use std::fmt::{Display, Formatter};
use std::ops::Deref;
use std::vec::IntoIter;
use crate::utils::{merge_and_order_indices, set_difference};
use crate::{DFSchema, HashSet, JoinType};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Constraint {
PrimaryKey(Vec<usize>),
Unique(Vec<usize>),
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Constraints {
inner: Vec<Constraint>,
}
impl Constraints {
pub fn empty() -> Self {
Constraints::new_unverified(vec![])
}
pub fn new_unverified(constraints: Vec<Constraint>) -> Self {
Self { inner: constraints }
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn project(&self, proj_indices: &[usize]) -> Option<Self> {
let projected = self
.inner
.iter()
.filter_map(|constraint| {
match constraint {
Constraint::PrimaryKey(indices) => {
let new_indices =
update_elements_with_matching_indices(indices, proj_indices);
(new_indices.len() == indices.len())
.then_some(Constraint::PrimaryKey(new_indices))
}
Constraint::Unique(indices) => {
let new_indices =
update_elements_with_matching_indices(indices, proj_indices);
(new_indices.len() == indices.len())
.then_some(Constraint::Unique(new_indices))
}
}
})
.collect::<Vec<_>>();
(!projected.is_empty()).then_some(Constraints::new_unverified(projected))
}
}
impl Default for Constraints {
fn default() -> Self {
Constraints::empty()
}
}
impl IntoIterator for Constraints {
type Item = Constraint;
type IntoIter = IntoIter<Constraint>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
impl Display for Constraints {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let pk = self
.inner
.iter()
.map(|c| format!("{c:?}"))
.collect::<Vec<_>>();
let pk = pk.join(", ");
write!(f, "constraints=[{pk}]")
}
}
impl Deref for Constraints {
type Target = [Constraint];
fn deref(&self) -> &Self::Target {
self.inner.as_slice()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependence {
pub source_indices: Vec<usize>,
pub target_indices: Vec<usize>,
pub nullable: bool,
pub mode: Dependency,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dependency {
Single, Multi, }
impl FunctionalDependence {
pub fn new(
source_indices: Vec<usize>,
target_indices: Vec<usize>,
nullable: bool,
) -> Self {
Self {
source_indices,
target_indices,
nullable,
mode: Dependency::Multi,
}
}
pub fn with_mode(mut self, mode: Dependency) -> Self {
self.mode = mode;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FunctionalDependencies {
deps: Vec<FunctionalDependence>,
}
impl FunctionalDependencies {
pub fn empty() -> Self {
Self { deps: vec![] }
}
pub fn new(dependencies: Vec<FunctionalDependence>) -> Self {
Self { deps: dependencies }
}
pub fn new_from_constraints(
constraints: Option<&Constraints>,
n_field: usize,
) -> Self {
if let Some(Constraints { inner: constraints }) = constraints {
let dependencies = constraints
.iter()
.map(|constraint| {
let dependency = match constraint {
Constraint::PrimaryKey(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
false,
),
Constraint::Unique(indices) => FunctionalDependence::new(
indices.to_vec(),
(0..n_field).collect::<Vec<_>>(),
true,
),
};
dependency.with_mode(Dependency::Single)
})
.collect::<Vec<_>>();
Self::new(dependencies)
} else {
Self::empty()
}
}
pub fn with_dependency(mut self, mode: Dependency) -> Self {
self.deps.iter_mut().for_each(|item| item.mode = mode);
self
}
pub fn extend(&mut self, other: FunctionalDependencies) {
self.deps.extend(other.deps);
}
pub fn is_valid(&self, n_field: usize) -> bool {
self.deps.iter().all(
|FunctionalDependence {
source_indices,
target_indices,
..
}| {
source_indices
.iter()
.max()
.map(|&max_index| max_index < n_field)
.unwrap_or(true)
&& target_indices
.iter()
.max()
.map(|&max_index| max_index < n_field)
.unwrap_or(true)
},
)
}
pub fn add_offset(&mut self, offset: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
source_indices,
target_indices,
..
}| {
*source_indices = add_offset_to_vec(source_indices, offset);
*target_indices = add_offset_to_vec(target_indices, offset);
},
)
}
pub fn project_functional_dependencies(
&self,
proj_indices: &[usize],
n_out: usize,
) -> FunctionalDependencies {
let mut projected_func_dependencies = vec![];
for FunctionalDependence {
source_indices,
target_indices,
nullable,
mode,
} in &self.deps
{
let new_source_indices =
update_elements_with_matching_indices(source_indices, proj_indices);
let new_target_indices = if *mode == Dependency::Single {
(0..n_out).collect()
} else {
update_elements_with_matching_indices(target_indices, proj_indices)
};
if new_source_indices.len() == source_indices.len() {
let new_func_dependence = FunctionalDependence::new(
new_source_indices,
new_target_indices,
*nullable,
)
.with_mode(*mode);
projected_func_dependencies.push(new_func_dependence);
}
}
FunctionalDependencies::new(projected_func_dependencies)
}
pub fn join(
&self,
other: &FunctionalDependencies,
join_type: &JoinType,
left_cols_len: usize,
) -> FunctionalDependencies {
let mut right_func_dependencies = other.clone();
let mut left_func_dependencies = self.clone();
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right => {
right_func_dependencies.add_offset(left_cols_len);
left_func_dependencies =
left_func_dependencies.with_dependency(Dependency::Multi);
right_func_dependencies =
right_func_dependencies.with_dependency(Dependency::Multi);
if *join_type == JoinType::Left {
right_func_dependencies.downgrade_dependencies();
} else if *join_type == JoinType::Right {
left_func_dependencies.downgrade_dependencies();
}
left_func_dependencies.extend(right_func_dependencies);
left_func_dependencies
}
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
left_func_dependencies
}
JoinType::RightSemi | JoinType::RightAnti => {
right_func_dependencies
}
JoinType::Full => {
FunctionalDependencies::empty()
}
}
}
fn downgrade_dependencies(&mut self) {
self.deps.retain(|item| !item.nullable);
self.deps.iter_mut().for_each(|item| item.nullable = true);
}
pub fn extend_target_indices(&mut self, n_out: usize) {
self.deps.iter_mut().for_each(
|FunctionalDependence {
mode,
target_indices,
..
}| {
if *mode == Dependency::Single {
*target_indices = (0..n_out).collect::<Vec<_>>();
}
},
)
}
}
impl Deref for FunctionalDependencies {
type Target = [FunctionalDependence];
fn deref(&self) -> &Self::Target {
self.deps.as_slice()
}
}
pub fn aggregate_functional_dependencies(
aggr_input_schema: &DFSchema,
group_by_expr_names: &[String],
aggr_schema: &DFSchema,
) -> FunctionalDependencies {
let mut aggregate_func_dependencies = vec![];
let aggr_input_fields = aggr_input_schema.field_names();
let aggr_fields = aggr_schema.fields();
let target_indices = (0..aggr_schema.fields().len()).collect::<Vec<_>>();
let func_dependencies = aggr_input_schema.functional_dependencies();
for FunctionalDependence {
source_indices,
nullable,
mode,
..
} in &func_dependencies.deps
{
let mut new_source_indices = vec![];
let mut new_source_field_names = vec![];
let source_field_names = source_indices
.iter()
.map(|&idx| &aggr_input_fields[idx])
.collect::<Vec<_>>();
for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() {
if source_field_names.contains(&group_by_expr_name) {
new_source_indices.push(idx);
new_source_field_names.push(group_by_expr_name.clone());
}
}
let existing_target_indices =
get_target_functional_dependencies(aggr_input_schema, group_by_expr_names);
let new_target_indices = get_target_functional_dependencies(
aggr_input_schema,
&new_source_field_names,
);
let mode = if existing_target_indices == new_target_indices
&& new_target_indices.is_some()
{
Dependency::Single
} else {
*mode
};
if new_source_indices.len() == source_indices.len() {
aggregate_func_dependencies.push(
FunctionalDependence::new(
new_source_indices,
target_indices.clone(),
*nullable,
)
.with_mode(mode),
);
}
}
if !group_by_expr_names.is_empty() {
let count = group_by_expr_names.len();
let source_indices = (0..count).collect::<Vec<_>>();
let nullable = source_indices
.iter()
.any(|idx| aggr_fields[*idx].is_nullable());
if !aggregate_func_dependencies.iter().any(|item| {
item.source_indices.iter().all(|idx| idx < &count)
}) {
aggregate_func_dependencies.push(
FunctionalDependence::new(source_indices, target_indices, nullable)
.with_mode(Dependency::Single),
);
}
}
FunctionalDependencies::new(aggregate_func_dependencies)
}
pub fn get_target_functional_dependencies(
schema: &DFSchema,
group_by_expr_names: &[String],
) -> Option<Vec<usize>> {
let mut combined_target_indices = HashSet::new();
let dependencies = schema.functional_dependencies();
let field_names = schema.field_names();
for FunctionalDependence {
source_indices,
target_indices,
..
} in &dependencies.deps
{
let source_key_names = source_indices
.iter()
.map(|id_key_idx| &field_names[*id_key_idx])
.collect::<Vec<_>>();
if source_key_names
.iter()
.all(|source_key_name| group_by_expr_names.contains(source_key_name))
{
combined_target_indices.extend(target_indices.iter());
}
}
(!combined_target_indices.is_empty()).then_some({
let mut result = combined_target_indices.into_iter().collect::<Vec<_>>();
result.sort();
result
})
}
pub fn get_required_group_by_exprs_indices(
schema: &DFSchema,
group_by_expr_names: &[String],
) -> Option<Vec<usize>> {
let dependencies = schema.functional_dependencies();
let field_names = schema.field_names();
let mut groupby_expr_indices = group_by_expr_names
.iter()
.map(|group_by_expr_name| {
field_names
.iter()
.position(|field_name| field_name == group_by_expr_name)
})
.collect::<Option<Vec<_>>>()?;
groupby_expr_indices.sort();
for FunctionalDependence {
source_indices,
target_indices,
..
} in &dependencies.deps
{
if source_indices
.iter()
.all(|source_idx| groupby_expr_indices.contains(source_idx))
{
groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices);
groupby_expr_indices =
merge_and_order_indices(groupby_expr_indices, source_indices);
}
}
groupby_expr_indices
.iter()
.map(|idx| {
group_by_expr_names
.iter()
.position(|name| &field_names[*idx] == name)
})
.collect()
}
fn update_elements_with_matching_indices(
entries: &[usize],
proj_indices: &[usize],
) -> Vec<usize> {
entries
.iter()
.filter_map(|val| proj_indices.iter().position(|proj_idx| proj_idx == val))
.collect()
}
fn add_offset_to_vec<T: Copy + std::ops::Add<Output = T>>(
in_data: &[T],
offset: T,
) -> Vec<T> {
in_data.iter().map(|&item| item + offset).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constraints_iter() {
let constraints = Constraints::new_unverified(vec![
Constraint::PrimaryKey(vec![10]),
Constraint::Unique(vec![20]),
]);
let mut iter = constraints.iter();
assert_eq!(iter.next(), Some(&Constraint::PrimaryKey(vec![10])));
assert_eq!(iter.next(), Some(&Constraint::Unique(vec![20])));
assert_eq!(iter.next(), None);
}
#[test]
fn test_project_constraints() {
let constraints = Constraints::new_unverified(vec![
Constraint::PrimaryKey(vec![1, 2]),
Constraint::Unique(vec![0, 3]),
]);
let projected = constraints.project(&[1, 2, 3]).unwrap();
assert_eq!(
projected,
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0, 1])])
);
assert!(constraints.project(&[0]).is_none());
}
#[test]
fn test_get_updated_id_keys() {
let fund_dependencies =
FunctionalDependencies::new(vec![FunctionalDependence::new(
vec![1],
vec![0, 1, 2],
true,
)]);
let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2);
let expected = FunctionalDependencies::new(vec![FunctionalDependence::new(
vec![0],
vec![0, 1],
true,
)]);
assert_eq!(res, expected);
}
}