use crate::logical_plan::{DFSchema, Expr, LogicalPlan};
use crate::{
error::{DataFusionError, Result},
logical_plan::{ExpressionVisitor, Recursion},
};
use std::collections::HashMap;
pub(crate) fn expand_wildcard(expr: &Expr, schema: &DFSchema) -> Vec<Expr> {
match expr {
Expr::Wildcard => schema
.fields()
.iter()
.map(|f| Expr::Column(f.name().to_string()))
.collect::<Vec<Expr>>(),
_ => vec![expr.clone()],
}
}
pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(
nested_expr,
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
)
})
}
pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_)))
}
fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
struct Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
test_fn: &'a F,
exprs: Vec<Expr>,
}
impl<'a, F> Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
fn new(test_fn: &'a F) -> Self {
Self {
test_fn,
exprs: Vec::new(),
}
}
}
impl<'a, F> ExpressionVisitor for Finder<'a, F>
where
F: Fn(&Expr) -> bool,
{
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
if (self.test_fn)(expr) {
if !(self.exprs.contains(expr)) {
self.exprs.push(expr.clone())
}
return Ok(Recursion::Stop(self));
}
Ok(Recursion::Continue(self))
}
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let Finder { exprs, .. } = expr
.accept(Finder::new(test_fn))
.expect("no way to return error during recursion");
exprs
}
pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(_) => Ok(expr.clone()),
_ => Ok(Expr::Column(expr.name(&plan.schema())?)),
}
}
pub(crate) fn rebase_expr(
expr: &Expr,
base_exprs: &[Expr],
plan: &LogicalPlan,
) -> Result<Expr> {
clone_with_replacement(expr, &|nested_expr| {
if base_exprs.contains(nested_expr) {
Ok(Some(expr_as_column_expr(nested_expr, plan)?))
} else {
Ok(None)
}
})
}
pub(crate) fn can_columns_satisfy_exprs(
columns: &[Expr],
exprs: &[Expr],
) -> Result<bool> {
columns.iter().try_for_each(|c| match c {
Expr::Column(_) => Ok(()),
_ => Err(DataFusionError::Internal(
"Expr::Column are required".to_string(),
)),
})?;
Ok(find_column_exprs(exprs).iter().all(|c| columns.contains(c)))
}
fn clone_with_replacement<F>(expr: &Expr, replacement_fn: &F) -> Result<Expr>
where
F: Fn(&Expr) -> Result<Option<Expr>>,
{
let replacement_opt = replacement_fn(expr)?;
match replacement_opt {
Some(replacement) => Ok(replacement),
None => match expr {
Expr::AggregateFunction {
fun,
args,
distinct,
} => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
distinct: *distinct,
}),
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
alias_name.clone(),
)),
Expr::Between {
expr: nested_expr,
negated,
low,
high,
} => Ok(Expr::Between {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
negated: *negated,
low: Box::new(clone_with_replacement(&**low, replacement_fn)?),
high: Box::new(clone_with_replacement(&**high, replacement_fn)?),
}),
Expr::InList {
expr: nested_expr,
list,
negated,
} => Ok(Expr::InList {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
list: list
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
negated: *negated,
}),
Expr::BinaryExpr { left, right, op } => Ok(Expr::BinaryExpr {
left: Box::new(clone_with_replacement(&**left, replacement_fn)?),
op: *op,
right: Box::new(clone_with_replacement(&**right, replacement_fn)?),
}),
Expr::Case {
expr: case_expr_opt,
when_then_expr,
else_expr: else_expr_opt,
} => Ok(Expr::Case {
expr: match case_expr_opt {
Some(case_expr) => Some(Box::new(clone_with_replacement(
&**case_expr,
replacement_fn,
)?)),
None => None,
},
when_then_expr: when_then_expr
.iter()
.map(|(a, b)| {
Ok((
Box::new(clone_with_replacement(&**a, replacement_fn)?),
Box::new(clone_with_replacement(&**b, replacement_fn)?),
))
})
.collect::<Result<Vec<(_, _)>>>()?,
else_expr: match else_expr_opt {
Some(else_expr) => Some(Box::new(clone_with_replacement(
&**else_expr,
replacement_fn,
)?)),
None => None,
},
}),
Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::ScalarUDF { fun, args } => Ok(Expr::ScalarUDF {
fun: fun.clone(),
args: args
.iter()
.map(|arg| clone_with_replacement(arg, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::Negative(nested_expr) => Ok(Expr::Negative(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::Not(nested_expr) => Ok(Expr::Not(Box::new(clone_with_replacement(
&**nested_expr,
replacement_fn,
)?))),
Expr::IsNotNull(nested_expr) => Ok(Expr::IsNotNull(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::IsNull(nested_expr) => Ok(Expr::IsNull(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::Cast {
expr: nested_expr,
data_type,
} => Ok(Expr::Cast {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
data_type: data_type.clone(),
}),
Expr::TryCast {
expr: nested_expr,
data_type,
} => Ok(Expr::TryCast {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
data_type: data_type.clone(),
}),
Expr::Sort {
expr: nested_expr,
asc,
nulls_first,
} => Ok(Expr::Sort {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
asc: *asc,
nulls_first: *nulls_first,
}),
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => {
Ok(expr.clone())
}
Expr::Wildcard => Ok(Expr::Wildcard),
},
}
}
pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
exprs
.iter()
.filter_map(|expr| match expr {
Expr::Alias(nested_expr, alias_name) => {
Some((alias_name.clone(), *nested_expr.clone()))
}
_ => None,
})
.collect::<HashMap<String, Expr>>()
}
pub(crate) fn resolve_aliases_to_exprs(
expr: &Expr,
aliases: &HashMap<String, Expr>,
) -> Result<Expr> {
clone_with_replacement(expr, &|nested_expr| match nested_expr {
Expr::Column(name) => {
if let Some(aliased_expr) = aliases.get(name) {
Ok(Some(aliased_expr.clone()))
} else {
Ok(None)
}
}
_ => Ok(None),
})
}