use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode},
Result,
};
use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
use sqlparser::ast::Ident;
pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = plan.clone();
let transformed_plan = plan.transform_up(|plan| match plan {
LogicalPlan::Union(mut union) => {
let schema = Arc::unwrap_or_clone(union.schema);
let schema = schema.strip_qualifiers();
union.schema = Arc::new(schema);
Ok(Transformed::yes(LogicalPlan::Union(union)))
}
LogicalPlan::Sort(sort) => {
if !matches!(&*sort.input, LogicalPlan::Union(_)) {
return Ok(Transformed::no(LogicalPlan::Sort(sort)));
}
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: rewrite_sort_expr_for_union(sort.expr)?,
input: sort.input,
fetch: sort.fetch,
})))
}
_ => Ok(Transformed::no(plan)),
});
transformed_plan.data()
}
fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
let sort_exprs = transform_sort_vec(exprs, &mut |expr| {
expr.transform_up(|expr| {
if let Expr::Column(mut col) = expr {
col.relation = None;
Ok(Transformed::yes(Expr::Column(col)))
} else {
Ok(Transformed::no(expr))
}
})
})
.data()?;
Ok(sort_exprs)
}
pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
p: &Projection,
) -> Option<LogicalPlan> {
let LogicalPlan::Sort(sort) = p.input.as_ref() else {
return None;
};
let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
return None;
};
let mut map = HashMap::new();
let inner_exprs = inner_p
.expr
.iter()
.enumerate()
.map(|(i, f)| match f {
Expr::Alias(alias) => {
let a = Expr::Column(alias.name.clone().into());
map.insert(a.clone(), f.clone());
a
}
Expr::Column(_) => {
map.insert(
Expr::Column(inner_p.schema.field(i).name().into()),
f.clone(),
);
f.clone()
}
_ => {
let a = Expr::Column(inner_p.schema.field(i).name().into());
map.insert(a.clone(), f.clone());
a
}
})
.collect::<Vec<_>>();
let mut collects = p.expr.clone();
for sort in &sort.expr {
collects.push(sort.expr.clone());
}
let outer_collects = collects.iter().map(Expr::to_string).collect::<HashSet<_>>();
let inner_collects = inner_exprs
.iter()
.map(Expr::to_string)
.collect::<HashSet<_>>();
if outer_collects == inner_collects {
let mut sort = sort.clone();
let mut inner_p = inner_p.clone();
let new_exprs = p
.expr
.iter()
.map(|e| map.get(e).unwrap_or(e).clone())
.collect::<Vec<_>>();
inner_p.expr.clone_from(&new_exprs);
sort.input = Arc::new(LogicalPlan::Projection(inner_p));
Some(LogicalPlan::Sort(sort))
} else {
None
}
}
pub(super) fn subquery_alias_inner_query_and_columns(
subquery_alias: &datafusion_expr::SubqueryAlias,
) -> (&LogicalPlan, Vec<Ident>) {
let plan: &LogicalPlan = subquery_alias.input.as_ref();
let LogicalPlan::Projection(outer_projections) = plan else {
return (plan, vec![]);
};
let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else {
return (plan, vec![]);
};
let mut columns: Vec<Ident> = vec![];
for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
return (plan, vec![]);
};
let inner_expr_string = match inner_expr {
Expr::Column(_) => inner_expr.to_string(),
_ => inner_projection.schema.field(i).name().clone(),
};
if outer_alias.expr.to_string() != inner_expr_string {
return (plan, vec![]);
};
columns.push(outer_alias.name.as_str().into());
}
(outer_projections.input.as_ref(), columns)
}
pub(super) fn inject_column_aliases(
projection: &datafusion_expr::Projection,
aliases: impl IntoIterator<Item = Ident>,
) -> LogicalPlan {
let mut updated_projection = projection.clone();
let new_exprs = updated_projection
.expr
.into_iter()
.zip(aliases)
.map(|(expr, col_alias)| match expr {
Expr::Column(col) => {
let relation = col.relation.clone();
Expr::Alias(Alias {
expr: Box::new(Expr::Column(col)),
relation,
name: col_alias.value,
})
}
_ => expr,
})
.collect::<Vec<_>>();
updated_projection.expr = new_exprs;
LogicalPlan::Projection(updated_projection)
}
fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
match logical_plan {
LogicalPlan::Projection(p) => Some(p),
LogicalPlan::Limit(p) => find_projection(p.input.as_ref()),
LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()),
LogicalPlan::Sort(p) => find_projection(p.input.as_ref()),
_ => None,
}
}