use crate::expr::Alias;
use crate::expr_rewriter::normalize_col;
use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{Column, Result};
pub fn rewrite_sort_cols_by_aggs(
sorts: impl IntoIterator<Item = impl Into<Sort>>,
plan: &LogicalPlan,
) -> Result<Vec<Sort>> {
sorts
.into_iter()
.map(|e| {
let sort = e.into();
Ok(Sort::new(
rewrite_sort_col_by_aggs(sort.expr, plan)?,
sort.asc,
sort.nulls_first,
))
})
.collect()
}
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
let plan_inputs = plan.inputs();
if plan_inputs.len() == 1 {
let proj_exprs = plan.expressions();
rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
} else {
Ok(expr)
}
}
fn rewrite_in_terms_of_projection(
expr: Expr,
proj_exprs: Vec<Expr>,
input: &LogicalPlan,
) -> Result<Expr> {
expr.transform(|expr| {
if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
let (qualifier, field_name) = found.qualified_name();
let col = Expr::Column(Column::new(qualifier, field_name));
return Ok(Transformed::yes(col));
}
let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
e
} else {
return Ok(Transformed::no(expr));
};
let name = normalized_expr.schema_name().to_string();
let search_col = Expr::Column(Column::new_unqualified(name));
let mut found = None;
for proj_expr in &proj_exprs {
proj_expr.apply(|e| {
if expr_match(&search_col, e) {
found = Some(e.clone());
return Ok(TreeNodeRecursion::Stop);
}
Ok(TreeNodeRecursion::Continue)
})?;
}
if let Some(found) = found {
return Ok(Transformed::yes(match normalized_expr {
Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
expr: Box::new(found),
data_type,
}),
Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
expr: Box::new(found),
data_type,
}),
_ => found,
}));
}
Ok(Transformed::no(expr))
})
.data()
}
fn expr_match(needle: &Expr, expr: &Expr) -> bool {
if let Expr::Alias(Alias { expr, .. }) = &expr {
expr.as_ref() == needle
} else {
expr == needle
}
}
#[cfg(test)]
mod test {
use std::ops::Add;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use crate::{
cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
LogicalPlanBuilder,
};
use super::*;
use crate::test::function_stub::avg;
use crate::test::function_stub::min;
#[test]
fn rewrite_sort_cols_by_agg() {
let agg = make_input()
.aggregate(
vec![col("c1")],
vec![min(col("c2"))],
)
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "c1 --> c1",
input: sort(col("c1")),
expected: sort(col("c1")),
},
TestCase {
desc: "c1 + c2 --> c1 + c2",
input: sort(col("c1") + col("c1")),
expected: sort(col("c1") + col("c1")),
},
TestCase {
desc: r#"min(c2) --> "min(c2)"#,
input: sort(min(col("c2"))),
expected: sort(min(col("c2"))),
},
TestCase {
desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("c1") + min(col("c2"))),
},
];
for case in cases {
case.run(&agg)
}
}
#[test]
fn rewrite_sort_cols_by_agg_alias() {
let agg = make_input()
.aggregate(
vec![col("c1")],
vec![min(col("c2")), avg(col("c3"))],
)
.unwrap()
.project(vec![
col("c1").add(lit(1)).alias("c1"),
min(col("c2")),
avg(col("c3")).alias("average"),
])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
input: sort(col("c1")),
expected: sort(col("c1")),
},
TestCase {
desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(min(col("c2"))),
expected: sort(col("min(t.c2)")),
},
TestCase {
desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
expected: sort(col("c1") + col("min(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("avg(t.c3)").alias("average")),
},
];
for case in cases {
case.run(&agg)
}
}
#[test]
fn preserve_cast() {
let plan = make_input()
.project(vec![col("c2").alias("c2")])
.unwrap()
.project(vec![col("c2").alias("c2")])
.unwrap()
.build()
.unwrap();
let cases = vec![
TestCase {
desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
input: sort(cast(col("c2"), DataType::Int64)),
expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
},
TestCase {
desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
input: sort(try_cast(col("c2"), DataType::Int64)),
expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
},
];
for case in cases {
case.run(&plan)
}
}
struct TestCase {
desc: &'static str,
input: Sort,
expected: Sort,
}
impl TestCase {
fn run(self, input_plan: &LogicalPlan) {
let Self {
desc,
input,
expected,
} = self;
println!("running: '{desc}'");
let mut exprs =
rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
assert_eq!(exprs.len(), 1);
let rewritten = exprs.pop().unwrap();
assert_eq!(
rewritten, expected,
"\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
);
}
}
fn make_input() -> LogicalPlanBuilder {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Float64, true),
]));
let projection = None;
LogicalPlanBuilder::scan(
"t",
Arc::new(LogicalTableSource::new(schema)),
projection,
)
.unwrap()
}
fn sort(expr: Expr) -> Sort {
let asc = true;
let nulls_first = true;
expr.sort(asc, nulls_first)
}
}