use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::ops::Deref;
use std::vec::IntoIter;
use crate::error::_plan_err;
use crate::utils::{merge_and_order_indices, set_difference};
use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result};
use sqlparser::ast::TableConstraint;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Constraint {
PrimaryKey(Vec<usize>),
Unique(Vec<usize>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Constraints {
inner: Vec<Constraint>,
}
impl Constraints {
pub fn empty() -> Self {
Constraints::new_unverified(vec![])
}
pub fn new_unverified(constraints: Vec<Constraint>) -> Self {
Self { inner: constraints }
}
pub fn new_from_table_constraints(
constraints: &[TableConstraint],
df_schema: &DFSchemaRef,
) -> Result<Self> {
let constraints = constraints
.iter()
.map(|c: &TableConstraint| match c {
TableConstraint::Unique { name, columns, .. } => {
let field_names = df_schema.field_names();
let indices = columns
.iter()
.map(|u| {
let idx = field_names
.iter()
.position(|item| *item == u.value)
.ok_or_else(|| {
let name = name
.as_ref()
.map(|name| format!("with name '{name}' "))
.unwrap_or("".to_string());
DataFusionError::Execution(
format!("Column for unique constraint {}not found in schema: {}", name,u.value)
)
})?;
Ok(idx)
})
.collect::<Result<Vec<_>>>()?;
Ok(Constraint::Unique(indices))
}
TableConstraint::PrimaryKey { columns, .. } => {
let field_names = df_schema.field_names();
let indices = columns
.iter()
.map(|pk| {
let idx = field_names
.iter()
.position(|item| *item == pk.value)
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Column for primary key not found in schema: {}",
pk.value
))
})?;
Ok(idx)
})
.collect::<Result<Vec<_>>>()?;
Ok(Constraint::PrimaryKey(indices))
}
TableConstraint::ForeignKey { .. } => {
_plan_err!("Foreign key constraints are not currently supported")
}
TableConstraint::Check { .. } => {
_plan_err!("Check constraints are not currently supported")
}
TableConstraint::Index { .. } => {
_plan_err!("Indexes are not currently supported")
}
TableConstraint::FulltextOrSpatial { .. } => {
_plan_err!("Indexes are not currently supported")
}
})
.collect::<Result<Vec<_>>>()?;
Ok(Constraints::new_unverified(constraints))
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl IntoIterator for Constraints {
type Item = Constraint;
type IntoIter = IntoIter<Constraint>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
impl Display for Constraints {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let pk: Vec<String> = self.inner.iter().map(|c| format!("{:?}", c)).collect();
let pk = pk.join(", ");
if !pk.is_empty() {
write!(f, " constraints=[{pk}]")
} else {
write!(f, "")
}
}
}
impl Deref for Constraints {
type Target = [Constraint];
fn deref(&self) -> &Self::Target {
self.inner.as_slice()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependence {
pub source_indices: Vec<usize>,
pub target_indices: Vec<usize>,
pub nullable: bool,
pub mode: Dependency,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dependency {
Single, Multi, }
impl FunctionalDependence {
pub fn new(
source_indices: Vec<usize>,
target_indices: Vec<usize>,
nullable: bool,
) -> Self {
Self {
source_indices,
target_indices,
nullable,
mode: Dependency::Multi,
}
}
pub fn with_mode(mut self, mode: Dependency) -> Self {
self.mode = mode;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependencies {
deps: Vec<FunctionalDependence>,
}
impl FunctionalDependencies {
pub fn empty() -> Self {
Self { deps: vec![] }
}
pub fn new(dependencies: Vec<FunctionalDependence>) -> Self {
Self { deps: dependencies }
}
pub fn new_from_constraints(
constraints: Option<&Constraints>,
n_field: usize,
) -> Self {
if let Some(Constraints { inner: constraints }) = constraints {
let dependencies = constraints
.iter()
.map(|constraint| {
let dependency = match constraint {
Constraint::PrimaryKey(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
false,
),
Constraint::Unique(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
true,
),
};
dependency.with_mode(Dependency::Single)
})
.collect::<Vec<_>>();
Self::new(dependencies)
} else {
Self::empty()
}
}
pub fn with_dependency(mut self, mode: Dependency) -> Self {
self.deps.iter_mut().for_each(|item| item.mode = mode);
self
}
pub fn extend(&mut self, other: FunctionalDependencies) {
self.deps.extend(other.deps);
}
pub fn is_valid(&self, n_field: usize) -> bool {
self.deps.iter().all(
|FunctionalDependence {
source_indices,
target_indices,
..
}| {
source_indices
.iter()
.max()
.map(|&max_index| max_index < n_field)
.unwrap_or(true)
&& target_indices
.iter()
.max()
.map(|&max_index| max_index < n_field)
.unwrap_or(true)
},
)
}
pub fn add_offset(&mut self, offset: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
source_indices,
target_indices,
..
}| {
*source_indices = add_offset_to_vec(source_indices, offset);
*target_indices = add_offset_to_vec(target_indices, offset);
},
)
}
pub fn project_functional_dependencies(
&self,
proj_indices: &[usize],
n_out: usize,
) -> FunctionalDependencies {
let mut projected_func_dependencies = vec![];
for FunctionalDependence {
source_indices,
target_indices,
nullable,
mode,
} in &self.deps
{
let new_source_indices =
update_elements_with_matching_indices(source_indices, proj_indices);
let new_target_indices = if *mode == Dependency::Single {
(0..n_out).collect()
} else {
update_elements_with_matching_indices(target_indices, proj_indices)
};
if new_source_indices.len() == source_indices.len() {
let new_func_dependence = FunctionalDependence::new(
new_source_indices,
new_target_indices,
*nullable,
)
.with_mode(*mode);
projected_func_dependencies.push(new_func_dependence);
}
}
FunctionalDependencies::new(projected_func_dependencies)
}
pub fn join(
&self,
other: &FunctionalDependencies,
join_type: &JoinType,
left_cols_len: usize,
) -> FunctionalDependencies {
let mut right_func_dependencies = other.clone();
let mut left_func_dependencies = self.clone();
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right => {
right_func_dependencies.add_offset(left_cols_len);
left_func_dependencies =
left_func_dependencies.with_dependency(Dependency::Multi);
right_func_dependencies =
right_func_dependencies.with_dependency(Dependency::Multi);
if *join_type == JoinType::Left {
right_func_dependencies.downgrade_dependencies();
} else if *join_type == JoinType::Right {
left_func_dependencies.downgrade_dependencies();
}
left_func_dependencies.extend(right_func_dependencies);
left_func_dependencies
}
JoinType::LeftSemi | JoinType::LeftAnti => {
left_func_dependencies
}
JoinType::RightSemi | JoinType::RightAnti => {
right_func_dependencies
}
JoinType::Full => {
FunctionalDependencies::empty()
}
}
}
fn downgrade_dependencies(&mut self) {
self.deps.retain(|item| !item.nullable);
self.deps.iter_mut().for_each(|item| item.nullable = true);
}
pub fn extend_target_indices(&mut self, n_out: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
mode,
target_indices,
..
}| {
if *mode == Dependency::Single {
*target_indices = (0..n_out).collect::<Vec<_>>();
}
},
)
}
}
impl Deref for FunctionalDependencies {
type Target = [FunctionalDependence];
fn deref(&self) -> &Self::Target {
self.deps.as_slice()
}
}
pub fn aggregate_functional_dependencies(
aggr_input_schema: &DFSchema,
group_by_expr_names: &[String],
aggr_schema: &DFSchema,
) -> FunctionalDependencies {
let mut aggregate_func_dependencies = vec![];
let aggr_input_fields = aggr_input_schema.field_names();
let aggr_fields = aggr_schema.fields();
let target_indices = (0..aggr_schema.fields().len()).collect::<Vec<_>>();
let func_dependencies = aggr_input_schema.functional_dependencies();
for FunctionalDependence {
source_indices,
nullable,
mode,
..
} in &func_dependencies.deps
{
let mut new_source_indices = vec![];
let mut new_source_field_names = vec![];
let source_field_names = source_indices
.iter()
.map(|&idx| &aggr_input_fields[idx])
.collect::<Vec<_>>();
for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() {
if source_field_names.contains(&group_by_expr_name) {
new_source_indices.push(idx);
new_source_field_names.push(group_by_expr_name.clone());
}
}
let existing_target_indices =
get_target_functional_dependencies(aggr_input_schema, group_by_expr_names);
let new_target_indices = get_target_functional_dependencies(
aggr_input_schema,
&new_source_field_names,
);
let mode = if existing_target_indices == new_target_indices
&& new_target_indices.is_some()
{
Dependency::Single
} else {
*mode
};
if new_source_indices.len() == source_indices.len() {
aggregate_func_dependencies.push(
FunctionalDependence::new(
new_source_indices,
target_indices.clone(),
*nullable,
)
.with_mode(mode),
);
}
}
if group_by_expr_names.len() == 1 {
aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0));
aggregate_func_dependencies.push(
FunctionalDependence::new(
vec![0],
target_indices,
aggr_fields[0].is_nullable(),
)
.with_mode(Dependency::Single),
);
}
FunctionalDependencies::new(aggregate_func_dependencies)
}
pub fn get_target_functional_dependencies(
schema: &DFSchema,
group_by_expr_names: &[String],
) -> Option<Vec<usize>> {
let mut combined_target_indices = HashSet::new();
let dependencies = schema.functional_dependencies();
let field_names = schema.field_names();
for FunctionalDependence {
source_indices,
target_indices,
..
} in &dependencies.deps
{
let source_key_names = source_indices
.iter()
.map(|id_key_idx| &field_names[*id_key_idx])
.collect::<Vec<_>>();
if source_key_names
.iter()
.all(|source_key_name| group_by_expr_names.contains(source_key_name))
{
combined_target_indices.extend(target_indices.iter());
}
}
(!combined_target_indices.is_empty()).then_some({
let mut result = combined_target_indices.into_iter().collect::<Vec<_>>();
result.sort();
result
})
}
pub fn get_required_group_by_exprs_indices(
schema: &DFSchema,
group_by_expr_names: &[String],
) -> Option<Vec<usize>> {
let dependencies = schema.functional_dependencies();
let field_names = schema.field_names();
let mut groupby_expr_indices = group_by_expr_names
.iter()
.map(|group_by_expr_name| {
field_names
.iter()
.position(|field_name| field_name == group_by_expr_name)
})
.collect::<Option<Vec<_>>>()?;
groupby_expr_indices.sort();
for FunctionalDependence {
source_indices,
target_indices,
..
} in &dependencies.deps
{
if source_indices
.iter()
.all(|source_idx| groupby_expr_indices.contains(source_idx))
{
groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices);
groupby_expr_indices =
merge_and_order_indices(groupby_expr_indices, source_indices);
}
}
groupby_expr_indices
.iter()
.map(|idx| {
group_by_expr_names
.iter()
.position(|name| &field_names[*idx] == name)
})
.collect()
}
fn update_elements_with_matching_indices(
entries: &[usize],
proj_indices: &[usize],
) -> Vec<usize> {
entries
.iter()
.filter_map(|val| proj_indices.iter().position(|proj_idx| proj_idx == val))
.collect()
}
fn add_offset_to_vec<T: Copy + std::ops::Add<Output = T>>(
in_data: &[T],
offset: T,
) -> Vec<T> {
in_data.iter().map(|&item| item + offset).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constraints_iter() {
let constraints = Constraints::new_unverified(vec![
Constraint::PrimaryKey(vec![10]),
Constraint::Unique(vec![20]),
]);
let mut iter = constraints.iter();
assert_eq!(iter.next(), Some(&Constraint::PrimaryKey(vec![10])));
assert_eq!(iter.next(), Some(&Constraint::Unique(vec![20])));
assert_eq!(iter.next(), None);
}
#[test]
fn test_get_updated_id_keys() {
let fund_dependencies =
FunctionalDependencies::new(vec![FunctionalDependence::new(
vec![1],
vec![0, 1, 2],
true,
)]);
let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2);
let expected = FunctionalDependencies::new(vec![FunctionalDependence::new(
vec![0],
vec![0, 1],
true,
)]);
assert_eq!(res, expected);
}
}