use std::any::Any;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use ahash::RandomState;
use futures::{
stream::{Stream, StreamExt},
Future,
};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::{
Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan,
Partitioning, PhysicalExpr,
};
use crate::scalar::ScalarValue;
use arrow::{array::ArrayRef, compute, compute::cast};
use arrow::{
array::{Array, UInt32Builder},
error::{ArrowError, Result as ArrowResult},
};
use arrow::{
datatypes::{Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use hashbrown::raw::RawTable;
use pin_project_lite::pin_project;
use crate::execution::runtime_env::RuntimeEnv;
use async_trait::async_trait;
use super::common::AbortOnDropSingle;
use super::expressions::PhysicalSortExpr;
use super::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
use super::Statistics;
use super::{expressions::Column, RecordBatchStream, SendableRecordBatchStream};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateMode {
Partial,
Final,
FinalPartitioned,
}
#[derive(Debug)]
pub struct HashAggregateExec {
mode: AggregateMode,
group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
input_schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
}
fn create_schema(
input_schema: &Schema,
group_expr: &[(Arc<dyn PhysicalExpr>, String)],
aggr_expr: &[Arc<dyn AggregateExpr>],
mode: AggregateMode,
) -> Result<Schema> {
let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
for (expr, name) in group_expr {
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
expr.nullable(input_schema)?,
))
}
match mode {
AggregateMode::Partial => {
for expr in aggr_expr {
fields.extend(expr.state_fields()?.iter().cloned())
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
for expr in aggr_expr {
fields.push(expr.field()?)
}
}
}
Ok(Schema::new(fields))
}
impl HashAggregateExec {
pub fn try_new(
mode: AggregateMode,
group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let schema = create_schema(&input.schema(), &group_expr, &aggr_expr, mode)?;
let schema = Arc::new(schema);
Ok(HashAggregateExec {
mode,
group_expr,
aggr_expr,
input,
schema,
input_schema,
metrics: ExecutionPlanMetricsSet::new(),
})
}
pub fn mode(&self) -> &AggregateMode {
&self.mode
}
pub fn group_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.group_expr
}
pub fn aggr_expr(&self) -> &[Arc<dyn AggregateExpr>] {
&self.aggr_expr
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
}
#[async_trait]
impl ExecutionPlan for HashAggregateExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn required_child_distribution(&self) -> Distribution {
match &self.mode {
AggregateMode::Partial => Distribution::UnspecifiedDistribution,
AggregateMode::FinalPartitioned => Distribution::HashPartitioned(
self.group_expr.iter().map(|x| x.0.clone()).collect(),
),
AggregateMode::Final => Distribution::SinglePartition,
}
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn relies_on_input_order(&self) -> bool {
false
}
async fn execute(
&self,
partition: usize,
runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, runtime).await?;
let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
if self.group_expr.is_empty() {
Ok(Box::pin(HashAggregateStream::new(
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
)))
} else {
Ok(Box::pin(GroupedHashAggregateStream::new(
self.mode,
self.schema.clone(),
group_expr,
self.aggr_expr.clone(),
input,
baseline_metrics,
)))
}
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match children.len() {
1 => Ok(Arc::new(HashAggregateExec::try_new(
self.mode,
self.group_expr.clone(),
self.aggr_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
)?)),
_ => Err(DataFusionError::Internal(
"HashAggregateExec wrong number of children".to_string(),
)),
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "HashAggregateExec: mode={:?}", self.mode)?;
let g: Vec<String> = self
.group_expr
.iter()
.map(|(e, alias)| {
let e = e.to_string();
if &e != alias {
format!("{} as {}", e, alias)
} else {
e
}
})
.collect();
write!(f, ", gby=[{}]", g.join(", "))?;
let a: Vec<String> = self
.aggr_expr
.iter()
.map(|agg| agg.name().to_string())
.collect();
write!(f, ", aggr=[{}]", a.join(", "))?;
}
}
Ok(())
}
fn statistics(&self) -> Statistics {
match self.mode {
AggregateMode::Final | AggregateMode::FinalPartitioned
if self.group_expr.is_empty() =>
{
Statistics {
num_rows: Some(1),
is_exact: true,
..Default::default()
}
}
_ => Statistics::default(),
}
}
}
pin_project! {
struct GroupedHashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
}
fn group_aggregate_batch(
mode: &AggregateMode,
random_state: &RandomState,
group_expr: &[Arc<dyn PhysicalExpr>],
aggr_expr: &[Arc<dyn AggregateExpr>],
batch: RecordBatch,
mut accumulators: Accumulators,
aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<Accumulators> {
let group_values = evaluate(group_expr, &batch)?;
let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
let mut groups_with_rows = vec![];
let mut batch_hashes = vec![0; batch.num_rows()];
create_hashes(&group_values, random_state, &mut batch_hashes)?;
for (row, hash) in batch_hashes.into_iter().enumerate() {
let Accumulators { map, group_states } = &mut accumulators;
let entry = map.get_mut(hash, |(_hash, group_idx)| {
let group_state = &group_states[*group_idx];
group_values
.iter()
.zip(group_state.group_by_values.iter())
.all(|(array, scalar)| scalar.eq_array(array, row))
});
match entry {
Some((_hash, group_idx)) => {
let group_state = &mut group_states[*group_idx];
if group_state.indices.is_empty() {
groups_with_rows.push(*group_idx);
};
group_state.indices.push(row as u32); }
None => {
let accumulator_set = create_accumulators(aggr_expr)?;
let group_by_values = group_values
.iter()
.map(|col| ScalarValue::try_from_array(col, row))
.collect::<Result<Vec<_>>>()?;
let group_state = GroupState {
group_by_values: group_by_values.into_boxed_slice(),
accumulator_set,
indices: vec![row as u32], };
let group_idx = group_states.len();
group_states.push(group_state);
groups_with_rows.push(group_idx);
map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
}
};
}
let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
let mut offsets = vec![0];
let mut offset_so_far = 0;
for group_idx in groups_with_rows.iter() {
let indices = &accumulators.group_states[*group_idx].indices;
batch_indices.append_slice(indices)?;
offset_so_far += indices.len();
offsets.push(offset_so_far);
}
let batch_indices = batch_indices.finish();
let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
.iter()
.map(|array| {
array
.iter()
.map(|array| {
compute::take(
array.as_ref(),
&batch_indices,
None, )
.unwrap()
})
.collect()
})
.collect();
groups_with_rows
.iter()
.zip(offsets.windows(2))
.try_for_each(|(group_idx, offsets)| {
let group_state = &mut accumulators.group_states[*group_idx];
group_state
.accumulator_set
.iter_mut()
.zip(values.iter())
.map(|(accumulator, aggr_array)| {
(
accumulator,
aggr_array
.iter()
.map(|array| {
array.slice(offsets[0], offsets[1] - offsets[0])
})
.collect::<Vec<ArrayRef>>(),
)
})
.try_for_each(|(accumulator, values)| match mode {
AggregateMode::Partial => accumulator.update_batch(&values),
AggregateMode::FinalPartitioned | AggregateMode::Final => {
accumulator.merge_batch(&values)
}
})
.and({
group_state.indices.clear();
Ok(())
})
})?;
Ok(accumulators)
}
async fn compute_grouped_hash_aggregate(
mode: AggregateMode,
schema: SchemaRef,
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
mut input: SendableRecordBatchStream,
elapsed_compute: metrics::Time,
) -> ArrowResult<RecordBatch> {
let timer = elapsed_compute.timer();
let aggregate_expressions =
aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
let random_state = RandomState::new();
let mut accumulators = Accumulators::default();
timer.done();
while let Some(batch) = input.next().await {
let batch = batch?;
let timer = elapsed_compute.timer();
accumulators = group_aggregate_batch(
&mode,
&random_state,
&group_expr,
&aggr_expr,
batch,
accumulators,
&aggregate_expressions,
)?;
timer.done();
}
let timer = elapsed_compute.timer();
let batch = create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema);
timer.done();
batch
}
impl GroupedHashAggregateStream {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();
let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let join_handle = tokio::spawn(async move {
let result = compute_grouped_hash_aggregate(
mode,
schema_clone,
group_expr,
aggr_expr,
input,
elapsed_compute,
)
.await
.record_output(&baseline_metrics);
tx.send(result).ok();
});
Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
type AccumulatorItem = Box<dyn Accumulator>;
#[derive(Debug)]
struct GroupState {
group_by_values: Box<[ScalarValue]>,
accumulator_set: Vec<AccumulatorItem>,
indices: Vec<u32>,
}
#[derive(Default)]
struct Accumulators {
map: RawTable<(u64, usize)>,
group_states: Vec<GroupState>,
}
impl std::fmt::Debug for Accumulators {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let map_string = "RawTable";
f.debug_struct("Accumulators")
.field("map", &map_string)
.field("group_states", &self.group_states)
.finish()
}
}
impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.finished {
return Poll::Ready(None);
}
let this = self.project();
let output_poll = this.output.poll(cx);
match output_poll {
Poll::Ready(result) => {
*this.finished = true;
let result = match result {
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), Ok(result) => result,
};
Poll::Ready(Some(result))
}
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for GroupedHashAggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn evaluate(
expr: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
) -> Result<Vec<ArrayRef>> {
expr.iter()
.map(|expr| expr.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()
}
fn evaluate_many(
expr: &[Vec<Arc<dyn PhysicalExpr>>],
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
expr.iter()
.map(|expr| evaluate(expr, batch))
.collect::<Result<Vec<_>>>()
}
fn merge_expressions(
index_base: usize,
expr: &Arc<dyn AggregateExpr>,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
Ok(expr
.state_fields()?
.iter()
.enumerate()
.map(|(idx, f)| {
Arc::new(Column::new(f.name(), index_base + idx)) as Arc<dyn PhysicalExpr>
})
.collect::<Vec<_>>())
}
fn aggregate_expressions(
aggr_expr: &[Arc<dyn AggregateExpr>],
mode: &AggregateMode,
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let mut col_idx_base = col_idx_base;
Ok(aggr_expr
.iter()
.map(|agg| {
let exprs = merge_expressions(col_idx_base, agg)?;
col_idx_base += exprs.len();
Ok(exprs)
})
.collect::<Result<Vec<_>>>()?)
}
}
}
pin_project! {
pub struct HashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
}
async fn compute_hash_aggregate(
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
mut input: SendableRecordBatchStream,
elapsed_compute: metrics::Time,
) -> ArrowResult<RecordBatch> {
let timer = elapsed_compute.timer();
let mut accumulators = create_accumulators(&aggr_expr)?;
let expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
let expressions = Arc::new(expressions);
timer.done();
while let Some(batch) = input.next().await {
let batch = batch?;
let timer = elapsed_compute.timer();
aggregate_batch(&mode, &batch, &mut accumulators, &expressions)?;
timer.done();
}
let timer = elapsed_compute.timer();
let batch = finalize_aggregation(&accumulators, &mode)
.map(|columns| RecordBatch::try_new(schema.clone(), columns))?;
timer.done();
batch
}
impl HashAggregateStream {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();
let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let join_handle = tokio::spawn(async move {
let result = compute_hash_aggregate(
mode,
schema_clone,
aggr_expr,
input,
elapsed_compute,
)
.await
.record_output(&baseline_metrics);
tx.send(result).ok();
});
Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
fn aggregate_batch(
mode: &AggregateMode,
batch: &RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
accumulators
.iter_mut()
.zip(expressions)
.try_for_each(|(accum, expr)| {
let values = &expr
.iter()
.map(|e| e.evaluate(batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
match mode {
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
}
}
})
}
impl Stream for HashAggregateStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.finished {
return Poll::Ready(None);
}
let this = self.project();
let output_poll = this.output.poll(cx);
match output_poll {
Poll::Ready(result) => {
*this.finished = true;
let result = match result {
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), Ok(result) => result,
};
Poll::Ready(Some(result))
}
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for HashAggregateStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn create_batch_from_map(
mode: &AggregateMode,
accumulators: &Accumulators,
num_group_expr: usize,
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
if accumulators.group_states.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
}
let accs = &accumulators.group_states[0].accumulator_set;
let mut acc_data_types: Vec<usize> = vec![];
match mode {
AggregateMode::Partial => {
for acc in accs.iter() {
let state = acc.state()?;
acc_data_types.push(state.len());
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
acc_data_types = vec![1; accs.len()];
}
}
let mut columns = (0..num_group_expr)
.map(|i| {
ScalarValue::iter_to_array(
accumulators
.group_states
.iter()
.map(|group_state| group_state.group_by_values[i].clone()),
)
})
.collect::<Result<Vec<_>>>()?;
for (x, &state_len) in acc_data_types.iter().enumerate() {
for y in 0..state_len {
match mode {
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
let x = group_state.accumulator_set[x].state().unwrap();
x[y].clone()
}),
)?;
columns.push(res);
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
group_state.accumulator_set[x].evaluate().unwrap()
}),
)?;
columns.push(res);
}
}
}
}
let columns = columns
.iter()
.zip(output_schema.fields().iter())
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
.collect::<ArrowResult<Vec<_>>>()?;
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
}
fn create_accumulators(
aggr_expr: &[Arc<dyn AggregateExpr>],
) -> Result<Vec<AccumulatorItem>> {
aggr_expr
.iter()
.map(|expr| expr.create_accumulator())
.collect::<Result<Vec<_>>>()
}
fn finalize_aggregation(
accumulators: &[AccumulatorItem],
mode: &AggregateMode,
) -> Result<Vec<ArrayRef>> {
match mode {
AggregateMode::Partial => {
let a = accumulators
.iter()
.map(|accumulator| accumulator.state())
.map(|value| {
value.map(|e| {
e.iter().map(|v| v.to_array()).collect::<Vec<ArrayRef>>()
})
})
.collect::<Result<Vec<_>>>()?;
Ok(a.iter().flatten().cloned().collect::<Vec<_>>())
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accumulators
.iter()
.map(|accumulator| accumulator.evaluate().map(|v| v.to_array()))
.collect::<Result<Vec<ArrayRef>>>()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::from_slice::FromSlice;
use crate::physical_plan::expressions::{col, Avg};
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::{assert_batches_sorted_eq, physical_plan::common};
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::DataType;
use futures::FutureExt;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Float64, false),
]));
(
schema.clone(),
vec![
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from_slice(&[2, 3, 4, 4])),
Arc::new(Float64Array::from_slice(&[1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt32Array::from_slice(&[2, 3, 3, 4])),
Arc::new(Float64Array::from_slice(&[1.0, 2.0, 3.0, 4.0])),
],
)
.unwrap(),
],
)
}
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
let input_schema = input.schema();
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &input_schema)?, "a".to_string())];
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
let runtime = Arc::new(RuntimeEnv::default());
let partial_aggregate = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
input,
input_schema.clone(),
)?);
let result =
common::collect(partial_aggregate.execute(0, runtime.clone()).await?).await?;
let expected = vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2 |",
"| 3 | 3 | 7 |",
"| 4 | 3 | 11 |",
"+---+---------------+-------------+",
];
assert_batches_sorted_eq!(expected, &result);
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
.map(|i| col(&groups[i].1, &input_schema))
.collect::<Result<_>>()?;
let merged_aggregate = Arc::new(HashAggregateExec::try_new(
AggregateMode::Final,
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
.collect(),
aggregates,
merge,
input_schema,
)?);
let result =
common::collect(merged_aggregate.execute(0, runtime.clone()).await?).await?;
assert_eq!(result.len(), 1);
let batch = &result[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
let expected = vec![
"+---+--------------------+",
"| a | AVG(b) |",
"+---+--------------------+",
"| 2 | 1 |",
"| 3 | 2.3333333333333335 |", "| 4 | 3.6666666666666665 |", "+---+--------------------+",
];
assert_batches_sorted_eq!(&expected, &result);
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(3, output_rows);
Ok(())
}
#[derive(Debug)]
struct TestYieldingExec {
pub yield_first: bool,
}
#[async_trait]
impl ExecutionPlan for TestYieldingExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
some_data().0
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn with_new_children(
&self,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Internal(format!(
"Children cannot be replaced in {:?}",
self
)))
}
async fn execute(
&self,
_partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
let stream;
if self.yield_first {
stream = TestYieldingStream::New;
} else {
stream = TestYieldingStream::Yielded;
}
Ok(Box::pin(stream))
}
fn statistics(&self) -> Statistics {
let (_, batches) = some_data();
common::compute_record_batch_statistics(&[batches], &self.schema(), None)
}
}
enum TestYieldingStream {
New,
Yielded,
ReturnedBatch1,
ReturnedBatch2,
}
impl Stream for TestYieldingStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match &*self {
TestYieldingStream::New => {
*(self.as_mut()) = TestYieldingStream::Yielded;
cx.waker().wake_by_ref();
Poll::Pending
}
TestYieldingStream::Yielded => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
Poll::Ready(Some(Ok(some_data().1[0].clone())))
}
TestYieldingStream::ReturnedBatch1 => {
*(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
Poll::Ready(Some(Ok(some_data().1[1].clone())))
}
TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
}
}
}
impl RecordBatchStream for TestYieldingStream {
fn schema(&self) -> SchemaRef {
some_data().0
}
}
#[tokio::test]
async fn aggregate_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });
check_aggregates(input).await
}
#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
check_aggregates(input).await
}
#[tokio::test]
async fn test_drop_cancel_without_groups() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let groups = vec![];
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("a", &schema)?,
"AVG(a)".to_string(),
DataType::Float64,
))];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
blocking_exec,
schema,
)?);
let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn test_drop_cancel_with_groups() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "a".to_string())];
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
blocking_exec,
schema,
)?);
let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
}