use std::sync::Arc;
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator},
Result,
};
use datafusion_expr::{Expr, LogicalPlan, Sort};
pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = plan.clone();
let transformed_plan = plan.transform_up(|plan| match plan {
LogicalPlan::Union(mut union) => {
let schema = match Arc::try_unwrap(union.schema) {
Ok(inner) => inner,
Err(schema) => (*schema).clone(),
};
let schema = schema.strip_qualifiers();
union.schema = Arc::new(schema);
Ok(Transformed::yes(LogicalPlan::Union(union)))
}
LogicalPlan::Sort(sort) => {
if !matches!(&*sort.input, LogicalPlan::Union(_)) {
return Ok(Transformed::no(LogicalPlan::Sort(sort)));
}
Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: rewrite_sort_expr_for_union(sort.expr)?,
input: sort.input,
fetch: sort.fetch,
})))
}
_ => Ok(Transformed::no(plan)),
});
transformed_plan.data()
}
fn rewrite_sort_expr_for_union(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
let sort_exprs: Vec<Expr> = exprs
.into_iter()
.map_until_stop_and_collect(|expr| {
expr.transform_up(|expr| {
if let Expr::Column(mut col) = expr {
col.relation = None;
Ok(Transformed::yes(Expr::Column(col)))
} else {
Ok(Transformed::no(expr))
}
})
})
.data()?;
Ok(sort_exprs)
}