use crate::error::{DataFusionError, Result};
use crate::logical_plan::{DFSchema, Expr, LogicalPlan};
pub(crate) fn expand_wildcard(expr: &Expr, schema: &DFSchema) -> Vec<Expr> {
match expr {
Expr::Wildcard => schema
.fields()
.iter()
.map(|f| Expr::Column(f.name().to_string()))
.collect::<Vec<Expr>>(),
_ => vec![expr.clone()],
}
}
pub(crate) fn find_aggregate_exprs(exprs: &Vec<Expr>) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
matches!(
nested_expr,
Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
)
})
}
pub(crate) fn find_column_exprs(exprs: &Vec<Expr>) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_)))
}
fn find_exprs_in_exprs<F>(exprs: &Vec<Expr>, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
exprs
.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
where
F: Fn(&Expr) -> bool,
{
let matched_exprs = if test_fn(expr) {
vec![expr.clone()]
} else {
match expr {
Expr::AggregateFunction { args, .. } => find_exprs_in_exprs(&args, test_fn),
Expr::AggregateUDF { args, .. } => find_exprs_in_exprs(&args, test_fn),
Expr::Alias(nested_expr, _) => {
find_exprs_in_expr(nested_expr.as_ref(), test_fn)
}
Expr::Between {
expr: nested_expr,
low,
high,
..
} => {
let mut matches = vec![];
matches.extend(find_exprs_in_expr(nested_expr.as_ref(), test_fn));
matches.extend(find_exprs_in_expr(low.as_ref(), test_fn));
matches.extend(find_exprs_in_expr(high.as_ref(), test_fn));
matches
}
Expr::BinaryExpr { left, right, .. } => {
let mut matches = vec![];
matches.extend(find_exprs_in_expr(left.as_ref(), test_fn));
matches.extend(find_exprs_in_expr(right.as_ref(), test_fn));
matches
}
Expr::InList {
expr: nested_expr,
list,
..
} => {
let mut matches = vec![];
matches.extend(find_exprs_in_expr(nested_expr.as_ref(), test_fn));
matches.extend(
list.iter()
.flat_map(|expr| find_exprs_in_expr(expr, test_fn))
.collect::<Vec<Expr>>(),
);
matches
}
Expr::Case {
expr: case_expr_opt,
when_then_expr,
else_expr: else_expr_opt,
} => {
let mut matches = vec![];
if let Some(case_expr) = case_expr_opt {
matches.extend(find_exprs_in_expr(case_expr.as_ref(), test_fn));
}
matches.extend(
when_then_expr
.iter()
.flat_map(|(a, b)| vec![a, b])
.flat_map(|expr| find_exprs_in_expr(expr.as_ref(), test_fn))
.collect::<Vec<Expr>>(),
);
if let Some(else_expr) = else_expr_opt {
matches.extend(find_exprs_in_expr(else_expr.as_ref(), test_fn));
}
matches
}
Expr::Cast {
expr: nested_expr, ..
} => find_exprs_in_expr(nested_expr.as_ref(), test_fn),
Expr::IsNotNull(nested_expr) => {
find_exprs_in_expr(nested_expr.as_ref(), test_fn)
}
Expr::IsNull(nested_expr) => {
find_exprs_in_expr(nested_expr.as_ref(), test_fn)
}
Expr::Negative(nested_expr) => {
find_exprs_in_expr(nested_expr.as_ref(), test_fn)
}
Expr::Not(nested_expr) => find_exprs_in_expr(nested_expr.as_ref(), test_fn),
Expr::ScalarFunction { args, .. } => find_exprs_in_exprs(&args, test_fn),
Expr::ScalarUDF { args, .. } => find_exprs_in_exprs(&args, test_fn),
Expr::Sort {
expr: nested_expr, ..
} => find_exprs_in_expr(nested_expr.as_ref(), test_fn),
Expr::Column(_)
| Expr::Literal(_)
| Expr::ScalarVariable(_)
| Expr::Wildcard => vec![],
}
};
matched_exprs.into_iter().fold(vec![], |mut acc, expr| {
if !acc.contains(&expr) {
acc.push(expr)
}
acc
})
}
pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
match expr {
Expr::Column(_) => Ok(expr.clone()),
_ => Ok(Expr::Column(expr.name(&plan.schema())?)),
}
}
pub(crate) fn rebase_expr(
expr: &Expr,
base_exprs: &Vec<Expr>,
plan: &LogicalPlan,
) -> Result<Expr> {
clone_with_replacement(expr, &|nested_expr| {
if base_exprs.contains(nested_expr) {
Ok(Some(expr_as_column_expr(nested_expr, plan)?))
} else {
Ok(None)
}
})
}
pub(crate) fn can_columns_satisfy_exprs(
columns: &Vec<Expr>,
exprs: &Vec<Expr>,
) -> Result<bool> {
columns.iter().try_for_each(|c| match c {
Expr::Column(_) => Ok(()),
_ => Err(DataFusionError::Internal(
"Expr::Column are required".to_string(),
)),
})?;
Ok(find_column_exprs(exprs).iter().all(|c| columns.contains(c)))
}
fn clone_with_replacement<F>(expr: &Expr, replacement_fn: &F) -> Result<Expr>
where
F: Fn(&Expr) -> Result<Option<Expr>>,
{
let replacement_opt = replacement_fn(expr)?;
match replacement_opt {
Some(replacement) => Ok(replacement),
None => match expr {
Expr::AggregateFunction {
fun,
args,
distinct,
} => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
distinct: *distinct,
}),
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
alias_name.clone(),
)),
Expr::Between {
expr: nested_expr,
negated,
low,
high,
} => Ok(Expr::Between {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
negated: *negated,
low: Box::new(clone_with_replacement(&**low, replacement_fn)?),
high: Box::new(clone_with_replacement(&**high, replacement_fn)?),
}),
Expr::InList {
expr: nested_expr,
list,
negated,
} => Ok(Expr::InList {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
list: list
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
negated: *negated,
}),
Expr::BinaryExpr { left, right, op } => Ok(Expr::BinaryExpr {
left: Box::new(clone_with_replacement(&**left, replacement_fn)?),
op: op.clone(),
right: Box::new(clone_with_replacement(&**right, replacement_fn)?),
}),
Expr::Case {
expr: case_expr_opt,
when_then_expr,
else_expr: else_expr_opt,
} => Ok(Expr::Case {
expr: match case_expr_opt {
Some(case_expr) => Some(Box::new(clone_with_replacement(
&**case_expr,
replacement_fn,
)?)),
None => None,
},
when_then_expr: when_then_expr
.iter()
.map(|(a, b)| {
Ok((
Box::new(clone_with_replacement(&**a, replacement_fn)?),
Box::new(clone_with_replacement(&**b, replacement_fn)?),
))
})
.collect::<Result<Vec<(_, _)>>>()?,
else_expr: match else_expr_opt {
Some(else_expr) => Some(Box::new(clone_with_replacement(
&**else_expr,
replacement_fn,
)?)),
None => None,
},
}),
Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::ScalarUDF { fun, args } => Ok(Expr::ScalarUDF {
fun: fun.clone(),
args: args
.iter()
.map(|arg| clone_with_replacement(arg, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
}),
Expr::Negative(nested_expr) => Ok(Expr::Negative(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::Not(nested_expr) => Ok(Expr::Not(Box::new(clone_with_replacement(
&**nested_expr,
replacement_fn,
)?))),
Expr::IsNotNull(nested_expr) => Ok(Expr::IsNotNull(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::IsNull(nested_expr) => Ok(Expr::IsNull(Box::new(
clone_with_replacement(&**nested_expr, replacement_fn)?,
))),
Expr::Cast {
expr: nested_expr,
data_type,
} => Ok(Expr::Cast {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
data_type: data_type.clone(),
}),
Expr::Sort {
expr: nested_expr,
asc,
nulls_first,
} => Ok(Expr::Sort {
expr: Box::new(clone_with_replacement(&**nested_expr, replacement_fn)?),
asc: *asc,
nulls_first: *nulls_first,
}),
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => {
Ok(expr.clone())
}
Expr::Wildcard => Ok(Expr::Wildcard),
},
}
}