use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator},
Result,
};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort};
pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = plan.clone();
let transformed_plan = plan.transform_up(|plan| match plan {
LogicalPlan::Union(mut union) => {
let schema = match Arc::try_unwrap(union.schema) {
Ok(inner) => inner,
Err(schema) => (*schema).clone(),
};
let schema = schema.strip_qualifiers();
union.schema = Arc::new(schema);
Ok(Transformed::yes(LogicalPlan::Union(union)))
}
LogicalPlan::Sort(sort) => {
if !matches!(&*sort.input, LogicalPlan::Union(_)) {
return Ok(Transformed::no(LogicalPlan::Sort(sort)));
}
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: rewrite_sort_expr_for_union(sort.expr)?,
input: sort.input,
fetch: sort.fetch,
})))
}
_ => Ok(Transformed::no(plan)),
});
transformed_plan.data()
}
fn rewrite_sort_expr_for_union(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
let sort_exprs: Vec<Expr> = exprs
.into_iter()
.map_until_stop_and_collect(|expr| {
expr.transform_up(|expr| {
if let Expr::Column(mut col) = expr {
col.relation = None;
Ok(Transformed::yes(Expr::Column(col)))
} else {
Ok(Transformed::no(expr))
}
})
})
.data()?;
Ok(sort_exprs)
}
pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
p: &Projection,
) -> Option<LogicalPlan> {
let LogicalPlan::Sort(sort) = p.input.as_ref() else {
return None;
};
let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
return None;
};
let mut map = HashMap::new();
let inner_exprs = inner_p
.expr
.iter()
.map(|f| {
if let Expr::Alias(alias) = f {
let a = Expr::Column(alias.name.clone().into());
map.insert(a.clone(), f.clone());
a
} else {
f.clone()
}
})
.collect::<Vec<_>>();
let mut collects = p.expr.clone();
for expr in &sort.expr {
if let Expr::Sort(s) = expr {
collects.push(s.expr.as_ref().clone());
}
}
if collects.iter().collect::<HashSet<_>>()
== inner_exprs.iter().collect::<HashSet<_>>()
{
let mut sort = sort.clone();
let mut inner_p = inner_p.clone();
let new_exprs = p
.expr
.iter()
.map(|e| map.get(e).unwrap_or(e).clone())
.collect::<Vec<_>>();
inner_p.expr.clone_from(&new_exprs);
sort.input = Arc::new(LogicalPlan::Projection(inner_p));
Some(LogicalPlan::Sort(sort))
} else {
None
}
}