use super::optimizer::OptimizerRule;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::{
Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Window,
};
use crate::logical_plan::{
build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable,
Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion,
Repartition, Union, Values,
};
use crate::prelude::lit;
use crate::scalar::ScalarValue;
use crate::{
error::{DataFusionError, Result},
logical_plan::ExpressionVisitor,
};
use std::{collections::HashSet, sync::Arc};
const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__";
const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__";
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
expr_to_columns(e, accum)?;
}
Ok(())
}
struct ColumnNameVisitor<'a> {
accum: &'a mut HashSet<Column>,
}
impl ExpressionVisitor for ColumnNameVisitor<'_> {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::Column(qc) => {
self.accum.insert(qc.clone());
}
Expr::ScalarVariable(var_names) => {
self.accum.insert(Column::from_name(var_names.join(".")));
}
Expr::Alias(_, _)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Wildcard
| Expr::GetIndexedField { .. } => {}
}
Ok(Recursion::Continue(self))
}
}
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
expr.accept(ColumnNameVisitor { accum })?;
Ok(())
}
pub fn optimize_children(
optimizer: &impl OptimizerRule,
plan: &LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
let new_exprs = plan.expressions();
let new_inputs = plan
.inputs()
.into_iter()
.map(|plan| optimizer.optimize(plan, execution_props))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &new_exprs, &new_inputs)
}
pub fn from_plan(
plan: &LogicalPlan,
expr: &[Expr],
inputs: &[LogicalPlan],
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(Projection { schema, alias, .. }) => {
Ok(LogicalPlan::Projection(Projection {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
schema: schema.clone(),
alias: alias.clone(),
}))
}
LogicalPlan::Values(Values { schema, .. }) => Ok(LogicalPlan::Values(Values {
schema: schema.clone(),
values: expr
.chunks_exact(schema.fields().len())
.map(|s| s.to_vec())
.collect::<Vec<_>>(),
})),
LogicalPlan::Filter { .. } => Ok(LogicalPlan::Filter(Filter {
predicate: expr[0].clone(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::RoundRobinBatch(n) => {
Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::RoundRobinBatch(*n),
input: Arc::new(inputs[0].clone()),
}))
}
Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition {
partitioning_scheme: Partitioning::Hash(expr.to_owned(), *n),
input: Arc::new(inputs[0].clone()),
})),
},
LogicalPlan::Window(Window {
window_expr,
schema,
..
}) => Ok(LogicalPlan::Window(Window {
input: Arc::new(inputs[0].clone()),
window_expr: expr[0..window_expr.len()].to_vec(),
schema: schema.clone(),
})),
LogicalPlan::Aggregate(Aggregate {
group_expr, schema, ..
}) => Ok(LogicalPlan::Aggregate(Aggregate {
group_expr: expr[0..group_expr.len()].to_vec(),
aggr_expr: expr[group_expr.len()..].to_vec(),
input: Arc::new(inputs[0].clone()),
schema: schema.clone(),
})),
LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::Join(Join {
join_type,
join_constraint,
on,
null_equals_null,
..
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
on: on.clone(),
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
}))
}
LogicalPlan::CrossJoin(_) => {
let left = inputs[0].clone();
let right = &inputs[1];
LogicalPlanBuilder::from(left).cross_join(right)?.build()
}
LogicalPlan::Limit(Limit { n, .. }) => Ok(LogicalPlan::Limit(Limit {
n: *n,
input: Arc::new(inputs[0].clone()),
})),
LogicalPlan::CreateMemoryTable(CreateMemoryTable { name, .. }) => {
Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable {
input: Arc::new(inputs[0].clone()),
name: name.clone(),
}))
}
LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
node: e.node.from_template(expr, inputs),
})),
LogicalPlan::Union(Union { schema, alias, .. }) => {
Ok(LogicalPlan::Union(Union {
inputs: inputs.to_vec(),
schema: schema.clone(),
alias: alias.clone(),
}))
}
LogicalPlan::Analyze(a) => {
assert!(expr.is_empty());
assert_eq!(inputs.len(), 1);
Ok(LogicalPlan::Analyze(Analyze {
verbose: a.verbose,
schema: a.schema.clone(),
input: Arc::new(inputs[0].clone()),
}))
}
LogicalPlan::Explain(_) => {
assert!(
expr.is_empty(),
"Explain can not be created from utils::from_expr"
);
assert!(
inputs.is_empty(),
"Explain can not be created from utils::from_expr"
);
Ok(plan.clone())
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::TableScan { .. }
| LogicalPlan::CreateExternalTable(_)
| LogicalPlan::DropTable(_) => {
assert!(expr.is_empty(), "{:?} should have no exprs", plan);
assert!(inputs.is_empty(), "{:?} should have no inputs", plan);
Ok(plan.clone())
}
}
}
pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
match expr {
Expr::BinaryExpr { left, right, .. } => {
Ok(vec![left.as_ref().to_owned(), right.as_ref().to_owned()])
}
Expr::IsNull(expr)
| Expr::IsNotNull(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Alias(expr, ..)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::WindowFunction {
args,
partition_by,
order_by,
..
} => {
let mut expr_list: Vec<Expr> = vec![];
expr_list.extend(args.clone());
expr_list.push(lit(WINDOW_PARTITION_MARKER));
expr_list.extend(partition_by.clone());
expr_list.push(lit(WINDOW_SORT_MARKER));
expr_list.extend(order_by.clone());
Ok(expr_list)
}
Expr::Case {
expr,
when_then_expr,
else_expr,
..
} => {
let mut expr_list: Vec<Expr> = vec![];
if let Some(e) = expr {
expr_list.push(lit(CASE_EXPR_MARKER));
expr_list.push(e.as_ref().to_owned());
}
for (w, t) in when_then_expr {
expr_list.push(w.as_ref().to_owned());
expr_list.push(t.as_ref().to_owned());
}
if let Some(e) = else_expr {
expr_list.push(lit(CASE_ELSE_MARKER));
expr_list.push(e.as_ref().to_owned());
}
Ok(expr_list)
}
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]),
Expr::Between {
expr, low, high, ..
} => Ok(vec![
expr.as_ref().to_owned(),
low.as_ref().to_owned(),
high.as_ref().to_owned(),
]),
Expr::InList { expr, list, .. } => {
let mut expr_list: Vec<Expr> = vec![expr.as_ref().to_owned()];
for list_expr in list {
expr_list.push(list_expr.to_owned());
}
Ok(expr_list)
}
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}
pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
match expr {
Expr::BinaryExpr { op, .. } => Ok(Expr::BinaryExpr {
left: Box::new(expressions[0].clone()),
op: *op,
right: Box::new(expressions[1].clone()),
}),
Expr::IsNull(_) => Ok(Expr::IsNull(Box::new(expressions[0].clone()))),
Expr::IsNotNull(_) => Ok(Expr::IsNotNull(Box::new(expressions[0].clone()))),
Expr::ScalarFunction { fun, .. } => Ok(Expr::ScalarFunction {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::ScalarUDF { fun, .. } => Ok(Expr::ScalarUDF {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::WindowFunction {
fun, window_frame, ..
} => {
let partition_index = expressions
.iter()
.position(|expr| {
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
if str == WINDOW_PARTITION_MARKER)
})
.ok_or_else(|| {
DataFusionError::Internal(
"Ill-formed window function expressions: unexpected marker"
.to_owned(),
)
})?;
let sort_index = expressions
.iter()
.position(|expr| {
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
if str == WINDOW_SORT_MARKER)
})
.ok_or_else(|| {
DataFusionError::Internal(
"Ill-formed window function expressions".to_owned(),
)
})?;
if partition_index >= sort_index {
Err(DataFusionError::Internal(
"Ill-formed window function expressions: partition index too large"
.to_owned(),
))
} else {
Ok(Expr::WindowFunction {
fun: fun.clone(),
args: expressions[..partition_index].to_vec(),
partition_by: expressions[partition_index + 1..sort_index].to_vec(),
order_by: expressions[sort_index + 1..].to_vec(),
window_frame: *window_frame,
})
}
}
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions.to_vec(),
distinct: *distinct,
}),
Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::Case { .. } => {
let mut base_expr: Option<Box<Expr>> = None;
let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
let mut else_expr: Option<Box<Expr>> = None;
let mut i = 0;
while i < expressions.len() {
match &expressions[i] {
Expr::Literal(ScalarValue::Utf8(Some(str)))
if str == CASE_EXPR_MARKER =>
{
base_expr = Some(Box::new(expressions[i + 1].clone()));
i += 2;
}
Expr::Literal(ScalarValue::Utf8(Some(str)))
if str == CASE_ELSE_MARKER =>
{
else_expr = Some(Box::new(expressions[i + 1].clone()));
i += 2;
}
_ => {
when_then.push((
Box::new(expressions[i].clone()),
Box::new(expressions[i + 1].clone()),
));
i += 2;
}
}
}
Ok(Expr::Case {
expr: base_expr,
when_then_expr: when_then,
else_expr,
})
}
Expr::Cast { data_type, .. } => Ok(Expr::Cast {
expr: Box::new(expressions[0].clone()),
data_type: data_type.clone(),
}),
Expr::TryCast { data_type, .. } => Ok(Expr::TryCast {
expr: Box::new(expressions[0].clone()),
data_type: data_type.clone(),
}),
Expr::Alias(_, alias) => {
Ok(Expr::Alias(Box::new(expressions[0].clone()), alias.clone()))
}
Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))),
Expr::Negative(_) => Ok(Expr::Negative(Box::new(expressions[0].clone()))),
Expr::Column(_)
| Expr::Literal(_)
| Expr::InList { .. }
| Expr::ScalarVariable(_) => Ok(expr.clone()),
Expr::Sort {
asc, nulls_first, ..
} => Ok(Expr::Sort {
expr: Box::new(expressions[0].clone()),
asc: *asc,
nulls_first: *nulls_first,
}),
Expr::Between { negated, .. } => {
let expr = Expr::BinaryExpr {
left: Box::new(Expr::BinaryExpr {
left: Box::new(expressions[0].clone()),
op: Operator::GtEq,
right: Box::new(expressions[1].clone()),
}),
op: Operator::And,
right: Box::new(Expr::BinaryExpr {
left: Box::new(expressions[0].clone()),
op: Operator::LtEq,
right: Box::new(expressions[2].clone()),
}),
};
if *negated {
Ok(Expr::Not(Box::new(expr)))
} else {
Ok(expr)
}
}
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
expr: Box::new(expressions[0].clone()),
key: key.clone(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::col;
use arrow::datatypes::DataType;
use std::collections::HashSet;
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
}