use std::sync::Arc;
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow::datatypes::Schema;
use datafusion_common::{
not_impl_err, plan_err,
tree_node::{TreeNode, TreeNodeRecursion},
Result,
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
use sqlparser::ast::{Query, SetExpr, SetOperator, With};
impl<S: ContextProvider> SqlToRel<'_, S> {
pub(super) fn plan_with_clause(
&self,
with: With,
planner_context: &mut PlannerContext,
) -> Result<()> {
let is_recursive = with.recursive;
for cte in with.cte_tables {
let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone());
if planner_context.contains_cte(&cte_name) {
return plan_err!(
"WITH query name {cte_name:?} specified more than once"
);
}
let cte_plan = if is_recursive {
self.recursive_cte(cte_name.clone(), *cte.query, planner_context)?
} else {
self.non_recursive_cte(*cte.query, planner_context)?
};
let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
planner_context.insert_cte(cte_name, final_plan);
}
Ok(())
}
fn non_recursive_cte(
&self,
cte_query: Query,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
self.query_to_plan(cte_query, planner_context)
}
fn recursive_cte(
&self,
cte_name: String,
mut cte_query: Query,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
if !self
.context_provider
.options()
.execution
.enable_recursive_ctes
{
return not_impl_err!("Recursive CTEs are not enabled");
}
let (left_expr, right_expr, set_quantifier) = match *cte_query.body {
SetExpr::SetOperation {
op: SetOperator::Union,
left,
right,
set_quantifier,
} => (left, right, set_quantifier),
other => {
cte_query.body = Box::new(other);
return self.non_recursive_cte(cte_query, planner_context);
}
};
let static_plan = self.set_expr_to_plan(*left_expr, planner_context)?;
let work_table_source = self.context_provider.create_cte_work_table(
&cte_name,
Arc::new(Schema::from(static_plan.schema().as_ref())),
)?;
let work_table_plan = LogicalPlanBuilder::scan(
cte_name.to_string(),
Arc::clone(&work_table_source),
None,
)?
.build()?;
let name = cte_name.clone();
planner_context.insert_cte(cte_name.clone(), work_table_plan);
let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?;
if !has_work_table_reference(&recursive_plan, &work_table_source) {
planner_context.remove_cte(&cte_name);
return self.set_operation_to_plan(
SetOperator::Union,
static_plan,
recursive_plan,
set_quantifier,
);
}
let distinct = !Self::is_union_all(set_quantifier)?;
LogicalPlanBuilder::from(static_plan)
.to_recursive_query(name, recursive_plan, distinct)?
.build()
}
}
fn has_work_table_reference(
plan: &LogicalPlan,
work_table_source: &Arc<dyn TableSource>,
) -> bool {
let mut has_reference = false;
plan.apply(|node| {
if let LogicalPlan::TableScan(scan) = node {
if Arc::ptr_eq(&scan.source, work_table_source) {
has_reference = true;
return Ok(TreeNodeRecursion::Stop);
}
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
has_reference
}