use crate::utils::conjunction;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{
and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
or,
utils::from_plan,
BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;
use std::sync::Arc;
#[derive(Default)]
pub struct PushDownFilter {}
fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
},
LogicalPlan::CrossJoin(_) => Ok((true, true)),
_ => Err(DataFusionError::Internal(
"lr_is_preserved only valid for JOIN nodes".to_string(),
)),
}
}
fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightSemi
| JoinType::RightAnti => {
Ok((false, false))
}
},
LogicalPlan::CrossJoin(_) => Err(DataFusionError::Internal(
"on_lr_is_preserved cannot be applied to CROSSJOIN nodes".to_string(),
)),
_ => Err(DataFusionError::Internal(
"on_lr_is_preserved only valid for JOIN nodes".to_string(),
)),
}
}
fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result<bool> {
let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();
let columns = predicate.to_columns()?;
Ok(schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len())
}
fn extract_or_clauses_for_join(
filters: &[&Expr],
schema: &DFSchema,
preserved: bool,
) -> Vec<Expr> {
if !preserved {
return vec![];
}
let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();
let mut exprs = vec![];
for expr in filters.iter() {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Or,
right,
}) = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
exprs.push(or(left_expr, right_expr));
}
}
}
exprs
}
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;
match expr {
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
let columns = expr.to_columns().ok().unwrap();
if schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len()
{
predicate = Some(expr.clone());
}
}
}
predicate
}
fn push_down_all_join(
predicates: Vec<Expr>,
plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
on_filter: Vec<Expr>,
) -> Result<LogicalPlan> {
let on_filter_empty = on_filter.is_empty();
let (left_preserved, right_preserved) = lr_is_preserved(plan)?;
let mut left_push = vec![];
let mut right_push = vec![];
let mut keep_predicates = vec![];
for predicate in predicates {
if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? {
left_push.push(predicate);
} else if right_preserved
&& can_pushdown_join_predicate(&predicate, right.schema())?
{
right_push.push(predicate);
} else {
keep_predicates.push(predicate);
}
}
let mut keep_condition = vec![];
if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?;
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on, left.schema())? {
left_push.push(on)
} else if on_right_preserved
&& can_pushdown_join_predicate(&on, right.schema())?
{
right_push.push(on)
} else {
keep_condition.push(on)
}
}
}
let or_to_left = extract_or_clauses_for_join(
&keep_predicates.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let or_to_right = extract_or_clauses_for_join(
&keep_predicates.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);
let on_or_to_left = extract_or_clauses_for_join(
&keep_condition.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let on_or_to_right = extract_or_clauses_for_join(
&keep_condition.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);
left_push.extend(or_to_left);
left_push.extend(on_or_to_left);
right_push.extend(or_to_right);
right_push.extend(on_or_to_right);
let left = match conjunction(left_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?)
}
None => left.clone(),
};
let right = match conjunction(right_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?)
}
None => right.clone(),
};
let expr = plan.expressions();
let expr = if !on_filter_empty && keep_condition.is_empty() {
expr[..expr.len() - 1].to_vec()
} else if !keep_condition.is_empty() {
expr[..expr.len() - 1]
.iter()
.cloned()
.chain(once(keep_condition.into_iter().reduce(Expr::and).unwrap()))
.collect()
} else {
plan.expressions()
};
let plan = from_plan(plan, &expr, &[left, right])?;
if keep_predicates.is_empty() {
Ok(plan)
} else {
match conjunction(keep_predicates) {
Some(predicate) => Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?)),
None => Ok(plan),
}
}
}
fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
) -> Result<Option<LogicalPlan>> {
let mut predicates = match parent_predicate {
Some(parent_predicate) => {
utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone()))
}
None => vec![],
};
let on_filters = join
.filter
.as_ref()
.map(|e| utils::split_conjunction_owned(e.clone()))
.unwrap_or_else(Vec::new);
if join.join_type == JoinType::Inner {
let join_side_filters = predicates
.iter()
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();
let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
};
for col in columns.iter() {
for (l, r) in join.on.iter() {
if col == l {
join_cols_to_replace.insert(col, r);
break;
} else if col == r {
join_cols_to_replace.insert(col, l);
break;
}
}
}
if join_cols_to_replace.is_empty() {
return None;
}
let join_side_predicate =
match replace_col(predicate.clone(), &join_cols_to_replace) {
Ok(p) => p,
Err(e) => {
return Some(Err(e));
}
};
Some(Ok(join_side_predicate))
})
.collect::<Result<Vec<_>>>()?;
predicates.extend(join_side_filters);
}
if on_filters.is_empty() && predicates.is_empty() {
return Ok(None);
}
Ok(Some(push_down_all_join(
predicates,
plan,
&join.left,
&join.right,
on_filters,
)?))
}
impl OptimizerRule for PushDownFilter {
fn name(&self) -> &str {
"push_down_filter"
}
fn optimize(
&self,
plan: &LogicalPlan,
optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
let filter = match plan {
LogicalPlan::Filter(filter) => filter,
LogicalPlan::Join(join) => {
let optimized_plan = push_down_join(plan, join, None)?;
return match optimized_plan {
Some(optimized_plan) => {
utils::optimize_children(self, &optimized_plan, optimizer_config)
}
None => utils::optimize_children(self, plan, optimizer_config),
};
}
_ => return utils::optimize_children(self, plan, optimizer_config),
};
let child_plan = &**filter.input();
let new_plan = match child_plan {
LogicalPlan::Filter(child_filter) => {
let new_predicate =
and(filter.predicate().clone(), child_filter.predicate().clone());
let new_plan = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input().clone(),
)?);
return self.optimize(&new_plan, optimizer_config);
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
let new_filter =
plan.with_new_inputs(&[
(**(child_plan.inputs().get(0).unwrap())).clone()
])?;
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, field) in
subquery_alias.input.schema().fields().iter().enumerate()
{
replace_map.insert(
subquery_alias
.schema
.fields()
.get(i)
.unwrap()
.qualified_name(),
Expr::Column(field.qualified_column()),
);
}
let new_predicate =
replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
subquery_alias.input.clone(),
)?);
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Projection(projection) => {
let replace_map = projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let expr = match &projection.expr[i] {
Expr::Alias(expr, _) => expr.as_ref().clone(),
expr => expr.clone(),
};
(field.qualified_name(), expr)
})
.collect::<HashMap<_, _>>();
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(filter.predicate().clone(), &replace_map)?,
projection.input.clone(),
)?);
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Union(union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
for (i, field) in input.schema().fields().iter().enumerate() {
replace_map.insert(
union.schema.fields().get(i).unwrap().qualified_name(),
Expr::Column(field.qualified_column()),
);
}
let push_predicate =
replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
input.clone(),
)?)))
}
LogicalPlan::Union(Union {
inputs,
schema: plan.schema().clone(),
})
}
LogicalPlan::Aggregate(agg) => {
let mut used_columns = HashSet::new();
exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
let agg_columns = agg
.aggr_expr
.iter()
.map(|x| Ok(Column::from_name(x.display_name()?)))
.collect::<Result<HashSet<_>>>()?;
used_columns.extend(agg_columns);
let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
filter.predicate().clone(),
));
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let columns = expr.to_columns()?;
if columns.is_empty()
|| !columns
.intersection(&used_columns)
.collect::<HashSet<_>>()
.is_empty()
{
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
let child = match conjunction(push_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new((*agg.input).clone()),
)?),
None => (*agg.input).clone(),
};
let new_agg = from_plan(
filter.input(),
&filter.input().expressions(),
&vec![child],
)?;
match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_agg),
)?),
None => new_agg,
}
}
LogicalPlan::Join(join) => {
match push_down_join(filter.input(), join, Some(filter.predicate()))? {
Some(optimized_plan) => optimized_plan,
None => plan.clone(),
}
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
filter.predicate().clone(),
));
push_down_all_join(predicates, filter.input(), left, right, vec![])?
}
LogicalPlan::TableScan(scan) => {
let mut new_scan_filters = scan.filters.clone();
let mut new_predicate = vec![];
let filter_predicates = utils::split_conjunction_owned(
utils::cnf_rewrite(filter.predicate().clone()),
);
for filter_expr in &filter_predicates {
let (preserve_filter_node, add_to_provider) =
match scan.source.supports_filter_pushdown(filter_expr)? {
TableProviderFilterPushDown::Unsupported => (true, false),
TableProviderFilterPushDown::Inexact => (true, true),
TableProviderFilterPushDown::Exact => (false, true),
};
if preserve_filter_node {
new_predicate.push(filter_expr.clone());
}
if add_to_provider {
if new_scan_filters.contains(filter_expr) {
continue;
}
new_scan_filters.push(filter_expr.clone());
}
}
let new_scan = LogicalPlan::TableScan(TableScan {
source: scan.source.clone(),
projection: scan.projection.clone(),
projected_schema: scan.projected_schema.clone(),
table_name: scan.table_name.clone(),
filters: new_scan_filters,
fetch: scan.fetch,
});
match conjunction(new_predicate) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_scan),
)?),
None => new_scan,
}
}
_ => plan.clone(),
};
utils::optimize_children(self, &new_plan, optimizer_config)
}
}
impl PushDownFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<Expr> {
struct ColumnReplacer<'a> {
replace_map: &'a HashMap<String, Expr>,
}
impl<'a> ExprRewriter for ColumnReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = &expr {
match self.replace_map.get(&c.flat_name()) {
Some(new_c) => Ok(new_c.clone()),
None => Ok(expr),
}
} else {
Ok(expr)
}
}
}
e.rewrite(&mut ColumnReplacer { replace_map })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::DFSchema;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
};
use std::sync::Arc;
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
let optimized_plan = PushDownFilter::new()
.optimize(plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(plan.schema(), optimized_plan.schema());
assert_eq!(expected, formatted_plan);
Ok(())
}
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_after_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(10))?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_no_columns() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(0i64).eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: Int64(0) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_jump_2_plans() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.project(vec![col("c"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.c, test.b\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
let expected = "\
Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_complex_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "Filter: test.b > Int64(10)\
\n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "\
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_keep_partial_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(filter)?
.build()?;
let expected = "\
Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Plus,
Box::new(right),
))
}
fn multiply(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Multiply,
Box::new(right),
))
}
#[test]
fn complex_expression() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: b = Int64(1)\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: test.a * Int32(2) + test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn complex_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: a = Int64(1)\
\n Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn multi_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("SUM(test.c)").gt(lit(10i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: SUM(test.c) > Int64(10)\
\n Filter: b > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: SUM(test.c) > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn split_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(and(
col("SUM(test.c)").gt(lit(10i64)),
and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))),
))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: SUM(test.c) > Int64(10) AND b > Int64(10) AND SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: SUM(test.c) > Int64(10) AND SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn double_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(20))?
.limit(0, Some(10))?
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Limit: skip=0, fetch=20\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn union_all() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan2 = test_table_scan_with_name("test2")?;
let plan = LogicalPlanBuilder::from(table_scan)
.union(LogicalPlanBuilder::from(table_scan2).build()?)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "Union\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n Filter: test2.a = Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn union_all_on_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let table = LogicalPlanBuilder::from(table_scan)
.project_with_alias(vec![col("a").alias("b")], Some("test2".to_string()))?;
let plan = table
.union(table.build()?)?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "Union\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_union_different_schema() -> Result<()> {
let left = LogicalPlanBuilder::from(test_table_scan()?)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let schema = Schema::new(vec![
Field::new("d", DataType::UInt32, false),
Field::new("e", DataType::UInt32, false),
Field::new("f", DataType::UInt32, false),
]);
let right = table_scan(Some("test1"), &schema, None)?
.project(vec![col("d"), col("e"), col("f")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(&right)?
.project(vec![col("test.a"), col("test1.d")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.d\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int32(1)\
\n TableScan: test\
\n Projection: test1.d, test1.e, test1.f\
\n Filter: test1.d > Int32(2)\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_project_same_name_different_qualifier() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(&right)?
.project(vec![col("test.a"), col("test1.a")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.a\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int32(1)\
\n TableScan: test\
\n Projection: test1.a, test1.b, test1.c\
\n Filter: test1.a > Int32(2)\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_2_breaks_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.filter(col("a").lt_eq(lit(1i64)))?
.limit(0, Some(1))?
.project(vec![col("a")])?
.filter(col("a").gt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.a >= Int64(1)\
\n Projection: test.a\
\n Limit: skip=0, fetch=1\
\n Filter: test.a <= Int64(1)\
\n Projection: test.a\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Limit: skip=0, fetch=1\
\n Projection: test.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn two_filters_on_same_depth() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(0, Some(1))?
.filter(col("a").lt_eq(lit(1i64)))?
.filter(col("a").gt_eq(lit(1i64)))?
.project(vec![col("a")])?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Filter: test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filters_user_defined_node() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let plan = user_defined::new(plan);
let expected = "\
TestUserDefined\
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
assert_eq!(format!("{:?}", plan), expected);
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_on_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.a <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.a <= Int64(1)\
\n Inner Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: Using test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("c").lt_eq(col("b")))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.c <= test2.b\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2"
);
let expected = &format!("{:?}", plan);
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let table_scan_right = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(table_scan_right)
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.b <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n Filter: test.b <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_left_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_right_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_left_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Left Join: Using test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_right_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: test2.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.c")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a Filter: test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n Filter: test.c > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_removed() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.b")
.gt(lit(1u32))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b")])?
.build()?;
let filter = col("test.a").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("b")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
\n Projection: test.a\
\n TableScan: test\
\n Projection: test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.b\
\n Projection: test.a\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.b\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn left_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Left,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn right_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Right,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn full_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Full,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = &format!("{:?}", plan);
assert_optimized_plan_eq(&plan, expected)
}
struct PushDownProvider {
pub filter_support: TableProviderFilterPushDown,
}
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new(
"a",
arrow::datatypes::DataType::Int32,
true,
),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::Int32,
true,
),
]))
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn supports_filter_pushdown(
&self,
_: &Expr,
) -> Result<TableProviderFilterPushDown> {
Ok(self.filter_support.clone())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
let test_provider = PushDownProvider { filter_support };
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".to_string(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: None,
source: Arc::new(test_provider),
fetch: None,
});
LogicalPlanBuilder::from(table_scan)
.filter(col("a").eq(lit(1i64)))?
.build()
}
#[test]
fn filter_with_table_provider_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
let expected = "\
TableScan: test, full_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_with_table_provider_inexact() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_with_table_provider_multiple_invocations() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let optimised_plan = PushDownFilter::new()
.optimize(&plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&optimised_plan, expected)
}
#[test]
fn filter_with_table_provider_unsupported() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn multi_combined_filter() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Inexact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".to_string(),
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected = "Projection: a, b\
\n Filter: a = Int64(10) AND b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_multi_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c").alias("d")])?
.filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: b > Int64(10) AND d > Int64(10)\
\n Projection: test.a AS b, test.c AS d\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c AS d\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b").alias("d")])?
.build()?;
let filter = col("c").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("c")], vec![Column::from_name("d")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Inner Join: c = d Filter: c > UInt32(1)\
\n Projection: test.a AS c\
\n TableScan: test\
\n Projection: test2.b AS d\
\n TableScan: test2"
);
let expected = "\
Inner Join: c = d\
\n Projection: test.a AS c\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.b AS d\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_subquery_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan_sq = test_table_scan_with_name("sq")?;
let subplan = Arc::new(
LogicalPlanBuilder::from(table_scan_sq)
.project(vec![col("c")])?
.build()?,
);
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_subquery(col("b"), subplan))?
.build()?;
let expected_before = "\
Filter: b IN (<subquery>)\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq\
\n Projection: test.a AS b, test.c\
\n TableScan: test";
assert_eq!(format!("{:?}", plan), expected_before);
let expected_after = "\
Projection: test.a AS b, test.c\
\n Filter: test.a IN (<subquery>)\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
let plan = LogicalPlanBuilder::empty(true)
.project_with_alias(vec![lit(0i64).alias("a")], Some("b".to_owned()))?
.project_with_alias(vec![col("b.a")], Some("b".to_owned()))?
.filter(col("b.a").eq(lit(1i64)))?
.project(vec![col("b.a")])?
.build()?;
let expected_before = "Projection: b.a\
\n Filter: b.a = Int64(1)\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n EmptyRelation";
assert_eq!(format!("{:?}", plan), expected_before);
let expected_after = "Projection: b.a\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n Filter: Int64(0) = Int64(1)\
\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
fn test_crossjoin_with_or_clause() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a").alias("d"), col("a").alias("e")])?
.build()?;
let filter = or(
and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
);
let plan = LogicalPlanBuilder::from(left)
.cross_join(&right)?
.filter(filter)?
.build()?;
let expected = "\
Filter: (test.a = d OR test.b = e) AND (test.a = d OR test.c < UInt32(10)) AND (test.b > UInt32(1) OR test.b = e)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected)?;
let optimized_plan = PushDownFilter::new()
.optimize(&plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
assert_optimized_plan_eq(&optimized_plan, expected)
}
}