use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{Expr, LogicalPlan};
use std::sync::Arc;
#[derive(Default)]
pub struct ApplyFunctionRewrites {
function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>,
}
impl ApplyFunctionRewrites {
pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>) -> Self {
Self { function_rewrites }
}
}
impl AnalyzerRule for ApplyFunctionRewrites {
fn name(&self) -> &str {
"apply_function_rewrites"
}
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
self.analyze_internal(&plan, options)
}
}
impl ApplyFunctionRewrites {
fn analyze_internal(
&self,
plan: &LogicalPlan,
options: &ConfigOptions,
) -> Result<LogicalPlan> {
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.analyze_internal(p, options))
.collect::<Result<Vec<_>>>()?;
let mut schema = merge_schema(new_inputs.iter().collect());
if let LogicalPlan::TableScan(ts) = plan {
let source_schema =
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
schema.merge(&source_schema);
}
let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites: &self.function_rewrites,
options,
schema: &schema,
};
let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;
plan.with_new_exprs(new_expr, new_inputs)
}
}
struct OperatorToFunctionRewriter<'a> {
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
options: &'a ConfigOptions,
schema: &'a DFSchema,
}
impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
type Node = Expr;
fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
let mut transformed = false;
for rewriter in self.function_rewrites.iter() {
let result = rewriter.rewrite(expr, self.schema, self.options)?;
if result.transformed {
transformed = true;
}
expr = result.data
}
Ok(if transformed {
Transformed::yes(expr)
} else {
Transformed::no(expr)
})
}
}