use crate::error::{DataFusionError, Result};
use crate::logical_plan::JoinType;
use crate::physical_plan::expressions::Column;
use arrow::datatypes::{Field, Schema};
use arrow::error::ArrowError;
use datafusion_physical_expr::PhysicalExpr;
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
use parking_lot::Mutex;
use std::collections::HashSet;
use std::future::Future;
use std::sync::Arc;
use std::task::{Context, Poll};
pub type JoinOn = Vec<(Column, Column)>;
pub type JoinOnRef<'a> = &'a [(Column, Column)];
pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
let left: HashSet<Column> = left
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
let right: HashSet<Column> = right
.fields()
.iter()
.enumerate()
.map(|(idx, f)| Column::new(f.name(), idx))
.collect();
check_join_set_is_valid(&left, &right, on)
}
fn check_join_set_is_valid(
left: &HashSet<Column>,
right: &HashSet<Column>,
on: &[(Column, Column)],
) -> Result<()> {
let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
return Err(DataFusionError::Plan(format!(
"The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}",
left_missing,
right_missing,
)));
};
Ok(())
}
#[derive(Debug, Clone)]
pub enum JoinSide {
Left,
Right,
}
#[derive(Debug, Clone)]
pub struct ColumnIndex {
pub index: usize,
pub side: JoinSide,
}
#[derive(Debug, Clone)]
pub struct JoinFilter {
expression: Arc<dyn PhysicalExpr>,
column_indices: Vec<ColumnIndex>,
schema: Schema,
}
impl JoinFilter {
pub fn new(
expression: Arc<dyn PhysicalExpr>,
column_indices: Vec<ColumnIndex>,
schema: Schema,
) -> JoinFilter {
JoinFilter {
expression,
column_indices,
schema,
}
}
pub fn build_column_indices(
left_indices: Vec<usize>,
right_indices: Vec<usize>,
) -> Vec<ColumnIndex> {
left_indices
.into_iter()
.map(|i| ColumnIndex {
index: i,
side: JoinSide::Left,
})
.chain(right_indices.into_iter().map(|i| ColumnIndex {
index: i,
side: JoinSide::Right,
}))
.collect()
}
pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
&self.expression
}
pub fn column_indices(&self) -> &[ColumnIndex] {
&self.column_indices
}
pub fn schema(&self) -> &Schema {
&self.schema
}
}
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
let force_nullable = match join_type {
JoinType::Inner => false,
JoinType::Left => !is_left, JoinType::Right => is_left, JoinType::Full => true, JoinType::Semi => false, JoinType::Anti => false, };
if force_nullable {
old_field.clone().with_nullable(true)
} else {
old_field.clone()
}
}
pub fn build_join_schema(
left: &Schema,
right: &Schema,
join_type: &JoinType,
) -> (Schema, Vec<ColumnIndex>) {
let (fields, column_indices): (Vec<Field>, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let left_fields = left
.fields()
.iter()
.map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Left,
},
)
});
let right_fields = right
.fields()
.iter()
.map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});
left_fields.chain(right_fields).unzip()
}
JoinType::Semi | JoinType::Anti => left
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Left,
},
)
})
.unzip(),
};
(Schema::new(fields), column_indices)
}
pub(crate) struct OnceAsync<T> {
fut: Mutex<Option<OnceFut<T>>>,
}
impl<T> Default for OnceAsync<T> {
fn default() -> Self {
Self {
fut: Mutex::new(None),
}
}
}
impl<T> std::fmt::Debug for OnceAsync<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OnceAsync")
}
}
impl<T: 'static> OnceAsync<T> {
pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>> + Send + 'static,
{
self.fut
.lock()
.get_or_insert_with(|| OnceFut::new(f()))
.clone()
}
}
type OnceFutPending<T> = Shared<BoxFuture<'static, Arc<Result<T>>>>;
pub(crate) struct OnceFut<T> {
state: OnceFutState<T>,
}
impl<T> Clone for OnceFut<T> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
enum OnceFutState<T> {
Pending(OnceFutPending<T>),
Ready(Arc<Result<T>>),
}
impl<T> Clone for OnceFutState<T> {
fn clone(&self) -> Self {
match self {
Self::Pending(p) => Self::Pending(p.clone()),
Self::Ready(r) => Self::Ready(r.clone()),
}
}
}
impl<T: 'static> OnceFut<T> {
pub(crate) fn new<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = Result<T>> + Send + 'static,
{
Self {
state: OnceFutState::Pending(fut.map(Arc::new).boxed().shared()),
}
}
pub(crate) fn get(
&mut self,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<&T, ArrowError>> {
if let OnceFutState::Pending(fut) = &mut self.state {
let r = ready!(fut.poll_unpin(cx));
self.state = OnceFutState::Ready(r);
}
match &self.state {
OnceFutState::Pending(_) => unreachable!(),
OnceFutState::Ready(r) => Poll::Ready(
r.as_ref()
.as_ref()
.map_err(|e| ArrowError::ExternalError(e.to_string().into())),
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
let left = left
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
let right = right
.iter()
.map(|x| x.to_owned())
.collect::<HashSet<Column>>();
check_join_set_is_valid(&left, &right, on)
}
#[test]
fn check_valid() -> Result<()> {
let left = vec![Column::new("a", 0), Column::new("b1", 1)];
let right = vec![Column::new("a", 0), Column::new("b2", 1)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
check(&left, &right, on)?;
Ok(())
}
#[test]
fn check_not_in_right() {
let left = vec![Column::new("a", 0), Column::new("b", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_not_in_left() {
let left = vec![Column::new("b", 0)];
let right = vec![Column::new("a", 0)];
let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_collision() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("a", 0), Column::new("b", 1)];
let on = &[(Column::new("a", 0), Column::new("b", 1))];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn check_in_right() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("b", 0)];
let on = &[(Column::new("a", 0), Column::new("b", 0))];
assert!(check(&left, &right, on).is_ok());
}
#[test]
fn test_join_schema() -> Result<()> {
let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
let cases = vec![
(&a, &b, JoinType::Inner, &a, &b),
(&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
(&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
(&a, &b, JoinType::Left, &a, &b_nulls),
(&a, &b_nulls, JoinType::Left, &a, &b_nulls),
(&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
(&a, &b, JoinType::Right, &a_nulls, &b),
(&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Right, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
];
for (left_in, right_in, join_type, left_out, right_out) in cases {
let (schema, _) = build_join_schema(left_in, right_in, &join_type);
let expected_fields = left_out
.fields()
.iter()
.cloned()
.chain(right_out.fields().iter().cloned())
.collect();
let expected_schema = Schema::new(expected_fields);
assert_eq!(
schema,
expected_schema,
"Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
left_in.fields()[0].name(),
left_in.fields()[0].is_nullable(),
right_in.fields()[0].name(),
right_in.fields()[0].is_nullable(),
join_type
);
}
Ok(())
}
}