use super::utils;
use crate::error::Result;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::Projection;
use crate::logical_plan::{Limit, TableScan};
use crate::logical_plan::{LogicalPlan, Union};
use crate::optimizer::optimizer::OptimizerRule;
use std::sync::Arc;
#[derive(Default)]
pub struct LimitPushDown {}
impl LimitPushDown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn limit_push_down(
optimizer: &LimitPushDown,
upper_limit: Option<usize>,
plan: &LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
match (plan, upper_limit) {
(LogicalPlan::Limit(Limit { n, input }), upper_limit) => {
let smallest = upper_limit.map(|x| std::cmp::min(x, *n)).unwrap_or(*n);
Ok(LogicalPlan::Limit(Limit {
n: smallest,
input: Arc::new(limit_push_down(
optimizer,
Some(smallest),
input.as_ref(),
execution_props,
)?),
}))
}
(
LogicalPlan::TableScan(TableScan {
table_name,
source,
projection,
filters,
limit,
projected_schema,
}),
Some(upper_limit),
) => Ok(LogicalPlan::TableScan(TableScan {
table_name: table_name.clone(),
source: source.clone(),
projection: projection.clone(),
filters: filters.clone(),
limit: limit
.map(|x| std::cmp::min(x, upper_limit))
.or(Some(upper_limit)),
projected_schema: projected_schema.clone(),
})),
(
LogicalPlan::Projection(Projection {
expr,
input,
schema,
alias,
}),
upper_limit,
) => {
Ok(LogicalPlan::Projection(Projection {
expr: expr.clone(),
input: Arc::new(limit_push_down(
optimizer,
upper_limit,
input.as_ref(),
execution_props,
)?),
schema: schema.clone(),
alias: alias.clone(),
}))
}
(
LogicalPlan::Union(Union {
inputs,
alias,
schema,
}),
Some(upper_limit),
) => {
let new_inputs = inputs
.iter()
.map(|x| {
Ok(LogicalPlan::Limit(Limit {
n: upper_limit,
input: Arc::new(limit_push_down(
optimizer,
Some(upper_limit),
x,
execution_props,
)?),
}))
})
.collect::<Result<_>>()?;
Ok(LogicalPlan::Union(Union {
inputs: new_inputs,
alias: alias.clone(),
schema: schema.clone(),
}))
}
_ => {
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| limit_push_down(optimizer, None, plan, execution_props))
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}
}
}
impl OptimizerRule for LimitPushDown {
fn optimize(
&self,
plan: &LogicalPlan,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
limit_push_down(self, None, plan, execution_props)
}
fn name(&self) -> &str {
"limit_push_down"
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
logical_plan::{col, max, LogicalPlan, LogicalPlanBuilder},
test::*,
};
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = LimitPushDown::new();
let optimized_plan = rule
.optimize(plan, &ExecutionProps::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
#[test]
fn limit_pushdown_projection_table_provider() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.limit(1000)?
.build()?;
let expected = "Limit: 1000\
\n Projection: #test.a\
\n TableScan: test projection=None, limit=1000";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn limit_push_down_take_smaller_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(1000)?
.limit(10)?
.build()?;
let expected = "Limit: 10\
\n Limit: 10\
\n TableScan: test projection=None, limit=10";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn limit_doesnt_push_down_aggregation() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![max(col("b"))])?
.limit(1000)?
.build()?;
let expected = "Limit: 1000\
\n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn limit_should_push_down_union() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.union(LogicalPlanBuilder::from(table_scan).build()?)?
.limit(1000)?
.build()?;
let expected = "Limit: 1000\
\n Union\
\n Limit: 1000\
\n TableScan: test projection=None, limit=1000\
\n Limit: 1000\
\n TableScan: test projection=None, limit=1000";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn multi_stage_limit_recurses_to_deeper_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(1000)?
.aggregate(vec![col("a")], vec![max(col("b"))])?
.limit(10)?
.build()?;
let expected = "Limit: 10\
\n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\
\n Limit: 1000\
\n TableScan: test projection=None, limit=1000";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}