use datafusion_common::Result;
use datafusion_expr::LogicalPlan;
pub trait TreeNodeRewritable: Clone {
fn transform_using<R: TreeNodeRewriter<Self>>(
self,
rewriter: &mut R,
) -> Result<Self> {
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};
let after_op_children =
self.map_children(|node| node.transform_using(rewriter))?;
if need_mutate {
rewriter.mutate(after_op_children)
} else {
Ok(after_op_children)
}
}
fn transform<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
self.transform_up(op)
}
fn transform_down<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
let node_cloned = self.clone();
let after_op = match op(node_cloned)? {
Some(value) => value,
None => self,
};
after_op.map_children(|node| node.transform_down(op))
}
fn transform_up<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
let after_op_children = self.map_children(|node| node.transform_up(op))?;
let after_op_children_clone = after_op_children.clone();
let new_node = match op(after_op_children)? {
Some(value) => value,
None => after_op_children_clone,
};
Ok(new_node)
}
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>;
fn for_each<F>(&self, func: &F) -> Result<()>
where
F: Fn(&Self) -> Result<()>,
{
func(self)?;
self.apply_children(|node| node.for_each(func))
}
fn for_each_up<F>(&self, func: &F) -> Result<()>
where
F: Fn(&Self) -> Result<()>,
{
self.apply_children(|node| node.for_each_up(func))?;
func(self)
}
fn apply_children<F>(&self, func: F) -> Result<()>
where
F: Fn(&Self) -> Result<()>;
}
pub trait TreeNodeRewriter<N: TreeNodeRewritable>: Sized {
fn pre_visit(&mut self, _node: &N) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
fn mutate(&mut self, node: N) -> Result<N>;
}
#[allow(dead_code)]
pub enum RewriteRecursion {
Continue,
Mutate,
Stop,
Skip,
}
impl TreeNodeRewritable for LogicalPlan {
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.inputs().into_iter().cloned().collect::<Vec<_>>();
if !children.is_empty() {
let new_children: Result<Vec<_>> =
children.into_iter().map(transform).collect();
self.with_new_inputs(new_children?.as_slice())
} else {
Ok(self)
}
}
fn apply_children<F>(&self, func: F) -> Result<()>
where
F: Fn(&Self) -> Result<()>,
{
let children = self.inputs();
if !children.is_empty() {
children.into_iter().try_for_each(func)
} else {
Ok(())
}
}
}