use crate::physical_plan::with_new_children_if_necessary;
use crate::physical_plan::ExecutionPlan;
use datafusion_common::Result;
use std::sync::Arc;
pub trait TreeNodeRewritable: Clone {
fn transform_using<R: TreeNodeRewriter<Self>>(
self,
rewriter: &mut R,
) -> Result<Self> {
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};
let after_op_children =
self.map_children(|node| node.transform_using(rewriter))?;
if need_mutate {
rewriter.mutate(after_op_children)
} else {
Ok(after_op_children)
}
}
fn transform<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
self.transform_up(op)
}
fn transform_down<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
let node_cloned = self.clone();
let after_op = match op(node_cloned)? {
Some(value) => value,
None => self,
};
after_op.map_children(|node| node.transform_down(op))
}
fn transform_up<F>(self, op: &F) -> Result<Self>
where
F: Fn(Self) -> Result<Option<Self>>,
{
let after_op_children = self.map_children(|node| node.transform_up(op))?;
let after_op_children_clone = after_op_children.clone();
let new_node = match op(after_op_children)? {
Some(value) => value,
None => after_op_children_clone,
};
Ok(new_node)
}
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>;
}
pub trait TreeNodeRewriter<N: TreeNodeRewritable>: Sized {
fn pre_visit(&mut self, _node: &N) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
fn mutate(&mut self, node: N) -> Result<N>;
}
#[allow(dead_code)]
pub enum RewriteRecursion {
Continue,
Mutate,
Stop,
Skip,
}
impl TreeNodeRewritable for Arc<dyn ExecutionPlan> {
fn map_children<F>(self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
let children = self.children();
if !children.is_empty() {
let new_children: Result<Vec<_>> =
children.into_iter().map(transform).collect();
with_new_children_if_necessary(self, new_children?)
} else {
Ok(self)
}
}
}