use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result};
use sqlparser::ast::TableConstraint;
use std::collections::HashSet;
use std::fmt::{Display, Formatter};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Constraint {
PrimaryKey(Vec<usize>),
Unique(Vec<usize>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Constraints {
inner: Vec<Constraint>,
}
impl Constraints {
pub fn empty() -> Self {
Constraints::new(vec![])
}
fn new(constraints: Vec<Constraint>) -> Self {
Self { inner: constraints }
}
pub fn new_from_table_constraints(
constraints: &[TableConstraint],
df_schema: &DFSchemaRef,
) -> Result<Self> {
let constraints = constraints
.iter()
.map(|c: &TableConstraint| match c {
TableConstraint::Unique {
columns,
is_primary,
..
} => {
let indices = columns
.iter()
.map(|pk| {
let idx = df_schema
.fields()
.iter()
.position(|item| {
item.qualified_name() == pk.value.clone()
})
.ok_or_else(|| {
DataFusionError::Execution(
"Primary key doesn't exist".to_string(),
)
})?;
Ok(idx)
})
.collect::<Result<Vec<_>>>()?;
Ok(if *is_primary {
Constraint::PrimaryKey(indices)
} else {
Constraint::Unique(indices)
})
}
TableConstraint::ForeignKey { .. } => Err(DataFusionError::Plan(
"Foreign key constraints are not currently supported".to_string(),
)),
TableConstraint::Check { .. } => Err(DataFusionError::Plan(
"Check constraints are not currently supported".to_string(),
)),
TableConstraint::Index { .. } => Err(DataFusionError::Plan(
"Indexes are not currently supported".to_string(),
)),
TableConstraint::FulltextOrSpatial { .. } => Err(DataFusionError::Plan(
"Indexes are not currently supported".to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
Ok(Constraints::new(constraints))
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl Display for Constraints {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let pk: Vec<String> = self.inner.iter().map(|c| format!("{:?}", c)).collect();
let pk = pk.join(", ");
if !pk.is_empty() {
write!(f, " constraints=[{pk}]")
} else {
write!(f, "")
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependence {
pub source_indices: Vec<usize>,
pub target_indices: Vec<usize>,
pub nullable: bool,
pub mode: Dependency,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dependency {
Single, Multi, }
impl FunctionalDependence {
pub fn new(
source_indices: Vec<usize>,
target_indices: Vec<usize>,
nullable: bool,
) -> Self {
Self {
source_indices,
target_indices,
nullable,
mode: Dependency::Multi,
}
}
pub fn with_mode(mut self, mode: Dependency) -> Self {
self.mode = mode;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependencies {
deps: Vec<FunctionalDependence>,
}
impl FunctionalDependencies {
pub fn empty() -> Self {
Self { deps: vec![] }
}
pub fn new(dependencies: Vec<FunctionalDependence>) -> Self {
Self { deps: dependencies }
}
pub fn new_from_constraints(
constraints: Option<&Constraints>,
n_field: usize,
) -> Self {
if let Some(Constraints { inner: constraints }) = constraints {
let dependencies = constraints
.iter()
.map(|constraint| {
let dependency = match constraint {
Constraint::PrimaryKey(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
false,
),
Constraint::Unique(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
true,
),
};
dependency.with_mode(Dependency::Single)
})
.collect::<Vec<_>>();
Self::new(dependencies)
} else {
Self::empty()
}
}
pub fn with_dependency(mut self, mode: Dependency) -> Self {
self.deps.iter_mut().for_each(|item| item.mode = mode);
self
}
pub fn extend(&mut self, other: FunctionalDependencies) {
self.deps.extend(other.deps);
}
pub fn add_offset(&mut self, offset: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
source_indices,
target_indices,
..
}| {
*source_indices = add_offset_to_vec(source_indices, offset);
*target_indices = add_offset_to_vec(target_indices, offset);
},
)
}
pub fn project_functional_dependencies(
&self,
proj_indices: &[usize],
n_out: usize,
) -> FunctionalDependencies {
let mut projected_func_dependencies = vec![];
for FunctionalDependence {
source_indices,
target_indices,
nullable,
mode,
} in &self.deps
{
let new_source_indices =
update_elements_with_matching_indices(source_indices, proj_indices);
let new_target_indices = if *mode == Dependency::Single {
(0..n_out).collect()
} else {
update_elements_with_matching_indices(target_indices, proj_indices)
};
if new_source_indices.len() == source_indices.len() {
let new_func_dependence = FunctionalDependence::new(
new_source_indices,
new_target_indices,
*nullable,
)
.with_mode(*mode);
projected_func_dependencies.push(new_func_dependence);
}
}
FunctionalDependencies::new(projected_func_dependencies)
}
pub fn join(
&self,
other: &FunctionalDependencies,
join_type: &JoinType,
left_cols_len: usize,
) -> FunctionalDependencies {
let mut right_func_dependencies = other.clone();
let mut left_func_dependencies = self.clone();
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right => {
right_func_dependencies.add_offset(left_cols_len);
left_func_dependencies =
left_func_dependencies.with_dependency(Dependency::Multi);
right_func_dependencies =
right_func_dependencies.with_dependency(Dependency::Multi);
if *join_type == JoinType::Left {
right_func_dependencies.downgrade_dependencies();
} else if *join_type == JoinType::Right {
left_func_dependencies.downgrade_dependencies();
}
left_func_dependencies.extend(right_func_dependencies);
left_func_dependencies
}
JoinType::LeftSemi | JoinType::LeftAnti => {
left_func_dependencies
}
JoinType::RightSemi | JoinType::RightAnti => {
right_func_dependencies
}
JoinType::Full => {
FunctionalDependencies::empty()
}
}
}
fn downgrade_dependencies(&mut self) {
self.deps.retain(|item| !item.nullable);
self.deps.iter_mut().for_each(|item| item.nullable = true);
}
pub fn extend_target_indices(&mut self, n_out: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
mode,
target_indices,
..
}| {
if *mode == Dependency::Single {
*target_indices = (0..n_out).collect::<Vec<_>>();
}
},
)
}
}
pub fn aggregate_functional_dependencies(
aggr_input_schema: &DFSchema,
group_by_expr_names: &[String],
aggr_schema: &DFSchema,
) -> FunctionalDependencies {
let mut aggregate_func_dependencies = vec![];
let aggr_input_fields = aggr_input_schema.fields();
let aggr_fields = aggr_schema.fields();
let target_indices = (0..aggr_schema.fields().len()).collect::<Vec<_>>();
let func_dependencies = aggr_input_schema.functional_dependencies();
for FunctionalDependence {
source_indices,
nullable,
mode,
..
} in &func_dependencies.deps
{
let mut new_source_indices = HashSet::new();
let source_field_names = source_indices
.iter()
.map(|&idx| aggr_input_fields[idx].qualified_name())
.collect::<Vec<_>>();
for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() {
if source_field_names.contains(group_by_expr_name) {
new_source_indices.insert(idx);
}
}
if new_source_indices.len() == source_indices.len() {
aggregate_func_dependencies.push(
FunctionalDependence::new(
new_source_indices.into_iter().collect(),
target_indices.clone(),
*nullable,
)
.with_mode(*mode),
);
}
}
if group_by_expr_names.len() == 1 {
if let Some(idx) = aggregate_func_dependencies
.iter()
.position(|item| item.source_indices.contains(&0))
{
aggregate_func_dependencies.remove(idx);
}
aggregate_func_dependencies.push(
FunctionalDependence::new(
vec![0],
target_indices,
aggr_fields[0].is_nullable(),
)
.with_mode(Dependency::Single),
);
}
FunctionalDependencies::new(aggregate_func_dependencies)
}
pub fn get_target_functional_dependencies(
schema: &DFSchema,
group_by_expr_names: &[String],
) -> Option<Vec<usize>> {
let mut combined_target_indices = HashSet::new();
let dependencies = schema.functional_dependencies();
let field_names = schema
.fields()
.iter()
.map(|item| item.qualified_name())
.collect::<Vec<_>>();
for FunctionalDependence {
source_indices,
target_indices,
..
} in &dependencies.deps
{
let source_key_names = source_indices
.iter()
.map(|id_key_idx| field_names[*id_key_idx].clone())
.collect::<Vec<_>>();
if source_key_names
.iter()
.all(|source_key_name| group_by_expr_names.contains(source_key_name))
{
combined_target_indices.extend(target_indices.iter());
}
}
(!combined_target_indices.is_empty())
.then_some(combined_target_indices.iter().cloned().collect::<Vec<_>>())
}
fn update_elements_with_matching_indices(
entries: &[usize],
proj_indices: &[usize],
) -> Vec<usize> {
entries
.iter()
.filter_map(|val| proj_indices.iter().position(|proj_idx| proj_idx == val))
.collect()
}
fn add_offset_to_vec<T: Copy + std::ops::Add<Output = T>>(
in_data: &[T],
offset: T,
) -> Vec<T> {
in_data.iter().map(|&item| item + offset).collect()
}