use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
and, col,
logical_plan::{Filter, LogicalPlan},
utils::from_plan,
Expr, Operator,
};
use std::collections::{HashSet, VecDeque};
use std::sync::Arc;
pub fn optimize_children(
optimizer: &impl OptimizerRule,
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
let new_exprs = plan.expressions();
let new_inputs = plan
.inputs()
.into_iter()
.map(|plan| optimizer.optimize(plan, optimizer_config))
.collect::<Result<Vec<_>>>()?;
from_plan(plan, &new_exprs, &new_inputs)
}
pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
split_conjunction_impl(expr, vec![])
}
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
right,
op: Operator::And,
left,
}) => {
let exprs = split_conjunction_impl(left, exprs);
split_conjunction_impl(right, exprs)
}
Expr::Alias(expr, _) => split_conjunction_impl(expr, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
split_binary_owned(expr, Operator::And)
}
pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
split_binary_owned_impl(expr, op, vec![])
}
fn split_binary_owned_impl(
expr: Expr,
operator: Operator,
mut exprs: Vec<Expr>,
) -> Vec<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
let exprs = split_binary_owned_impl(*left, operator, exprs);
split_binary_owned_impl(*right, operator, exprs)
}
Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
split_binary_impl(expr, op, vec![])
}
fn split_binary_impl<'a>(
expr: &'a Expr,
operator: Operator,
mut exprs: Vec<&'a Expr>,
) -> Vec<&'a Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
let exprs = split_binary_impl(left, operator, exprs);
split_binary_impl(right, operator, exprs)
}
Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs),
other => {
exprs.push(other);
exprs
}
}
}
fn permutations(mut exprs: VecDeque<Vec<&Expr>>) -> Vec<Vec<&Expr>> {
let first = if let Some(first) = exprs.pop_front() {
first
} else {
return vec![];
};
if exprs.is_empty() {
first.into_iter().map(|e| vec![e]).collect()
} else {
first
.into_iter()
.flat_map(|expr| {
permutations(exprs.clone())
.into_iter()
.map(|expr_list| {
std::iter::once(expr)
.chain(expr_list.into_iter())
.collect::<Vec<&Expr>>()
})
.collect::<Vec<Vec<&Expr>>>()
})
.collect()
}
}
const MAX_CNF_REWRITE_CONJUNCTS: usize = 10;
pub fn cnf_rewrite(expr: Expr) -> Expr {
let disjuncts = split_binary(&expr, Operator::Or);
let disjunct_conjuncts: VecDeque<Vec<&Expr>> = disjuncts
.into_iter()
.map(|e| split_binary(e, Operator::And))
.collect::<VecDeque<_>>();
let num_conjuncts = disjunct_conjuncts
.iter()
.fold(1usize, |sz, exprs| sz.saturating_mul(exprs.len()));
if disjunct_conjuncts.iter().any(|exprs| exprs.len() > 1)
&& num_conjuncts < MAX_CNF_REWRITE_CONJUNCTS
{
let or_clauses = permutations(disjunct_conjuncts)
.into_iter()
.map(|exprs| disjunction(exprs.into_iter().cloned()).unwrap());
conjunction(or_clauses).unwrap()
}
else {
expr
}
}
pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.and(expr))
}
pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.or(expr))
}
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}
pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
struct DisjunctionVisitor {}
impl ExpressionVisitor for DisjunctionVisitor {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Operator::Or,
right: _,
}) => {
plan_err!("Optimizing disjunctions not supported!")
}
_ => Ok(Recursion::Continue(self)),
}
}
}
for predicate in predicates.iter() {
predicate.accept(DisjunctionVisitor {})?;
}
Ok(())
}
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
let predicate = predicates
.iter()
.skip(1)
.fold(predicates[0].clone(), |acc, predicate| {
and(acc, (*predicate).to_owned())
});
Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?))
}
pub fn find_join_exprs(
exprs: Vec<&Expr>,
schema: &DFSchemaRef,
) -> Result<(Vec<Expr>, Vec<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins = vec![];
let mut others = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
(*left.clone(), *op, *right.clone())
}
_ => {
others.push((*filter).clone());
continue;
}
};
let left = match left {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
let right = match right {
Expr::Column(c) => c,
_ => {
others.push((*filter).clone());
continue;
}
};
if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) {
others.push((*filter).clone());
continue; }
match op {
Operator::Eq => {}
Operator::NotEq => {}
_ => {
plan_err!(format!("can't optimize {} column comparison", op))?;
}
}
joins.push((*filter).clone())
}
Ok((joins, others))
}
pub fn exprs_to_join_cols(
exprs: &[Expr],
schema: &DFSchemaRef,
include_negated: bool,
) -> Result<(Vec<Column>, Vec<Column>, Option<Expr>)> {
let fields: HashSet<_> = schema
.fields()
.iter()
.map(|it| it.qualified_name())
.collect();
let mut joins: Vec<(String, String)> = vec![];
let mut others: Vec<Expr> = vec![];
for filter in exprs.iter() {
let (left, op, right) = match filter {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
(*left.clone(), *op, *right.clone())
}
_ => plan_err!("Invalid correlation expression!")?,
};
match op {
Operator::Eq => {}
Operator::NotEq => {
if !include_negated {
others.push((*filter).clone());
continue;
}
}
_ => plan_err!(format!("Correlation operator unsupported: {}", op))?,
}
let left = left.try_into_col()?;
let right = right.try_into_col()?;
let sorted = if fields.contains(&left.flat_name()) {
(right.flat_name(), left.flat_name())
} else {
(left.flat_name(), right.flat_name())
};
joins.push(sorted);
}
let (left_cols, right_cols): (Vec<_>, Vec<_>) = joins
.into_iter()
.map(|(l, r)| (Column::from(l.as_str()), Column::from(r.as_str())))
.unzip();
let pred = conjunction(others);
Ok((left_cols, right_cols, pred))
}
pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
match slice {
[it] => Ok(it),
[] => plan_err!("No items found!"),
_ => plan_err!("More than one item found!"),
}
}
pub fn merge_cols(
a: (&[Column], &[Column]),
b: (&[Column], &[Column]),
) -> (Vec<Column>, Vec<Column>) {
let e =
a.0.iter()
.map(|it| it.flat_name())
.chain(a.1.iter().map(|it| it.flat_name()))
.map(|it| Column::from(it.as_str()));
let f =
b.0.iter()
.map(|it| it.flat_name())
.chain(b.1.iter().map(|it| it.flat_name()))
.map(|it| Column::from(it.as_str()));
let mut g = e.zip(f).collect::<Vec<_>>();
g.dedup();
g.into_iter().unzip()
}
pub fn swap_table(new_table: &str, cols: &[Column]) -> Vec<Column> {
cols.iter()
.map(|it| Column {
relation: Some(new_table.to_string()),
name: it.name.clone(),
})
.collect()
}
pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
cols.iter()
.map(|it| col(it.flat_name().as_str()).alias(it.name.as_str()))
.collect()
}
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: ExprRewriter<Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort { expr, .. } => name_for_alias(expr),
expr => expr.display_name(),
}
}
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;
if new_name == original_name {
return Ok(expr);
}
Ok(match expr {
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort {
expr: Box::new(expr),
asc,
nulls_first,
}
}
expr => expr.alias(original_name),
})
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::expr::Cast;
use datafusion_expr::{col, lit, or, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;
#[test]
fn test_split_conjunction() {
let expr = col("a");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_conjunction_two() {
let expr = col("a").eq(lit(5)).and(col("b"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_alias() {
let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
let expr1 = col("a").eq(lit(5));
let expr2 = col("b");
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr1, &expr2]);
}
#[test]
fn test_split_conjunction_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
let result = split_conjunction(&expr);
assert_eq!(result, vec![&expr]);
}
#[test]
fn test_split_binary_owned() {
let expr = col("a");
assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
}
#[test]
fn test_split_binary_owned_two() {
assert_eq!(
split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_binary_owned_different_op() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(
split_binary_owned(expr.clone(), Operator::And),
vec![expr]
);
}
#[test]
fn test_split_conjunction_owned() {
let expr = col("a");
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_split_conjunction_owned_two() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
vec![col("a").eq(lit(5)), col("b")]
);
}
#[test]
fn test_split_conjunction_owned_alias() {
assert_eq!(
split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
vec![
col("a").eq(lit(5)),
col("b"),
]
);
}
#[test]
fn test_conjunction_empty() {
assert_eq!(conjunction(vec![]), None);
}
#[test]
fn test_conjunction() {
let expr = conjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
}
#[test]
fn test_disjunction_empty() {
assert_eq!(disjunction(vec![]), None);
}
#[test]
fn test_disjunction() {
let expr = disjunction(vec![col("a"), col("b"), col("c")]);
assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
}
#[test]
fn test_split_conjunction_owned_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
expr_to_columns(
&Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
&mut accum,
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));
test_rewrite(col("a"), col("b"));
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
test_rewrite(
Expr::Sort {
expr: Box::new(col("a").add(lit(1i32))),
asc: true,
nulls_first: false,
},
Expr::Sort {
expr: Box::new(col("b").add(lit(2i64))),
asc: true,
nulls_first: false,
},
);
}
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}
impl ExprRewriter for TestRewriter {
fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}
let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
let original_name = match &expr_from {
Expr::Sort { expr, .. } => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();
let new_name = match &expr {
Expr::Sort { expr, .. } => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();
assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
#[test]
fn test_permutations() {
assert_eq!(make_permutations(vec![]), vec![] as Vec<Vec<Expr>>)
}
#[test]
fn test_permutations_one() {
assert_eq!(
make_permutations(vec![vec![col("a")]]),
vec![vec![col("a")]]
)
}
#[test]
fn test_permutations_two() {
assert_eq!(
make_permutations(vec![vec![col("a"), col("b")]]),
vec![vec![col("a")], vec![col("b")]]
)
}
#[test]
fn test_permutations_two_and_one() {
assert_eq!(
make_permutations(vec![vec![col("a"), col("b")], vec![col("c")]]),
vec![vec![col("a"), col("c")], vec![col("b"), col("c")]]
)
}
#[test]
fn test_permutations_two_and_one_and_two() {
assert_eq!(
make_permutations(vec![
vec![col("a"), col("b")],
vec![col("c")],
vec![col("d"), col("e")]
]),
vec![
vec![col("a"), col("c"), col("d")],
vec![col("a"), col("c"), col("e")],
vec![col("b"), col("c"), col("d")],
vec![col("b"), col("c"), col("e")],
]
)
}
fn make_permutations(exprs: impl IntoIterator<Item = Vec<Expr>>) -> Vec<Vec<Expr>> {
let exprs = exprs.into_iter().collect::<Vec<_>>();
let exprs: VecDeque<Vec<&Expr>> = exprs
.iter()
.map(|exprs| exprs.iter().collect::<Vec<&Expr>>())
.collect();
permutations(exprs)
.into_iter()
.map(|exprs| exprs.into_iter().cloned().collect())
.collect()
}
#[test]
fn test_rewrite_cnf() {
let a_1 = col("a").eq(lit(1i64));
let a_2 = col("a").eq(lit(2i64));
let b_1 = col("b").eq(lit(1i64));
let b_2 = col("b").eq(lit(2i64));
let expr1 = and(a_1.clone(), b_2.clone());
let expect = expr1.clone();
assert_eq!(expect, cnf_rewrite(expr1));
let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
let expect = and(a_1.clone(), b_2.clone())
.and(a_2.clone())
.and(b_1.clone());
assert_eq!(expect, cnf_rewrite(expr1));
let expr1 = or(a_1.clone(), b_2.clone());
let expect = expr1.clone();
assert_eq!(expect, cnf_rewrite(expr1));
let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
let a1_or_a2 = or(a_1.clone(), a_2.clone());
let a1_or_b1 = or(a_1.clone(), b_1.clone());
let b2_or_a2 = or(b_2.clone(), a_2.clone());
let b2_or_b1 = or(b_2.clone(), b_1.clone());
let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1);
assert_eq!(expect, cnf_rewrite(expr1));
let a1_or_b2 = or(a_1.clone(), b_2.clone());
let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone()));
let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone()));
assert_eq!(expect, cnf_rewrite(expr1));
let expr1 = or(or(a_1, b_2), or(a_2, b_1));
let expect = expr1.clone();
assert_eq!(expect, cnf_rewrite(expr1));
}
#[test]
fn test_rewrite_cnf_overflow() {
let mut expr1 = col("test1").eq(lit(1i64));
let expr2 = col("test2").eq(lit(2i64));
for _i in 0..9 {
expr1 = expr1.clone().and(expr2.clone());
}
let expr3 = expr1.clone();
let expr = or(expr1, expr3);
assert_eq!(expr, cnf_rewrite(expr.clone()));
}
}