use crate::{expr::GroupingSet, Expr};
use datafusion_common::Result;
pub enum Recursion<V: ExpressionVisitor> {
Continue(V),
Stop(V),
}
pub trait ExpressionVisitor<E: ExprVisitable = Expr>: Sized {
fn pre_visit(self, expr: &E) -> Result<Recursion<Self>>
where
Self: ExpressionVisitor;
fn post_visit(self, _expr: &E) -> Result<Self> {
Ok(self)
}
}
pub trait ExprVisitable: Sized {
fn accept<V: ExpressionVisitor<Self>>(&self, visitor: V) -> Result<V>;
}
impl ExprVisitable for Expr {
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
Recursion::Stop(visitor) => return Ok(visitor),
};
let visitor = match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsTrue(expr)
| Expr::IsFalse(expr)
| Expr::IsUnknown(expr)
| Expr::IsNotTrue(expr)
| Expr::IsNotFalse(expr)
| Expr::IsNotUnknown(expr)
| Expr::IsNull(expr)
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::InSubquery { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
v.and_then(|v| {
exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
})
})
}
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
| Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. } => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
let visitor = left.accept(visitor)?;
right.accept(visitor)
}
Expr::Like { expr, pattern, .. } => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::ILike { expr, pattern, .. } => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::SimilarTo { expr, pattern, .. } => {
let visitor = expr.accept(visitor)?;
pattern.accept(visitor)
}
Expr::Between {
expr, low, high, ..
} => {
let visitor = expr.accept(visitor)?;
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let visitor = if let Some(expr) = expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
let visitor = when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
if let Some(else_expr) = else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
}
}
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::WindowFunction {
args,
partition_by,
order_by,
..
} => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = partition_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = order_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
Ok(visitor)
}
Expr::InList { expr, list, .. } => {
let visitor = expr.accept(visitor)?;
list.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}?;
visitor.post_visit(self)
}
}