use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::{
col,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
logical_plan::{
Aggregate, CrossJoin, Filter, Join, JoinType, Limit, LogicalPlan, Projection,
TableScan, Union,
},
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;
#[derive(Default)]
pub struct FilterPushDown {}
type Predicate = (Expr, HashSet<Column>);
type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);
#[derive(Debug, Clone, Default)]
struct State {
filters: Vec<Predicate>,
}
impl State {
fn append_predicates(&mut self, predicates: Predicates) {
predicates
.0
.into_iter()
.zip(predicates.1)
.for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
}
}
fn get_predicates<'a>(
state: &'a State,
used_columns: &HashSet<Column>,
) -> Predicates<'a> {
state
.filters
.iter()
.filter(|(_, columns)| {
!columns
.intersection(used_columns)
.collect::<HashSet<_>>()
.is_empty()
})
.map(|&(ref a, ref b)| (a, b))
.unzip()
}
fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_inputs = plan
.inputs()
.iter()
.map(|input| optimize(input, state.clone()))
.collect::<Result<Vec<_>>>()?;
let expr = plan.expressions();
from_plan(plan, &expr, &new_inputs)
}
fn remove_filters(
filters: &[Predicate],
predicate_columns: &[&HashSet<Column>],
) -> Vec<Predicate> {
filters
.iter()
.filter(|(_, columns)| !predicate_columns.contains(&columns))
.cloned()
.collect::<Vec<_>>()
}
fn issue_filters(
mut state: State,
used_columns: HashSet<Column>,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
if predicates.is_empty() {
return push_down(&state, plan);
}
let plan = utils::add_filter(plan.clone(), &predicates);
state.filters = remove_filters(&state.filters, &predicate_columns);
push_down(&state, &plan)
}
fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => (true, true),
JoinType::Left => (true, false),
JoinType::Right => (false, true),
JoinType::Full => (false, false),
JoinType::Semi | JoinType::Anti => (true, false),
},
LogicalPlan::CrossJoin(_) => (true, true),
_ => unreachable!("lr_is_preserved only valid for JOIN nodes"),
}
}
fn on_lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => (true, true),
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Full => (false, false),
JoinType::Semi | JoinType::Anti => unreachable!(
"on_lr_is_preserved cannot be appplied to SEMI/ANTI-JOIN nodes"
),
},
LogicalPlan::CrossJoin(_) => {
unreachable!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes")
}
_ => unreachable!("on_lr_is_preserved only valid for JOIN nodes"),
}
}
fn get_pushable_join_predicates<'a>(
filters: &'a [Predicate],
schema: &DFSchema,
preserved: bool,
) -> Predicates<'a> {
if !preserved {
return (vec![], vec![]);
}
let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();
filters
.iter()
.filter(|(_, columns)| {
let all_columns_in_schema = schema_columns
.intersection(columns)
.collect::<HashSet<_>>()
.len()
== columns.len();
all_columns_in_schema
})
.map(|(a, b)| (a, b))
.unzip()
}
fn optimize_join(
mut state: State,
plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
on_filter: Vec<Predicate>,
) -> Result<LogicalPlan> {
let (left_preserved, right_preserved) = lr_is_preserved(plan);
let to_left =
get_pushable_join_predicates(&state.filters, left.schema(), left_preserved);
let to_right =
get_pushable_join_predicates(&state.filters, right.schema(), right_preserved);
let to_keep: Predicates = state
.filters
.iter()
.filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
.map(|(a, b)| (a, b))
.unzip();
let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
((vec![], vec![]), (vec![], vec![]), vec![])
} else {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan);
let on_to_left =
get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved);
let on_to_right =
get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved);
let on_to_keep = on_filter
.iter()
.filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e))
.map(|(a, _)| a.clone())
.collect::<Vec<_>>();
(on_to_left, on_to_right, on_to_keep)
};
let mut left_state = State { filters: vec![] };
left_state.append_predicates(to_left);
left_state.append_predicates(on_to_left);
let left = optimize(left, left_state)?;
let mut right_state = State { filters: vec![] };
right_state.append_predicates(to_right);
right_state.append_predicates(on_to_right);
let right = optimize(right, right_state)?;
let expr = plan.expressions();
let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {
expr[..expr.len() - 1].to_vec()
} else if !on_to_keep.is_empty() {
expr[..expr.len() - 1]
.iter()
.cloned()
.chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap()))
.collect()
} else {
plan.expressions()
};
let plan = from_plan(plan, &expr, &[left, right])?;
if to_keep.0.is_empty() {
Ok(plan)
} else {
let plan = utils::add_filter(plan, &to_keep.0);
state.filters = remove_filters(&state.filters, &to_keep.1);
Ok(plan)
}
}
fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Explain { .. } => {
push_down(&state, plan)
}
LogicalPlan::Analyze { .. } => push_down(&state, plan),
LogicalPlan::Filter(Filter { input, predicate }) => {
let mut predicates = vec![];
utils::split_conjunction(predicate, &mut predicates);
let mut no_col_predicates = vec![];
predicates
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(predicate, &mut columns)?;
if columns.is_empty() {
no_col_predicates.push(predicate)
} else {
state.filters.push((predicate.clone(), columns));
}
Ok(())
})?;
if !no_col_predicates.is_empty() {
Ok(utils::add_filter(
optimize(input, state)?,
&no_col_predicates,
))
} else {
optimize(input, state)
}
}
LogicalPlan::Projection(Projection {
input,
expr,
schema,
alias: _,
}) => {
let projection = schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let expr = match &expr[i] {
Expr::Alias(expr, _) => expr.as_ref().clone(),
expr => expr.clone(),
};
(field.qualified_name(), expr)
})
.collect::<HashMap<_, _>>();
for (predicate, columns) in state.filters.iter_mut() {
*predicate = replace_cols_by_name(predicate.clone(), &projection)?;
columns.clear();
expr_to_columns(predicate, columns)?;
}
let new_input = optimize(input, state)?;
from_plan(plan, expr, &[new_input])
}
LogicalPlan::Aggregate(Aggregate {
aggr_expr, input, ..
}) => {
let mut used_columns = HashSet::new();
exprlist_to_columns(aggr_expr, &mut used_columns)?;
let agg_columns = aggr_expr
.iter()
.map(|x| Ok(Column::from_name(x.name(input.schema())?)))
.collect::<Result<HashSet<_>>>()?;
used_columns.extend(agg_columns);
issue_filters(state, used_columns, plan)
}
LogicalPlan::Sort { .. } => {
push_down(&state, plan)
}
LogicalPlan::Union(Union {
inputs: _,
schema,
alias: _,
}) => {
let projection = schema
.fields()
.iter()
.map(|field| (field.qualified_name(), col(field.name())))
.collect::<HashMap<_, _>>();
if !projection.is_empty() {
for (predicate, columns) in state.filters.iter_mut() {
*predicate = replace_cols_by_name(predicate.clone(), &projection)?;
columns.clear();
expr_to_columns(predicate, columns)?;
}
}
push_down(&state, plan)
}
LogicalPlan::Limit(Limit { input, .. }) => {
let used_columns = input
.schema()
.fields()
.iter()
.map(|f| f.qualified_column())
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
optimize_join(state, plan, left, right, vec![])
}
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
..
}) => {
let on_filters = filter
.as_ref()
.map(|e| {
let mut predicates = vec![];
utils::split_conjunction(e, &mut predicates);
predicates
.into_iter()
.map(|e| {
let mut accum = HashSet::new();
expr_to_columns(e, &mut accum)?;
Ok((e.clone(), accum))
})
.collect::<Result<Vec<_>>>()
})
.unwrap_or_else(|| Ok(vec![]))?;
if *join_type == JoinType::Inner {
let join_side_filters = state
.filters
.iter()
.chain(on_filters.iter())
.filter_map(|(predicate, columns)| {
let mut join_cols_to_replace = HashMap::new();
for col in columns.iter() {
for (l, r) in on {
if col == l {
join_cols_to_replace.insert(col, r);
break;
} else if col == r {
join_cols_to_replace.insert(col, l);
break;
}
}
}
if join_cols_to_replace.is_empty() {
return None;
}
let join_side_predicate =
match replace_col(predicate.clone(), &join_cols_to_replace) {
Ok(p) => p,
Err(e) => {
return Some(Err(e));
}
};
let join_side_columns = columns
.clone()
.into_iter()
.filter(|c| !join_cols_to_replace.contains_key(c))
.chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone()))
.collect();
Some(Ok((join_side_predicate, join_side_columns)))
})
.collect::<Result<Vec<_>>>()?;
state.filters.extend(join_side_filters);
}
optimize_join(state, plan, left, right, on_filters)
}
LogicalPlan::TableScan(TableScan {
source,
projected_schema,
filters,
projection,
table_name,
fetch,
}) => {
let mut used_columns = HashSet::new();
let mut new_filters = filters.clone();
for (filter_expr, cols) in &state.filters {
let (preserve_filter_node, add_to_provider) =
match source.supports_filter_pushdown(filter_expr)? {
TableProviderFilterPushDown::Unsupported => (true, false),
TableProviderFilterPushDown::Inexact => (true, true),
TableProviderFilterPushDown::Exact => (false, true),
};
if preserve_filter_node {
used_columns.extend(cols.clone());
}
if add_to_provider {
if new_filters.contains(filter_expr) {
continue;
}
new_filters.push(filter_expr.clone());
}
}
issue_filters(
state,
used_columns,
&LogicalPlan::TableScan(TableScan {
source: source.clone(),
projection: projection.clone(),
projected_schema: projected_schema.clone(),
table_name: table_name.clone(),
filters: new_filters,
fetch: *fetch,
}),
)
}
_ => {
let used_columns = plan
.schema()
.fields()
.iter()
.map(|f| f.qualified_column())
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
}
}
impl OptimizerRule for FilterPushDown {
fn name(&self) -> &str {
"filter_push_down"
}
fn optimize(&self, plan: &LogicalPlan, _: &OptimizerConfig) -> Result<LogicalPlan> {
optimize(plan, State::default())
}
}
impl FilterPushDown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<Expr> {
struct ColumnReplacer<'a> {
replace_map: &'a HashMap<String, Expr>,
}
impl<'a> ExprRewriter for ColumnReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = &expr {
match self.replace_map.get(&c.flat_name()) {
Some(new_c) => Ok(new_c.clone()),
None => Ok(expr),
}
} else {
Ok(expr)
}
}
}
e.rewrite(&mut ColumnReplacer { replace_map })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion_common::DFSchema;
use datafusion_expr::{
and, col, in_list, in_subquery, lit,
logical_plan::{builder::union_with_alias, JoinType},
sum, Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
};
use std::sync::Arc;
fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan {
let rule = FilterPushDown::new();
rule.optimize(plan, &OptimizerConfig::new())
.expect("failed to optimize plan")
}
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimized_plan = optimize_plan(plan);
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #test.a, #test.b\
\n Filter: #test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_after_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(None, Some(10))?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: #test.a = Int64(1)\
\n Limit: skip=None, fetch=10\
\n Projection: #test.a, #test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_no_columns() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(0i64).eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: Int64(0) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_jump_2_plans() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.project(vec![col("c"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #test.c, #test.b\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
let expected = "\
Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS total_salary]]\
\n Filter: #test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "\
Filter: #b > Int64(10)\
\n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #test.a AS b, #test.c\
\n Filter: #test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Plus,
right: Box::new(right),
}
}
fn multiply(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Multiply,
right: Box::new(right),
}
}
#[test]
fn complex_expression() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b = Int64(1)\
\n Projection: #test.a * Int32(2) + #test.c AS b, #test.c\
\n TableScan: test"
);
let expected = "\
Projection: #test.a * Int32(2) + #test.c AS b, #test.c\
\n Filter: #test.a * Int32(2) + #test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn complex_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #a = Int64(1)\
\n Projection: #b * Int32(3) AS a, #test.c\
\n Projection: #test.a * Int32(2) + #test.c AS b, #test.c\
\n TableScan: test"
);
let expected = "\
Projection: #b * Int32(3) AS a, #test.c\
\n Projection: #test.a * Int32(2) + #test.c AS b, #test.c\
\n Filter: #test.a * Int32(2) + #test.c * Int32(3) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn multi_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("SUM(test.c)").gt(lit(10i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #SUM(test.c) > Int64(10)\
\n Filter: #b > Int64(10)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test"
);
let expected = "\
Filter: #SUM(test.c) > Int64(10)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
\n Projection: #test.a AS b, #test.c\
\n Filter: #test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn split_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(and(
col("SUM(test.c)").gt(lit(10i64)),
and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))),
))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #SUM(test.c) > Int64(10) AND #b > Int64(10) AND #SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test"
);
let expected = "\
Filter: #SUM(test.c) > Int64(10) AND #SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
\n Projection: #test.a AS b, #test.c\
\n Filter: #test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn double_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(None, Some(20))?
.limit(None, Some(10))?
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: #test.a, #test.b\
\n Filter: #test.a = Int64(1)\
\n Limit: skip=None, fetch=10\
\n Limit: skip=None, fetch=20\
\n Projection: #test.a, #test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn union_all() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.union(LogicalPlanBuilder::from(table_scan).build()?)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Union\
\n Filter: #a = Int64(1)\
\n TableScan: test\
\n Filter: #a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn union_all_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let union =
union_with_alias(table_scan.clone(), table_scan, Some("t".to_string()))?;
let plan = LogicalPlanBuilder::from(union)
.filter(col("t.a").eq(lit(1i64)))?
.build()?;
let expected = "\
Union\
\n Filter: #a = Int64(1)\
\n TableScan: test\
\n Filter: #a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_2_breaks_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.filter(col("a").lt_eq(lit(1i64)))?
.limit(None, Some(1))?
.project(vec![col("a")])?
.filter(col("a").gt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Filter: #test.a >= Int64(1)\
\n Projection: #test.a\
\n Limit: skip=None, fetch=1\
\n Filter: #test.a <= Int64(1)\
\n Projection: #test.a\
\n TableScan: test"
);
let expected = "\
Projection: #test.a\
\n Filter: #test.a >= Int64(1)\
\n Limit: skip=None, fetch=1\
\n Projection: #test.a\
\n Filter: #test.a <= Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn two_filters_on_same_depth() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(None, Some(1))?
.filter(col("a").lt_eq(lit(1i64)))?
.filter(col("a").gt_eq(lit(1i64)))?
.project(vec![col("a")])?
.build()?;
assert_eq!(
format!("{:?}", plan),
"Projection: #test.a\
\n Filter: #test.a >= Int64(1)\
\n Filter: #test.a <= Int64(1)\
\n Limit: skip=None, fetch=1\
\n TableScan: test"
);
let expected = "\
Projection: #test.a\
\n Filter: #test.a >= Int64(1) AND #test.a <= Int64(1)\
\n Limit: skip=None, fetch=1\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filters_user_defined_node() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let plan = crate::test::user_defined::new(plan);
let expected = "\
TestUserDefined\
\n Filter: #test.a <= Int64(1)\
\n TableScan: test";
assert_eq!(format!("{:?}", plan), expected);
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_on_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.a <= Int64(1)\
\n Inner Join: #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: #test.a = #test2.a\
\n Filter: #test.a <= Int64(1)\
\n TableScan: test\
\n Projection: #test2.a\
\n Filter: #test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_using_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.a <= Int64(1)\
\n Inner Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: Using #test.a = #test2.a\
\n Filter: #test.a <= Int64(1)\
\n TableScan: test\
\n Projection: #test2.a\
\n Filter: #test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("c").lt_eq(col("b")))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.c <= #test2.b\
\n Inner Join: #test.a = #test2.a\
\n Projection: #test.a, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b\
\n TableScan: test2"
);
let expected = &format!("{:?}", plan);
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let table_scan_right = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(table_scan_right)
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.b <= Int64(1)\
\n Inner Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b\
\n TableScan: test\
\n Projection: #test2.a, #test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b\
\n Filter: #test.b <= Int64(1)\
\n TableScan: test\
\n Projection: #test2.a, #test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_using_left_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test2.a <= Int64(1)\
\n Left Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: #test2.a <= Int64(1)\
\n Left Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_using_right_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.a <= Int64(1)\
\n Right Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: #test.a <= Int64(1)\
\n Right Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_using_left_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test.a <= Int64(1)\
\n Left Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Left Join: Using #test.a = #test2.a\
\n Filter: #test.a <= Int64(1)\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_using_right_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
&right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #test2.a <= Int64(1)\
\n Right Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n TableScan: test2"
);
let expected = "\
Right Join: Using #test.a = #test2.a\
\n TableScan: test\
\n Projection: #test2.a\
\n Filter: #test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.c")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.a Filter: #test.c > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: #test.a = #test2.a Filter: #test.b < #test2.b\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.c > UInt32(1)\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n Filter: #test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn join_filter_removed() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.b")
.gt(lit(1u32))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.a Filter: #test.b > UInt32(1) AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.b > UInt32(1)\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n Filter: #test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn join_filter_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b")])?
.build()?;
let filter = col("test.a").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("b")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #test.a = #test2.b Filter: #test.a > UInt32(1)\
\n Projection: #test.a\
\n TableScan: test\
\n Projection: #test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: #test.a = #test2.b\
\n Projection: #test.a\
\n Filter: #test.a > UInt32(1)\
\n TableScan: test\
\n Projection: #test2.b\
\n Filter: #test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn left_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Left,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Left Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2"
);
let expected = "\
Left Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n Filter: #test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn right_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Right,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Right Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2"
);
let expected = "\
Right Join: #test.a = #test2.a Filter: #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n Filter: #test.a > UInt32(1)\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn full_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Full,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Full Join: #test.a = #test2.a Filter: #test.a > UInt32(1) AND #test.b < #test2.b AND #test2.c > UInt32(4)\
\n Projection: #test.a, #test.b, #test.c\
\n TableScan: test\
\n Projection: #test2.a, #test2.b, #test2.c\
\n TableScan: test2"
);
let expected = &format!("{:?}", plan);
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
struct PushDownProvider {
pub filter_support: TableProviderFilterPushDown,
}
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new(
"a",
arrow::datatypes::DataType::Int32,
true,
),
arrow::datatypes::Field::new(
"b",
arrow::datatypes::DataType::Int32,
true,
),
]))
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn supports_filter_pushdown(
&self,
_: &Expr,
) -> Result<TableProviderFilterPushDown> {
Ok(self.filter_support.clone())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
let test_provider = PushDownProvider { filter_support };
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".to_string(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: None,
source: Arc::new(test_provider),
fetch: None,
});
LogicalPlanBuilder::from(table_scan)
.filter(col("a").eq(lit(1i64)))?
.build()
}
#[test]
fn filter_with_table_provider_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
let expected = "\
TableScan: test, full_filters=[#a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_with_table_provider_inexact() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let expected = "\
Filter: #a = Int64(1)\
\n TableScan: test, partial_filters=[#a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn filter_with_table_provider_multiple_invocations() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let optimised_plan = optimize_plan(&plan);
let expected = "\
Filter: #a = Int64(1)\
\n TableScan: test, partial_filters=[#a = Int64(1)]";
assert_optimized_plan_eq(&optimised_plan, expected);
Ok(())
}
#[test]
fn filter_with_table_provider_unsupported() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
let expected = "\
Filter: #a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn multi_combined_filter() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Inexact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".to_string(),
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected ="Projection: #a, #b\
\n Filter: #a = Int64(10) AND #b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[#a = Int64(10), #b > Int64(11)]";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b > Int64(10) AND #test.c > Int64(10)\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: #test.a AS b, #test.c\
\n Filter: #test.a > Int64(10) AND #test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b > Int64(10) AND #test.c > Int64(10)\
\n Projection: #b, #test.c\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: #b, #test.c\
\n Projection: #test.a AS b, #test.c\
\n Filter: #test.a > Int64(10) AND #test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_filter_with_multi_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c").alias("d")])?
.filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b > Int64(10) AND #d > Int64(10)\
\n Projection: #test.a AS b, #test.c AS d\
\n TableScan: test\
"
);
let expected = "\
Projection: #test.a AS b, #test.c AS d\
\n Filter: #test.a > Int64(10) AND #test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn join_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b").alias("d")])?
.build()?;
let filter = col("c").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
JoinType::Inner,
(vec![Column::from_name("c")], vec![Column::from_name("d")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Inner Join: #c = #d Filter: #c > UInt32(1)\
\n Projection: #test.a AS c\
\n TableScan: test\
\n Projection: #test2.b AS d\
\n TableScan: test2"
);
let expected = "\
Inner Join: #c = #d\
\n Projection: #test.a AS c\
\n Filter: #test.a > UInt32(1)\
\n TableScan: test\
\n Projection: #test2.b AS d\
\n Filter: #test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_in_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: #test.a AS b, #test.c\
\n Filter: #test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_in_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: #b, #test.c\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: #b, #test.c\
\n Projection: #test.a AS b, #test.c\
\n Filter: #test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn test_in_subquery_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan_sq = test_table_scan_with_name("sq")?;
let subplan = Arc::new(
LogicalPlanBuilder::from(table_scan_sq)
.project(vec![col("c")])?
.build()?,
);
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_subquery(col("b"), subplan))?
.build()?;
let expected_before = "\
Filter: #b IN (Subquery: Projection: #sq.c\n TableScan: sq)\
\n Projection: #test.a AS b, #test.c\
\n TableScan: test";
assert_eq!(format!("{:?}", plan), expected_before);
let expected_after = "\
Projection: #test.a AS b, #test.c\
\n Filter: #test.a IN (Subquery: Projection: #sq.c\n TableScan: sq)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected_after);
Ok(())
}
}