use async_trait::async_trait;
use std::{
any::Any,
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::Barrier;
use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
};
use futures::{Stream, StreamExt};
use tokio_stream::wrappers::ReceiverStream;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
};
#[derive(Debug, Default, Clone)]
pub struct BatchIndex {
inner: std::sync::Arc<std::sync::Mutex<usize>>,
}
impl BatchIndex {
pub fn value(&self) -> usize {
let inner = self.inner.lock().unwrap();
*inner
}
pub fn incr(&self) {
let mut inner = self.inner.lock().unwrap();
*inner += 1;
}
}
#[derive(Debug, Default)]
pub(crate) struct TestStream {
data: Vec<RecordBatch>,
index: BatchIndex,
}
impl TestStream {
pub fn new(data: Vec<RecordBatch>) -> Self {
Self {
data,
..Default::default()
}
}
pub fn index(&self) -> BatchIndex {
self.index.clone()
}
}
impl Stream for TestStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let next_batch = self.index.value();
Poll::Ready(if next_batch < self.data.len() {
let next_batch = self.index.value();
self.index.incr();
Some(Ok(self.data[next_batch].clone()))
} else {
None
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.data.len(), Some(self.data.len()))
}
}
impl RecordBatchStream for TestStream {
fn schema(&self) -> SchemaRef {
self.data[0].schema()
}
}
#[derive(Debug)]
pub struct MockExec {
data: Vec<ArrowResult<RecordBatch>>,
schema: SchemaRef,
}
impl MockExec {
pub fn new(data: Vec<ArrowResult<RecordBatch>>, schema: SchemaRef) -> Self {
Self { data, schema }
}
}
#[async_trait]
impl ExecutionPlan for MockExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
assert_eq!(partition, 0);
let schema = self.schema();
let data: Vec<_> = self
.data
.iter()
.map(|r| match r {
Ok(batch) => Ok(batch.clone()),
Err(e) => Err(clone_error(e)),
})
.collect();
let (tx, rx) = tokio::sync::mpsc::channel(2);
tokio::task::spawn(async move {
for batch in data {
println!("Sending batch via delayed stream");
if let Err(e) = tx.send(batch).await {
println!("ERROR batch via delayed stream: {}", e);
}
}
});
let stream = DelayedStream {
schema,
inner: ReceiverStream::new(rx),
};
Ok(Box::pin(stream))
}
}
fn clone_error(e: &ArrowError) -> ArrowError {
use ArrowError::*;
match e {
ComputeError(msg) => ComputeError(msg.to_string()),
_ => unimplemented!(),
}
}
#[derive(Debug)]
pub struct DelayedStream {
schema: SchemaRef,
inner: ReceiverStream<ArrowResult<RecordBatch>>,
}
impl Stream for DelayedStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
impl RecordBatchStream for DelayedStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[derive(Debug)]
pub struct BarrierExec {
data: Vec<Vec<RecordBatch>>,
schema: SchemaRef,
barrier: Arc<Barrier>,
}
impl BarrierExec {
pub fn new(data: Vec<Vec<RecordBatch>>, schema: SchemaRef) -> Self {
let barrier = Arc::new(Barrier::new(data.len() + 1));
Self {
data,
schema,
barrier,
}
}
pub async fn wait(&self) {
println!("BarrierExec::wait waiting on barrier");
self.barrier.wait().await;
println!("BarrierExec::wait done waiting");
}
}
#[async_trait]
impl ExecutionPlan for BarrierExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(self.data.len())
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
assert!(partition < self.data.len());
let schema = self.schema();
let (tx, rx) = tokio::sync::mpsc::channel(2);
let data = self.data[partition].clone();
let b = self.barrier.clone();
tokio::task::spawn(async move {
println!("Partition {} waiting on barrier", partition);
b.wait().await;
for batch in data {
println!("Partition {} sending batch", partition);
if let Err(e) = tx.send(Ok(batch)).await {
println!("ERROR batch via barrier stream stream: {}", e);
}
}
});
let stream = DelayedStream {
schema,
inner: ReceiverStream::new(rx),
};
Ok(Box::pin(stream))
}
}
#[derive(Debug)]
pub struct ErrorExec {
schema: SchemaRef,
}
impl ErrorExec {
pub fn new() -> Self {
let schema = Arc::new(Schema::new(vec![Field::new(
"dummy",
DataType::Int64,
true,
)]));
Self { schema }
}
}
#[async_trait]
impl ExecutionPlan for ErrorExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
Err(DataFusionError::Internal(format!(
"ErrorExec, unsurprisingly, errored in partition {}",
partition
)))
}
}