use std::sync::Arc;
use crate::physical_plan::displayable;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_execution::TaskContext;
use futures::stream::BoxStream;
use futures::{Future, Stream, StreamExt};
use log::debug;
use pin_project_lite::pin_project;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::task::JoinSet;
use super::metrics::BaselineMetrics;
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
pub struct RecordBatchReceiverStreamBuilder {
tx: Sender<Result<RecordBatch>>,
rx: Receiver<Result<RecordBatch>>,
schema: SchemaRef,
join_set: JoinSet<()>,
}
impl RecordBatchReceiverStreamBuilder {
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
Self {
tx,
rx,
schema,
join_set: JoinSet::new(),
}
}
pub fn tx(&self) -> Sender<Result<RecordBatch>> {
self.tx.clone()
}
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = ()>,
F: Send + 'static,
{
self.join_set.spawn(task);
}
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce(),
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
}
pub(crate) fn run_input(
&mut self,
input: Arc<dyn ExecutionPlan>,
partition: usize,
context: Arc<TaskContext>,
) {
let output = self.tx();
self.spawn(async move {
let mut stream = match input.execute(partition, context) {
Err(e) => {
output.send(Err(e)).await.ok();
debug!(
"Stopping execution: error executing input: {}",
displayable(input.as_ref()).one_line()
);
return;
}
Ok(stream) => stream,
};
while let Some(item) = stream.next().await {
let is_err = item.is_err();
if output.send(item).await.is_err() {
debug!(
"Stopping execution: output is gone, plan cancelling: {}",
displayable(input.as_ref()).one_line()
);
return;
}
if is_err {
debug!(
"Stopping execution: plan returned error: {}",
displayable(input.as_ref()).one_line()
);
return;
}
}
});
}
pub fn build(self) -> SendableRecordBatchStream {
let Self {
tx,
rx,
schema,
mut join_set,
} = self;
drop(tx);
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(()) => continue, Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
return Some(Err(DataFusionError::Internal(format!(
"Non Panic Task error: {e}"
))));
}
}
}
}
None
};
let check_stream = futures::stream::once(check)
.filter_map(|item| async move { item });
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});
let inner = futures::stream::select(rx_stream, check_stream).boxed();
Box::pin(RecordBatchReceiverStream { schema, inner })
}
}
pub struct RecordBatchReceiverStream {
schema: SchemaRef,
inner: BoxStream<'static, Result<RecordBatch>>,
}
impl RecordBatchReceiverStream {
pub fn builder(
schema: SchemaRef,
capacity: usize,
) -> RecordBatchReceiverStreamBuilder {
RecordBatchReceiverStreamBuilder::new(schema, capacity)
}
}
impl Stream for RecordBatchReceiverStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}
impl RecordBatchStream for RecordBatchReceiverStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pin_project! {
pub struct RecordBatchStreamAdapter<S> {
schema: SchemaRef,
#[pin]
stream: S,
}
}
impl<S> RecordBatchStreamAdapter<S> {
pub fn new(schema: SchemaRef, stream: S) -> Self {
Self { schema, stream }
}
}
impl<S> std::fmt::Debug for RecordBatchStreamAdapter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordBatchStreamAdapter")
.field("schema", &self.schema)
.finish()
}
}
impl<S> Stream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = Result<RecordBatch>>,
{
type Item = Result<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.project().stream.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
impl<S> RecordBatchStream for RecordBatchStreamAdapter<S>
where
S: Stream<Item = Result<RecordBatch>>,
{
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pub(crate) struct ObservedStream {
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
}
impl ObservedStream {
pub fn new(
inner: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
Self {
inner,
baseline_metrics,
}
}
}
impl RecordBatchStream for ObservedStream {
fn schema(&self) -> arrow::datatypes::SchemaRef {
self.inner.schema()
}
}
impl futures::Stream for ObservedStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let poll = self.inner.poll_next_unpin(cx);
self.baseline_metrics.record_poll(poll)
}
}
#[cfg(test)]
mod test {
use super::*;
use arrow_schema::{DataType, Field, Schema};
use crate::{
execution::context::SessionContext,
test::exec::{
assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
},
};
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
}
#[tokio::test]
#[should_panic(expected = "PanickingStream did panic")]
async fn record_batch_receiver_stream_propagates_panics() {
let schema = schema();
let num_partitions = 10;
let input = PanicExec::new(schema.clone(), num_partitions);
consume(input, 10).await
}
#[tokio::test]
#[should_panic(expected = "PanickingStream did panic: 1")]
async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
let schema = schema();
let num_partitions = 2;
let input = PanicExec::new(schema.clone(), num_partitions)
.with_partition_panic(0, 10)
.with_partition_panic(1, 3);
let max_batches = 5;
consume(input, max_batches).await
}
#[tokio::test]
async fn record_batch_receiver_stream_drop_cancel() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = schema();
let input = BlockingExec::new(schema.clone(), 1);
let refs = input.refs();
let mut builder = RecordBatchReceiverStream::builder(schema, 2);
builder.run_input(Arc::new(input), 0, task_ctx.clone());
let stream = builder.build();
assert!(std::sync::Weak::strong_count(&refs) > 0);
drop(stream);
assert_strong_count_converges_to_zero(refs).await;
}
#[tokio::test]
async fn record_batch_receiver_stream_error_does_not_drive_completion() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = schema();
let error_stream = MockExec::new(
vec![
Err(DataFusionError::Execution("Test1".to_string())),
Err(DataFusionError::Execution("Test2".to_string())),
],
schema.clone(),
)
.with_use_task(false);
let mut builder = RecordBatchReceiverStream::builder(schema, 2);
builder.run_input(Arc::new(error_stream), 0, task_ctx.clone());
let mut stream = builder.build();
let first_batch = stream.next().await.unwrap();
let first_err = first_batch.unwrap_err();
assert_eq!(first_err.to_string(), "Execution error: Test1");
assert!(stream.next().await.is_none());
}
async fn consume(input: PanicExec, max_batches: usize) {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let input = Arc::new(input);
let num_partitions = input.output_partitioning().partition_count();
let mut builder =
RecordBatchReceiverStream::builder(input.schema(), num_partitions);
for partition in 0..num_partitions {
builder.run_input(input.clone(), partition, task_ctx.clone());
}
let mut stream = builder.build();
let mut num_batches = 0;
while let Some(next) = stream.next().await {
next.unwrap();
num_batches += 1;
assert!(
num_batches < max_batches,
"Got the limit of {num_batches} batches before seeing panic"
);
}
}
}