use std::{any::Any, sync::Arc};
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use futures::StreamExt;
use log::debug;
use super::{
expressions::PhysicalSortExpr,
metrics::{ExecutionPlanMetricsSet, MetricsSet},
ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
use crate::execution::context::TaskContext;
use crate::{
error::Result,
physical_plan::{expressions, metrics::BaselineMetrics},
};
#[derive(Debug)]
pub struct UnionExec {
inputs: Vec<Arc<dyn ExecutionPlan>>,
metrics: ExecutionPlanMetricsSet,
}
impl UnionExec {
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
UnionExec {
inputs,
metrics: ExecutionPlanMetricsSet::new(),
}
}
pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
&self.inputs
}
}
impl ExecutionPlan for UnionExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.inputs[0].schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.inputs.clone()
}
fn output_partitioning(&self) -> Partitioning {
let num_partitions = self
.inputs
.iter()
.map(|plan| plan.output_partitioning().partition_count())
.sum();
Partitioning::UnknownPartitioning(num_partitions)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn relies_on_input_order(&self) -> bool {
false
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(UnionExec::new(children)))
}
fn execute(
&self,
mut partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
debug!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();
for input in self.inputs.iter() {
if partition < input.output_partitioning().partition_count() {
let stream = input.execute(partition, context)?;
debug!("Found a Union partition to execute");
return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics)));
} else {
partition -= input.output_partitioning().partition_count();
}
}
debug!("Error in Union: Partition {} not found", partition);
Err(crate::error::DataFusionError::Execution(format!(
"Partition {} not found in Union",
partition
)))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "UnionExec")
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Statistics {
self.inputs
.iter()
.map(|ep| ep.statistics())
.reduce(stats_union)
.unwrap_or_default()
}
fn benefits_from_input_partitioning(&self) -> bool {
false
}
}
struct ObservedStream {
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
}
impl ObservedStream {
fn new(inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics) -> Self {
Self {
inner,
baseline_metrics,
}
}
}
impl RecordBatchStream for ObservedStream {
fn schema(&self) -> arrow::datatypes::SchemaRef {
self.inner.schema()
}
}
impl futures::Stream for ObservedStream {
type Item = arrow::error::Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let poll = self.inner.poll_next_unpin(cx);
self.baseline_metrics.record_poll(poll)
}
}
fn col_stats_union(
mut left: ColumnStatistics,
right: ColumnStatistics,
) -> ColumnStatistics {
left.distinct_count = None;
left.min_value = left
.min_value
.zip(right.min_value)
.map(|(a, b)| expressions::helpers::min(&a, &b))
.and_then(Result::ok);
left.max_value = left
.max_value
.zip(right.max_value)
.map(|(a, b)| expressions::helpers::max(&a, &b))
.and_then(Result::ok);
left.null_count = left.null_count.zip(right.null_count).map(|(a, b)| a + b);
left
}
fn stats_union(mut left: Statistics, right: Statistics) -> Statistics {
left.is_exact = left.is_exact && right.is_exact;
left.num_rows = left.num_rows.zip(right.num_rows).map(|(a, b)| a + b);
left.total_byte_size = left
.total_byte_size
.zip(right.total_byte_size)
.map(|(a, b)| a + b);
left.column_statistics =
left.column_statistics
.zip(right.column_statistics)
.map(|(a, b)| {
a.into_iter()
.zip(b)
.map(|(ca, cb)| col_stats_union(ca, cb))
.collect()
});
left
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
use crate::prelude::SessionContext;
use crate::{physical_plan::collect, scalar::ScalarValue};
use arrow::record_batch::RecordBatch;
#[tokio::test]
async fn test_union_partitions() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let csv = test::scan_partitioned_csv(4)?;
let csv2 = test::scan_partitioned_csv(5)?;
let union_exec = Arc::new(UnionExec::new(vec![csv, csv2]));
assert_eq!(union_exec.output_partitioning().partition_count(), 9);
let result: Vec<RecordBatch> = collect(union_exec, task_ctx).await?;
assert_eq!(result.len(), 9);
Ok(())
}
#[tokio::test]
async fn test_stats_union() {
let left = Statistics {
is_exact: true,
num_rows: Some(5),
total_byte_size: Some(23),
column_statistics: Some(vec![
ColumnStatistics {
distinct_count: Some(5),
max_value: Some(ScalarValue::Int64(Some(21))),
min_value: Some(ScalarValue::Int64(Some(-4))),
null_count: Some(0),
},
ColumnStatistics {
distinct_count: Some(1),
max_value: Some(ScalarValue::Utf8(Some(String::from("x")))),
min_value: Some(ScalarValue::Utf8(Some(String::from("a")))),
null_count: Some(3),
},
ColumnStatistics {
distinct_count: None,
max_value: Some(ScalarValue::Float32(Some(1.1))),
min_value: Some(ScalarValue::Float32(Some(0.1))),
null_count: None,
},
]),
};
let right = Statistics {
is_exact: true,
num_rows: Some(7),
total_byte_size: Some(29),
column_statistics: Some(vec![
ColumnStatistics {
distinct_count: Some(3),
max_value: Some(ScalarValue::Int64(Some(34))),
min_value: Some(ScalarValue::Int64(Some(1))),
null_count: Some(1),
},
ColumnStatistics {
distinct_count: None,
max_value: Some(ScalarValue::Utf8(Some(String::from("c")))),
min_value: Some(ScalarValue::Utf8(Some(String::from("b")))),
null_count: None,
},
ColumnStatistics {
distinct_count: None,
max_value: None,
min_value: None,
null_count: None,
},
]),
};
let result = stats_union(left, right);
let expected = Statistics {
is_exact: true,
num_rows: Some(12),
total_byte_size: Some(52),
column_statistics: Some(vec![
ColumnStatistics {
distinct_count: None,
max_value: Some(ScalarValue::Int64(Some(34))),
min_value: Some(ScalarValue::Int64(Some(-4))),
null_count: Some(1),
},
ColumnStatistics {
distinct_count: None,
max_value: Some(ScalarValue::Utf8(Some(String::from("x")))),
min_value: Some(ScalarValue::Utf8(Some(String::from("a")))),
null_count: None,
},
ColumnStatistics {
distinct_count: None,
max_value: None,
min_value: None,
null_count: None,
},
]),
};
assert_eq!(result, expected);
}
}