use arrow_schema::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{plan_err, DataFusionError, Result};
use datafusion_expr::WindowUDF;
use datafusion_expr::{
logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource,
};
use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
sqlparser::{dialect::GenericDialect, parser::Parser},
TableReference,
};
use std::{collections::HashMap, sync::Arc};
fn main() {
let sql = "SELECT \
c.id, c.first_name, c.last_name, \
COUNT(*) as num_orders, \
SUM(o.price) AS total_price, \
SUM(o.price * s.sales_tax) AS state_tax \
FROM customer c \
JOIN state s ON c.state = s.id \
JOIN orders o ON c.id = o.customer_id \
WHERE o.price > 0 \
AND c.last_name LIKE 'G%' \
GROUP BY 1, 2, 3 \
ORDER BY state_tax DESC";
let dialect = GenericDialect {}; let ast = Parser::parse_sql(&dialect, sql).unwrap();
let statement = &ast[0];
let schema_provider = MySchemaProvider::new();
let sql_to_rel = SqlToRel::new(&schema_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
println!("{plan:?}");
}
struct MySchemaProvider {
options: ConfigOptions,
tables: HashMap<String, Arc<dyn TableSource>>,
}
impl MySchemaProvider {
fn new() -> Self {
let mut tables = HashMap::new();
tables.insert(
"customer".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
]),
);
tables.insert(
"state".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("sales_tax", DataType::Decimal128(10, 2), false),
]),
);
tables.insert(
"orders".to_string(),
create_table_source(vec![
Field::new("id", DataType::Int32, false),
Field::new("customer_id", DataType::Int32, false),
Field::new("item_id", DataType::Int32, false),
Field::new("quantity", DataType::Int32, false),
Field::new("price", DataType::Decimal128(10, 2), false),
]),
);
Self {
tables,
options: Default::default(),
}
}
}
fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
Arc::new(LogicalTableSource::new(Arc::new(
Schema::new_with_metadata(fields, HashMap::new()),
)))
}
impl ContextProvider for MySchemaProvider {
fn get_table_provider(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
match self.tables.get(name.table()) {
Some(table) => Ok(table.clone()),
_ => plan_err!("Table not found: {}", name.table()),
}
}
fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
None
}
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
None
}
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
None
}
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
None
}
fn options(&self) -> &ConfigOptions {
&self.options
}
}