use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
use std::collections::{HashMap, HashSet};
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn plan_table_with_joins(
&self,
t: TableWithJoins,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let origin_planner_context = planner_context.clone();
let left = self.create_relation(t.relation, planner_context)?;
match t.joins.len() {
0 => {
*planner_context = origin_planner_context;
Ok(left)
}
_ => {
let mut joins = t.joins.into_iter();
*planner_context = origin_planner_context.clone();
let mut left = self.parse_relation_join(
left,
joins.next().unwrap(), planner_context,
)?;
for join in joins {
*planner_context = origin_planner_context.clone();
left = self.parse_relation_join(left, join, planner_context)?;
}
*planner_context = origin_planner_context;
Ok(left)
}
}
}
fn parse_relation_join(
&self,
left: LogicalPlan,
join: Join,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let right = self.create_relation(join.relation, planner_context)?;
match join.join_operator {
JoinOperator::LeftOuter(constraint) => {
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
}
JoinOperator::RightOuter(constraint) => {
self.parse_join(left, right, constraint, JoinType::Right, planner_context)
}
JoinOperator::Inner(constraint) => {
self.parse_join(left, right, constraint, JoinType::Inner, planner_context)
}
JoinOperator::FullOuter(constraint) => {
self.parse_join(left, right, constraint, JoinType::Full, planner_context)
}
JoinOperator::CrossJoin => self.parse_cross_join(left, right),
other => Err(DataFusionError::NotImplemented(format!(
"Unsupported JOIN operator {other:?}"
))),
}
}
fn parse_cross_join(
&self,
left: LogicalPlan,
right: LogicalPlan,
) -> Result<LogicalPlan> {
LogicalPlanBuilder::from(left).cross_join(right)?.build()
}
fn parse_join(
&self,
left: LogicalPlan,
right: LogicalPlan,
constraint: JoinConstraint,
join_type: JoinType,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
match constraint {
JoinConstraint::On(sql_expr) => {
let join_schema = left.schema().join(right.schema())?;
let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?;
ensure_any_column_reference_is_unambiguous(
&expr,
&[left.schema().clone(), right.schema().clone()],
)?;
let using_columns = expr.to_columns()?;
let filter = normalize_col_with_schemas(
expr,
&[left.schema(), right.schema()],
&[using_columns],
)?;
LogicalPlanBuilder::from(left)
.join(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
Some(filter),
)?
.build()
}
JoinConstraint::Using(idents) => {
let keys: Vec<Column> = idents
.into_iter()
.map(|x| Column::from_name(normalize_ident(x)))
.collect();
LogicalPlanBuilder::from(left)
.join_using(right, join_type, keys)?
.build()
}
JoinConstraint::Natural => {
let left_cols: HashSet<&String> = left
.schema()
.fields()
.iter()
.map(|f| f.field().name())
.collect();
let keys: Vec<Column> = right
.schema()
.fields()
.iter()
.map(|f| f.field().name())
.filter(|f| left_cols.contains(f))
.map(Column::from_name)
.collect();
if keys.is_empty() {
self.parse_cross_join(left, right)
} else {
LogicalPlanBuilder::from(left)
.join_using(right, join_type, keys)?
.build()
}
}
JoinConstraint::None => Err(DataFusionError::NotImplemented(
"NONE constraint is not supported".to_string(),
)),
}
}
}
fn ensure_any_column_reference_is_unambiguous(
expr: &Expr,
schemas: &[DFSchemaRef],
) -> Result<()> {
if schemas.len() == 1 {
return Ok(());
}
let referenced_cols = expr.to_columns()?;
let mut no_relation_cols = referenced_cols
.iter()
.filter_map(|col| {
if col.relation.is_none() {
Some((col.name.as_str(), 0))
} else {
None
}
})
.collect::<HashMap<&str, u8>>();
let ambiguous_col_name = schemas
.iter()
.flat_map(|schema| schema.fields())
.map(|field| field.name())
.find(|col_name| {
no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
matches!(
no_relation_cols.get_key_value(col_name.as_str()),
Some((_, 2..))
)
});
if let Some(col_name) = ambiguous_col_name {
let maybe_field = schemas
.iter()
.flat_map(|schema| {
schema
.field_with_unqualified_name(col_name)
.map(|f| f.qualified_name())
.ok()
})
.collect::<Vec<_>>();
Err(DataFusionError::Plan(format!(
"reference \'{}\' is ambiguous, could be {};",
col_name,
maybe_field.join(","),
)))
} else {
Ok(())
}
}