use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics};
use arrow::array::{ArrayRef, UInt64Builder};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use log::debug;
use tokio_stream::wrappers::UnboundedReceiverStream;
use super::common::{AbortOnDropMany, AbortOnDropSingle};
use super::expressions::PhysicalSortExpr;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::execution::context::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::Stream;
use futures::StreamExt;
use hashbrown::HashMap;
use parking_lot::Mutex;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
type MaybeBatch = Option<ArrowResult<RecordBatch>>;
#[derive(Debug)]
struct RepartitionExecState {
channels:
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
abort_helper: Arc<AbortOnDropMany<()>>,
}
pub struct BatchPartitioner {
state: BatchPartitionerState,
timer: metrics::Time,
}
enum BatchPartitionerState {
Hash {
random_state: ahash::RandomState,
exprs: Vec<Arc<dyn PhysicalExpr>>,
num_partitions: usize,
hash_buffer: Vec<u64>,
},
RoundRobin {
num_partitions: usize,
next_idx: usize,
},
}
impl BatchPartitioner {
pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
let state = match partitioning {
Partitioning::RoundRobinBatch(num_partitions) => {
BatchPartitionerState::RoundRobin {
num_partitions,
next_idx: 0,
}
}
Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
exprs,
num_partitions,
random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
hash_buffer: vec![],
},
other => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported repartitioning scheme {:?}",
other
)))
}
};
Ok(Self { state, timer })
}
pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
where
F: FnMut(usize, RecordBatch) -> Result<()>,
{
match &mut self.state {
BatchPartitionerState::RoundRobin {
num_partitions,
next_idx,
} => {
let idx = *next_idx;
*next_idx = (*next_idx + 1) % *num_partitions;
f(idx, batch)?;
}
BatchPartitionerState::Hash {
random_state,
exprs,
num_partitions: partitions,
hash_buffer,
} => {
let mut timer = self.timer.timer();
let arrays = exprs
.iter()
.map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
hash_buffer.clear();
hash_buffer.resize(batch.num_rows(), 0);
create_hashes(&arrays, random_state, hash_buffer)?;
let mut indices: Vec<_> = (0..*partitions)
.map(|_| UInt64Builder::new(batch.num_rows()))
.collect();
for (index, hash) in hash_buffer.iter().enumerate() {
indices[(*hash % *partitions as u64) as usize]
.append_value(index as u64)
.unwrap();
}
for (partition, mut indices) in indices.into_iter().enumerate() {
let indices = indices.finish();
if indices.is_empty() {
continue;
}
let columns = batch
.columns()
.iter()
.map(|c| {
arrow::compute::take(c.as_ref(), &indices, None)
.map_err(DataFusionError::ArrowError)
})
.collect::<Result<Vec<ArrayRef>>>()?;
let batch = RecordBatch::try_new(batch.schema(), columns).unwrap();
timer.stop();
f(partition, batch)?;
timer.restart();
}
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct RepartitionExec {
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
state: Arc<Mutex<RepartitionExecState>>,
metrics: ExecutionPlanMetricsSet,
}
#[derive(Debug, Clone)]
struct RepartitionMetrics {
fetch_time: metrics::Time,
repart_time: metrics::Time,
send_time: metrics::Time,
}
impl RepartitionMetrics {
pub fn new(
output_partition: usize,
input_partition: usize,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
let label = metrics::Label::new("inputPartition", input_partition.to_string());
let fetch_time = MetricBuilder::new(metrics)
.with_label(label.clone())
.subset_time("fetch_time", output_partition);
let repart_time = MetricBuilder::new(metrics)
.with_label(label.clone())
.subset_time("repart_time", output_partition);
let send_time = MetricBuilder::new(metrics)
.with_label(label)
.subset_time("send_time", output_partition);
Self {
fetch_time,
repart_time,
send_time,
}
}
}
impl RepartitionExec {
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn partitioning(&self) -> &Partitioning {
&self.partitioning
}
}
impl ExecutionPlan for RepartitionExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn relies_on_input_order(&self) -> bool {
false
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(RepartitionExec::try_new(
children[0].clone(),
self.partitioning.clone(),
)?))
}
fn output_partitioning(&self) -> Partitioning {
self.partitioning.clone()
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
debug!(
"Start RepartitionExec::execute for partition: {}",
partition
);
let mut state = self.state.lock();
let num_input_partitions = self.input.output_partitioning().partition_count();
let num_output_partitions = self.partitioning.partition_count();
if state.channels.is_empty() {
for partition in 0..num_output_partitions {
let (sender, receiver) =
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
state.channels.insert(partition, (sender, receiver));
}
let mut join_handles = Vec::with_capacity(num_input_partitions);
for i in 0..num_input_partitions {
let txs: HashMap<_, _> = state
.channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.collect();
let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
let input_task: JoinHandle<Result<()>> =
tokio::spawn(Self::pull_from_input(
self.input.clone(),
i,
txs.clone(),
self.partitioning.clone(),
r_metrics,
context.clone(),
));
let join_handle = tokio::spawn(Self::wait_for_task(
AbortOnDropSingle::new(input_task),
txs,
));
join_handles.push(join_handle);
}
state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
}
debug!(
"Before returning stream in RepartitionExec::execute for partition: {}",
partition
);
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: UnboundedReceiverStream::new(
state.channels.remove(&partition).unwrap().1,
),
drop_helper: Arc::clone(&state.abort_helper),
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "RepartitionExec: partitioning={:?}", self.partitioning)
}
}
}
fn statistics(&self) -> Statistics {
self.input.statistics()
}
}
impl RepartitionExec {
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
partitioning: Partitioning,
) -> Result<Self> {
Ok(RepartitionExec {
input,
partitioning,
state: Arc::new(Mutex::new(RepartitionExecState {
channels: HashMap::new(),
abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
})),
metrics: ExecutionPlanMetricsSet::new(),
})
}
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
i: usize,
mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
partitioning: Partitioning,
r_metrics: RepartitionMetrics,
context: Arc<TaskContext>,
) -> Result<()> {
let mut partitioner =
BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?;
let timer = r_metrics.fetch_time.timer();
let mut stream = input.execute(i, context)?;
timer.done();
while !txs.is_empty() {
let timer = r_metrics.fetch_time.timer();
let result = stream.next().await;
timer.done();
let batch = match result {
Some(result) => result?,
None => break,
};
partitioner.partition(batch, |partition, partitioned| {
let timer = r_metrics.send_time.timer();
if let Some(tx) = txs.get_mut(&partition) {
if tx.send(Some(Ok(partitioned))).is_err() {
txs.remove(&partition);
}
}
timer.done();
Ok(())
})?;
}
Ok(())
}
async fn wait_for_task(
input_task: AbortOnDropSingle<Result<()>>,
txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
) {
match input_task.await {
Err(e) => {
for (_, tx) in txs {
let err = DataFusionError::Execution(format!("Join Error: {}", e));
let err = Err(err.into());
tx.send(Some(err)).ok();
}
}
Ok(Err(e)) => {
for (_, tx) in txs {
let err = DataFusionError::Execution(e.to_string());
let err = Err(err.into());
tx.send(Some(err)).ok();
}
}
Ok(Ok(())) => {
for (_, tx) in txs {
tx.send(None).ok();
}
}
}
}
}
struct RepartitionStream {
num_input_partitions: usize,
num_input_partitions_processed: usize,
schema: SchemaRef,
input: UnboundedReceiverStream<Option<ArrowResult<RecordBatch>>>,
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,
}
impl Stream for RepartitionStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;
if self.num_input_partitions == self.num_input_partitions_processed {
Poll::Ready(None)
} else {
self.poll_next(cx)
}
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchStream for RepartitionStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
use crate::{
assert_batches_sorted_eq,
physical_plan::{collect, expressions::col, memory::MemoryExec},
test::{
assert_is_pending,
exec::{
assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
ErrorExec, MockExec,
},
},
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{ArrayRef, StringArray},
error::ArrowError,
};
use futures::FutureExt;
use std::collections::HashSet;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
assert_eq!(4, output_partitions.len());
assert_eq!(13, output_partitions[0].len());
assert_eq!(13, output_partitions[1].len());
assert_eq!(12, output_partitions[2].len());
assert_eq!(12, output_partitions[3].len());
Ok(())
}
#[tokio::test]
async fn many_to_one_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
assert_eq!(1, output_partitions.len());
assert_eq!(150, output_partitions[0].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_round_robin() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions =
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
#[tokio::test]
async fn many_to_many_hash_partition() -> Result<()> {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
let output_partitions = repartition(
&schema,
partitions,
Partitioning::Hash(vec![col("c0", &schema)?], 8),
)
.await?;
let total_rows: usize = output_partitions
.iter()
.map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
.sum();
assert_eq!(8, output_partitions.len());
assert_eq!(total_rows, 8 * 50 * 3);
Ok(())
}
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
}
async fn repartition(
schema: &SchemaRef,
input_partitions: Vec<Vec<RecordBatch>>,
partitioning: Partitioning,
) -> Result<Vec<Vec<RecordBatch>>> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
let mut output_partitions = vec![];
for i in 0..exec.partitioning.partition_count() {
let mut stream = exec.execute(i, task_ctx.clone())?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
Ok(output_partitions)
}
#[tokio::test]
async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
let join_handle: JoinHandle<Result<Vec<Vec<RecordBatch>>>> =
tokio::spawn(async move {
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let partitions =
vec![partition.clone(), partition.clone(), partition.clone()];
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
});
let output_partitions = join_handle
.await
.map_err(|e| DataFusionError::Internal(e.to_string()))??;
assert_eq!(5, output_partitions.len());
assert_eq!(30, output_partitions[0].len());
assert_eq!(30, output_partitions[1].len());
assert_eq!(30, output_partitions[2].len());
assert_eq!(30, output_partitions[3].len());
assert_eq!(30, output_partitions[4].len());
Ok(())
}
#[tokio::test]
async fn unsupported_partitioning() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let partitioning = Partitioning::UnknownPartitioning(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string
.contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
"actual: {}",
result_string
);
}
#[tokio::test]
async fn error_for_input_exec() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let input = ErrorExec::new();
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
"actual: {}",
result_string
);
}
#[tokio::test]
async fn repartition_with_error_in_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let err = Err(ArrowError::ComputeError("bad data error".to_string()));
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch), err], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream = exec.execute(0, task_ctx).unwrap();
let result_string = crate::physical_plan::common::collect(output_stream)
.await
.unwrap_err()
.to_string();
assert!(
result_string.contains("bad data error"),
"actual: {}",
result_string
);
}
#[tokio::test]
async fn repartition_with_delayed_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
let expected_batches = vec![batch1.clone(), batch2.clone()];
let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
let partitioning = Partitioning::RoundRobinBatch(1);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let expected = vec![
"+------------------+",
"| my_awesome_field |",
"+------------------+",
"| foo |",
"| bar |",
"| frob |",
"| baz |",
"+------------------+",
];
assert_batches_sorted_eq!(&expected, &expected_batches);
let output_stream = exec.execute(0, task_ctx).unwrap();
let batches = crate::physical_plan::common::collect(output_stream)
.await
.unwrap();
assert_batches_sorted_eq!(&expected, &batches);
}
#[tokio::test]
async fn robin_repartition_with_dropping_output_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partitioning = Partitioning::RoundRobinBatch(2);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
std::mem::drop(output_stream0);
input.wait().await;
let batches = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
let expected = vec![
"+------------------+",
"| my_awesome_field |",
"+------------------+",
"| baz |",
"| frob |",
"| gaz |",
"| grob |",
"+------------------+",
];
assert_batches_sorted_eq!(&expected, &batches);
}
#[tokio::test]
async fn hash_repartition_with_dropping_output_stream() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
"my_awesome_field",
0,
))],
2,
);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
input.wait().await;
let batches_without_drop = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
let items_vec = str_batches_to_vec(&batches_without_drop);
let items_set: HashSet<&str> = items_vec.iter().copied().collect();
assert_eq!(items_vec.len(), items_set.len());
let source_str_set: HashSet<&str> =
(&["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"])
.iter()
.copied()
.collect();
assert_eq!(items_set.difference(&source_str_set).count(), 0);
let input = Arc::new(make_barrier_exec());
let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
std::mem::drop(output_stream0);
input.wait().await;
let batches_with_drop = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
assert_eq!(batches_without_drop, batches_with_drop);
}
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
batches
.iter()
.flat_map(|batch| {
assert_eq!(batch.columns().len(), 1);
let string_array = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("Unexpected type for repartitoned batch");
string_array
.iter()
.map(|v| v.expect("Unexpected null"))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn make_barrier_exec() -> BarrierExec {
let batch1 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef,
)])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["frob", "baz"])) as ArrayRef,
)])
.unwrap();
let batch3 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["goo", "gar"])) as ArrayRef,
)])
.unwrap();
let batch4 = RecordBatch::try_from_iter(vec![(
"my_awesome_field",
Arc::new(StringArray::from_slice(&["grob", "gaz"])) as ArrayRef,
)])
.unwrap();
let schema = batch1.schema();
BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
}
#[tokio::test]
async fn test_drop_cancel() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
let refs = blocking_exec.refs();
let repartition_exec = Arc::new(RepartitionExec::try_new(
blocking_exec,
Partitioning::UnknownPartitioning(1),
)?);
let fut = collect(repartition_exec, task_ctx);
let mut fut = fut.boxed();
assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;
Ok(())
}
#[tokio::test]
async fn hash_repartition_avoid_empty_batch() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let batch = RecordBatch::try_from_iter(vec![(
"a",
Arc::new(StringArray::from_slice(&["foo"])) as ArrayRef,
)])
.unwrap();
let partitioning = Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
"a", 0,
))],
2,
);
let schema = batch.schema();
let input = MockExec::new(vec![Ok(batch)], schema);
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap();
let batch0 = crate::physical_plan::common::collect(output_stream0)
.await
.unwrap();
let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap();
let batch1 = crate::physical_plan::common::collect(output_stream1)
.await
.unwrap();
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}
}