use crate::error::{ExecutionError, Result};
use crate::logical_plan::LogicalPlan;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use std::{collections::HashSet, sync::Arc};
use utils::optimize_explain;
pub struct ProjectionPushDown {}
impl OptimizerRule for ProjectionPushDown {
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let required_columns = plan
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<String>>();
return optimize_plan(self, plan, &required_columns, false);
}
fn name(&self) -> &str {
return "projection_push_down";
}
}
impl ProjectionPushDown {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn get_projected_schema(
schema: &Schema,
projection: &Option<Vec<usize>>,
required_columns: &HashSet<String>,
has_projection: bool,
) -> Result<(Vec<usize>, SchemaRef)> {
if projection.is_some() {
return Err(ExecutionError::General(
"Cannot run projection push-down rule more than once".to_string(),
));
}
let mut projection: Vec<usize> = required_columns
.iter()
.map(|name| schema.index_of(name))
.filter_map(ArrowResult::ok)
.collect();
if projection.is_empty() {
if has_projection {
projection.push(0);
} else {
projection = schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| i)
.collect::<Vec<usize>>();
}
}
projection.sort();
let mut projected_fields: Vec<Field> = Vec::with_capacity(projection.len());
for i in &projection {
projected_fields.push(schema.fields()[*i].clone());
}
Ok((projection, SchemaRef::new(Schema::new(projected_fields))))
}
fn optimize_plan(
optimizer: &mut ProjectionPushDown,
plan: &LogicalPlan,
required_columns: &HashSet<String>, has_projection: bool,
) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Projection {
input,
expr,
schema,
} => {
let mut new_expr = Vec::new();
let mut new_fields = Vec::new();
let mut new_required_columns = HashSet::new();
schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
if required_columns.contains(field.name()) {
new_expr.push(expr[i].clone());
new_fields.push(field.clone());
utils::expr_to_column_names(&expr[i], &mut new_required_columns)
} else {
Ok(())
}
})
.collect::<Result<()>>()?;
let new_input =
optimize_plan(optimizer, &input, &new_required_columns, true)?;
if new_fields.len() == 0 {
Ok(new_input)
} else {
Ok(LogicalPlan::Projection {
expr: new_expr,
input: Arc::new(new_input),
schema: SchemaRef::new(Schema::new(new_fields)),
})
}
}
LogicalPlan::Aggregate {
schema,
input,
group_expr,
aggr_expr,
..
} => {
let mut new_required_columns = HashSet::new();
utils::exprlist_to_column_names(group_expr, &mut new_required_columns)?;
let mut new_aggr_expr = Vec::new();
aggr_expr
.iter()
.map(|expr| {
let name = &expr.name(&schema)?;
if required_columns.contains(name) {
new_aggr_expr.push(expr.clone());
new_required_columns.insert(name.clone());
utils::expr_to_column_names(expr, &mut new_required_columns)
} else {
Ok(())
}
})
.collect::<Result<()>>()?;
let new_schema = Schema::new(
schema
.fields()
.iter()
.filter(|x| new_required_columns.contains(x.name()))
.cloned()
.collect(),
);
Ok(LogicalPlan::Aggregate {
group_expr: group_expr.clone(),
aggr_expr: new_aggr_expr,
input: Arc::new(optimize_plan(
optimizer,
&input,
&new_required_columns,
true,
)?),
schema: SchemaRef::new(new_schema),
})
}
LogicalPlan::TableScan {
schema_name,
table_name,
table_schema,
projection,
..
} => {
let (projection, projected_schema) = get_projected_schema(
&table_schema,
projection,
required_columns,
has_projection,
)?;
Ok(LogicalPlan::TableScan {
schema_name: schema_name.to_string(),
table_name: table_name.to_string(),
table_schema: table_schema.clone(),
projection: Some(projection),
projected_schema: projected_schema,
})
}
LogicalPlan::InMemoryScan {
data,
schema,
projection,
..
} => {
let (projection, projected_schema) = get_projected_schema(
&schema,
projection,
required_columns,
has_projection,
)?;
Ok(LogicalPlan::InMemoryScan {
data: data.clone(),
schema: schema.clone(),
projection: Some(projection),
projected_schema: projected_schema,
})
}
LogicalPlan::CsvScan {
path,
has_header,
delimiter,
schema,
projection,
..
} => {
let (projection, projected_schema) = get_projected_schema(
&schema,
projection,
required_columns,
has_projection,
)?;
Ok(LogicalPlan::CsvScan {
path: path.to_owned(),
has_header: *has_header,
schema: schema.clone(),
delimiter: *delimiter,
projection: Some(projection),
projected_schema: projected_schema,
})
}
LogicalPlan::ParquetScan {
path,
schema,
projection,
..
} => {
let (projection, projected_schema) = get_projected_schema(
&schema,
projection,
required_columns,
has_projection,
)?;
Ok(LogicalPlan::ParquetScan {
path: path.to_owned(),
schema: schema.clone(),
projection: Some(projection),
projected_schema: projected_schema,
})
}
LogicalPlan::Explain {
verbose,
plan,
stringified_plans,
schema,
} => optimize_explain(optimizer, *verbose, &*plan, stringified_plans, &*schema),
LogicalPlan::Limit { .. }
| LogicalPlan::Filter { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Sort { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Extension { .. } => {
let expr = utils::expressions(plan);
let mut new_required_columns = required_columns.clone();
utils::exprlist_to_column_names(&expr, &mut new_required_columns)?;
let inputs = utils::inputs(plan);
let new_inputs = inputs
.iter()
.map(|plan| {
optimize_plan(optimizer, plan, &new_required_columns, has_projection)
})
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::{col, lit};
use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder};
use crate::test::*;
use arrow::datatypes::DataType;
#[test]
fn aggregate_no_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn aggregate_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("c")], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn aggregate_no_group_by_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.filter(col("c"))?
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
\n Filter: #c\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn cast() -> Result<()> {
let table_scan = test_table_scan()?;
let projection = LogicalPlanBuilder::from(&table_scan)
.project(vec![Expr::Cast {
expr: Box::new(col("c")),
data_type: DataType::Float64,
}])?
.build()?;
let expected = "Projection: CAST(#c AS Float64)\
\n TableScan: test projection=Some([2])";
assert_optimized_plan_eq(&projection, expected);
Ok(())
}
#[test]
fn table_scan_projected_schema() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["a", "b"]);
let expected = "Projection: #a, #b\
\n TableScan: test projection=Some([0, 1])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_limit() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("c"), col("a")])?
.limit(5)?
.build()?;
assert_fields_eq(&plan, vec!["c", "a"]);
let expected = "Limit: 5\
\n Projection: #c, #a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_scan_without_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan).build()?;
let expected = "TableScan: test projection=Some([0, 1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_scan_with_literal_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![lit(1_i64), lit(2_i64)])?
.build()?;
let expected = "Projection: Int64(1), Int64(2)\
\n TableScan: test projection=Some([0])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_unused_column() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("c"), col("a"), col("b")])?
.filter(col("c").gt(lit(1)))?
.aggregate(vec![col("c")], vec![max(col("a"))])?
.build()?;
assert_fields_eq(&plan, vec!["c", "MAX(a)"]);
let expected = "\
Aggregate: groupBy=[[#c]], aggr=[[MAX(#a)]]\
\n Filter: #c Gt Int32(1)\
\n Projection: #c, #a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_unused_projection() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("b")])?
.project(vec![lit(1).alias("a")])?
.build()?;
assert_fields_eq(&plan, vec!["a"]);
let expected = "\
Projection: Int32(1) AS a\
\n TableScan: test projection=Some([0])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
#[test]
fn table_unused_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
assert_eq!(3, table_scan.schema().fields().len());
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
.filter(col("c").gt(lit(1)))?
.project(vec![col("c"), col("a"), col("MAX(b)")])?
.build()?;
assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]);
let expected = "\
Projection: #c, #a, #MAX(b)\
\n Filter: #c Gt Int32(1)\
\n Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([0, 1, 2])";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimized_plan = optimize(plan).expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
}
fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let mut rule = ProjectionPushDown::new();
rule.optimize(plan)
}
}