use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
and, col, combine_filters,
logical_plan::{Filter, LogicalPlan},
utils::from_plan,
Expr, Operator,
};
use std::collections::HashSet;
use std::sync::Arc;
pub fn optimize_children(
optimizer: &impl OptimizerRule,
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
let new_exprs = plan.expressions();
let new_inputs = plan
.inputs()
.into_iter()
.map(|plan| optimizer.optimize(plan, optimizer_config))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &new_exprs, &new_inputs)
}
pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) {
match predicate {
Expr::BinaryExpr {
right,
op: Operator::And,
left,
} => {
split_conjunction(left, predicates);
split_conjunction(right, predicates);
}
Expr::Alias(expr, _) => {
split_conjunction(expr, predicates);
}
other => predicates.push(other),
}
}
pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
struct DisjunctionVisitor {}
impl ExpressionVisitor for DisjunctionVisitor {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::BinaryExpr {
left: _,
op: Operator::Or,
right: _,
} => {
plan_err!("Optimizing disjunctions not supported!")
}
_ => Ok(Recursion::Continue(self)),
}
}
}
for predicate in predicates.iter() {
predicate.accept(DisjunctionVisitor {})?;
}
Ok(())
}
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
let predicate = predicates
.iter()
.skip(1)
.fold(predicates[0].clone(), |acc, predicate| {
and(acc, (*predicate).to_owned())
});
LogicalPlan::Filter(Filter {
predicate,
input: Arc::new(plan),
})
}
pub fn find_join_exprs(
exprs: Vec<&Expr>,
schema: &DFSchemaRef,
) -> Result<(Vec<Expr>, Vec<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins = vec![];
let mut others = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()),
_ => {
others.push((*filter).clone());
continue;
}
};
let left = match left {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
let right = match right {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
match op {
Operator::Eq => {}
Operator::NotEq => {}
_ => {
plan_err!(format!("can't optimize {} column comparison", op))?;
}
}
joins.push((*filter).clone())
}
Ok((joins, others))
}
pub fn exprs_to_join_cols(
exprs: &[Expr],
schema: &DFSchemaRef,
include_negated: bool,
) -> Result<(Vec<Column>, Vec<Column>, Option<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins: Vec<(String, String)> = vec![];
let mut others: Vec<Expr> = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()),
_ => plan_err!("Invalid correlation expression!")?,
};
match op {
Operator::Eq => {}
Operator::NotEq => {
if !include_negated {
others.push((*filter).clone());
continue;
}
}
_ => plan_err!(format!("Correlation operator unsupported: {}", op))?,
}
let left = left.try_into_col()?;
let right = right.try_into_col()?;
let sorted = if fields.contains(&left.flat_name()) {
(right.flat_name(), left.flat_name())
} else {
(left.flat_name(), right.flat_name())
};
joins.push(sorted);
}
let (left_cols, right_cols): (Vec<_>, Vec<_>) = joins
.into_iter()
.map(|(l, r)| (Column::from(l.as_str()), Column::from(r.as_str())))
.unzip();
let pred = combine_filters(&others);
Ok((left_cols, right_cols, pred))
}
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
match slice {
[it] => Ok(it),
[] => plan_err!("No items found!"),
_ => plan_err!("More than one item found!"),
}
}
pub fn merge_cols(
a: (&[Column], &[Column]),
b: (&[Column], &[Column]),
) -> (Vec<Column>, Vec<Column>) {
let e =
a.0.iter()
.map(|it| it.flat_name())
.chain(a.1.iter().map(|it| it.flat_name()))
.map(|it| Column::from(it.as_str()));
let f =
b.0.iter()
.map(|it| it.flat_name())
.chain(b.1.iter().map(|it| it.flat_name()))
.map(|it| Column::from(it.as_str()));
let mut g = e.zip(f).collect::<Vec<_>>();
g.dedup();
g.into_iter().unzip()
}
pub fn swap_table(new_table: &str, cols: &[Column]) -> Vec<Column> {
cols.iter()
.map(|it| Column {
relation: Some(new_table.to_string()),
name: it.name.clone(),
})
.collect()
}
pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
cols.iter()
.map(|it| col(it.flat_name().as_str()).alias(it.name.as_str()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::{col, utils::expr_to_columns};
use std::collections::HashSet;
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
}