use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{BinaryExpr, Sort};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::inspect_expr_pre;
use datafusion_expr::logical_plan::LogicalPlanBuilder;
use datafusion_expr::utils::{check_all_columns_from_schema, from_plan};
use datafusion_expr::{
and,
logical_plan::{Filter, LogicalPlan},
Expr, Operator,
};
use std::collections::HashSet;
use std::sync::Arc;
pub fn optimize_children(
optimizer: &impl OptimizerRule,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let new_exprs = plan.expressions();
let mut new_inputs = Vec::with_capacity(plan.inputs().len());
let mut plan_is_changed = false;
for input in plan.inputs() {
let new_input = optimizer.try_optimize(input, config)?;
plan_is_changed = plan_is_changed || new_input.is_some();
new_inputs.push(new_input.unwrap_or_else(|| input.clone()))
}
if plan_is_changed {
Ok(Some(from_plan(plan, &new_exprs, &new_inputs)?))
} else {
Ok(None)
}
}
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
split_conjunction_impl(expr, vec![])
}
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
let exprs = split_conjunction_impl(left, exprs);
split_conjunction_impl(right, exprs)
}
Expr::Alias(expr, _) => split_conjunction_impl(expr, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
split_binary_owned(expr, Operator::And)
}
pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
split_binary_owned_impl(expr, op, vec![])
}
fn split_binary_owned_impl(
expr: Expr,
operator: Operator,
mut exprs: Vec<Expr>,
) -> Vec<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
let exprs = split_binary_owned_impl(*left, operator, exprs);
split_binary_owned_impl(*right, operator, exprs)
}
Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
split_binary_impl(expr, op, vec![])
}
fn split_binary_impl<'a>(
expr: &'a Expr,
operator: Operator,
mut exprs: Vec<&'a Expr>,
) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
let exprs = split_binary_impl(left, operator, exprs);
split_binary_impl(right, operator, exprs)
}
Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.and(expr))
}
pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.or(expr))
}
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}
pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
fn check(expr: &&Expr) -> Result<()> {
inspect_expr_pre(expr, |expr| match expr {
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Operator::Or,
right: _,
}) => {
plan_err!("Optimizing disjunctions not supported!")
}
_ => Ok(()),
})
}
predicates.iter().try_for_each(check)
}
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
let predicate = predicates
.iter()
.skip(1)
.fold(predicates[0].clone(), |acc, predicate| {
and(acc, (*predicate).to_owned())
});
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?))
}
pub fn find_join_exprs(
exprs: Vec<&Expr>,
schema: &DFSchemaRef,
) -> Result<(Vec<Expr>, Vec<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins = vec![];
let mut others = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
(*left.clone(), *op, *right.clone())
}
_ => {
others.push((*filter).clone());
continue;
}
};
let left = match left {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
let right = match right {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
match op {
Operator::Eq => {}
Operator::NotEq => {}
_ => {
plan_err!(format!("can't optimize {op} column comparison"))?;
}
}
joins.push((*filter).clone())
}
Ok((joins, others))
}
pub fn exprs_to_join_cols(
exprs: &[Expr],
schema: &DFSchemaRef,
include_negated: bool,
) -> Result<(Vec<Column>, Vec<Column>, Option<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins: Vec<(String, String)> = vec![];
let mut others: Vec<Expr> = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
(*left.clone(), *op, *right.clone())
}
_ => plan_err!("Invalid correlation expression!")?,
};
match op {
Operator::Eq => {}
Operator::NotEq => {
if !include_negated {
others.push((*filter).clone());
continue;
}
}
_ => plan_err!(format!("Correlation operator unsupported: {op}"))?,
}
let left = left.try_into_col()?;
let right = right.try_into_col()?;
let sorted = if fields.contains(&left.flat_name()) {
(right.flat_name(), left.flat_name())
} else {
(left.flat_name(), right.flat_name())
};
joins.push(sorted);
}
let (left_cols, right_cols): (Vec<_>, Vec<_>) = joins
.into_iter()
.map(|(l, r)| (Column::from(l.as_str()), Column::from(r.as_str())))
.unzip();
let pred = conjunction(others);
Ok((left_cols, right_cols, pred))
}
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
match slice {
[it] => Ok(it),
[] => plan_err!("No items found!"),
_ => plan_err!("More than one item found!"),
}
}
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: ExprRewriter<Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort(Sort { expr, .. }) => name_for_alias(expr),
expr => expr.display_name(),
}
}
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;
if new_name == original_name {
return Ok(expr);
}
Ok(match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first))
}
expr => expr.alias(original_name),
})
}
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
inputs
.iter()
.map(|input| input.schema())
.fold(DFSchema::empty(), |mut lhs, rhs| {
lhs.merge(rhs);
lhs
})
}
pub(crate) fn extract_join_filters(
maybe_filter: &LogicalPlan,
) -> Result<(Vec<Expr>, LogicalPlan)> {
if let LogicalPlan::Filter(plan_filter) = maybe_filter {
let input_schema = plan_filter.input.schema();
let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
let mut join_filters: Vec<Expr> = vec![];
let mut subquery_filters: Vec<Expr> = vec![];
for expr in subquery_filter_exprs {
let cols = expr.to_columns()?;
if check_all_columns_from_schema(&cols, input_schema.clone())? {
subquery_filters.push(expr.clone());
} else {
join_filters.push(expr.clone())
}
}
let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
if let Some(expr) = conjunction(subquery_filters) {
plan = plan.filter(expr)?
}
Ok((join_filters, plan.build()?))
} else {
Ok((vec![], maybe_filter.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::expr::Cast;
use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;
#[test]
fn test_split_conjunction() {
let expr = col("a");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_conjunction_two() {
let expr = col("a").eq(lit(5)).and(col("b"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_alias() {
let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_binary_owned() {
let expr = col("a");
assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
}
#[test]
fn test_split_binary_owned_two() {
assert_eq!(
split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_binary_owned_different_op() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(
split_binary_owned(expr.clone(), Operator::And),
vec![expr]
);
}
#[test]
fn test_split_conjunction_owned() {
let expr = col("a");
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_split_conjunction_owned_two() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_conjunction_owned_alias() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
vec![
col("a").eq(lit(5)),
col("b"),
]
);
}
#[test]
fn test_conjunction_empty() {
assert_eq!(conjunction(vec![]), None);
}
#[test]
fn test_conjunction() {
let expr = conjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
}
#[test]
fn test_disjunction_empty() {
assert_eq!(disjunction(vec![]), None);
}
#[test]
fn test_disjunction() {
let expr = disjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
}
#[test]
fn test_split_conjunction_owned_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));
test_rewrite(col("a"), col("b"));
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
test_rewrite(
Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)),
Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)),
);
}
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}
impl ExprRewriter for TestRewriter {
fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}
let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();
let new_name = match &expr {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();
assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}