use async_trait::async_trait;
use std::{
any::Any,
pin::Pin,
sync::{Arc, Weak},
task::{Context, Poll},
};
use tokio::sync::Barrier;
use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
};
use futures::Stream;
use crate::physical_plan::{
common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
use crate::{
error::{DataFusionError, Result},
physical_plan::stream::RecordBatchReceiverStream,
};
use crate::{
execution::runtime_env::RuntimeEnv, physical_plan::expressions::PhysicalSortExpr,
};
#[derive(Debug, Default, Clone)]
pub struct BatchIndex {
inner: std::sync::Arc<std::sync::Mutex<usize>>,
}
impl BatchIndex {
pub fn value(&self) -> usize {
let inner = self.inner.lock().unwrap();
*inner
}
pub fn incr(&self) {
let mut inner = self.inner.lock().unwrap();
*inner += 1;
}
}
#[derive(Debug, Default)]
pub(crate) struct TestStream {
data: Vec<RecordBatch>,
index: BatchIndex,
}
impl TestStream {
pub fn new(data: Vec<RecordBatch>) -> Self {
Self {
data,
..Default::default()
}
}
pub fn index(&self) -> BatchIndex {
self.index.clone()
}
}
impl Stream for TestStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let next_batch = self.index.value();
Poll::Ready(if next_batch < self.data.len() {
let next_batch = self.index.value();
self.index.incr();
Some(Ok(self.data[next_batch].clone()))
} else {
None
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.data.len(), Some(self.data.len()))
}
}
impl RecordBatchStream for TestStream {
fn schema(&self) -> SchemaRef {
self.data[0].schema()
}
}
#[derive(Debug)]
pub struct MockExec {
data: Vec<ArrowResult<RecordBatch>>,
schema: SchemaRef,
}
impl MockExec {
pub fn new(data: Vec<ArrowResult<RecordBatch>>, schema: SchemaRef) -> Self {
Self { data, schema }
}
}
#[async_trait]
impl ExecutionPlan for MockExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(
&self,
partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(partition, 0);
let data: Vec<_> = self
.data
.iter()
.map(|r| match r {
Ok(batch) => Ok(batch.clone()),
Err(e) => Err(clone_error(e)),
})
.collect();
let (tx, rx) = tokio::sync::mpsc::channel(2);
let join_handle = tokio::task::spawn(async move {
for batch in data {
println!("Sending batch via delayed stream");
if let Err(e) = tx.send(batch).await {
println!("ERROR batch via delayed stream: {}", e);
}
}
});
Ok(RecordBatchReceiverStream::create(
&self.schema,
rx,
join_handle,
))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "MockExec")
}
}
}
fn statistics(&self) -> Statistics {
let data: ArrowResult<Vec<_>> = self
.data
.iter()
.map(|r| match r {
Ok(batch) => Ok(batch.clone()),
Err(e) => Err(clone_error(e)),
})
.collect();
let data = data.unwrap();
common::compute_record_batch_statistics(&[data], &self.schema, None)
}
}
fn clone_error(e: &ArrowError) -> ArrowError {
use ArrowError::*;
match e {
ComputeError(msg) => ComputeError(msg.to_string()),
_ => unimplemented!(),
}
}
#[derive(Debug)]
pub struct BarrierExec {
data: Vec<Vec<RecordBatch>>,
schema: SchemaRef,
barrier: Arc<Barrier>,
}
impl BarrierExec {
pub fn new(data: Vec<Vec<RecordBatch>>, schema: SchemaRef) -> Self {
let barrier = Arc::new(Barrier::new(data.len() + 1));
Self {
data,
schema,
barrier,
}
}
pub async fn wait(&self) {
println!("BarrierExec::wait waiting on barrier");
self.barrier.wait().await;
println!("BarrierExec::wait done waiting");
}
}
#[async_trait]
impl ExecutionPlan for BarrierExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(self.data.len())
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(
&self,
partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
assert!(partition < self.data.len());
let (tx, rx) = tokio::sync::mpsc::channel(2);
let data = self.data[partition].clone();
let b = self.barrier.clone();
let join_handle = tokio::task::spawn(async move {
println!("Partition {} waiting on barrier", partition);
b.wait().await;
for batch in data {
println!("Partition {} sending batch", partition);
if let Err(e) = tx.send(Ok(batch)).await {
println!("ERROR batch via barrier stream stream: {}", e);
}
}
});
Ok(RecordBatchReceiverStream::create(
&self.schema,
rx,
join_handle,
))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "BarrierExec")
}
}
}
fn statistics(&self) -> Statistics {
common::compute_record_batch_statistics(&self.data, &self.schema, None)
}
}
#[derive(Debug)]
pub struct ErrorExec {
schema: SchemaRef,
}
impl Default for ErrorExec {
fn default() -> Self {
Self::new()
}
}
impl ErrorExec {
pub fn new() -> Self {
let schema = Arc::new(Schema::new(vec![Field::new(
"dummy",
DataType::Int64,
true,
)]));
Self { schema }
}
}
#[async_trait]
impl ExecutionPlan for ErrorExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn execute(
&self,
partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
Err(DataFusionError::Internal(format!(
"ErrorExec, unsurprisingly, errored in partition {}",
partition
)))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "ErrorExec")
}
}
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
}
#[derive(Debug, Clone)]
pub struct StatisticsExec {
stats: Statistics,
schema: Arc<Schema>,
}
impl StatisticsExec {
pub fn new(stats: Statistics, schema: Schema) -> Self {
assert!(
stats
.column_statistics
.as_ref()
.map(|cols| cols.len() == schema.fields().len())
.unwrap_or(true),
"if defined, the column statistics vector length should be the number of fields"
);
Self {
stats,
schema: Arc::new(schema),
}
}
}
#[async_trait]
impl ExecutionPlan for StatisticsExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(2)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
if children.is_empty() {
Ok(Arc::new(self.clone()))
} else {
Err(DataFusionError::Internal(
"Children cannot be replaced in CustomExecutionPlan".to_owned(),
))
}
}
async fn execute(
&self,
_partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
unimplemented!("This plan only serves for testing statistics")
}
fn statistics(&self) -> Statistics {
self.stats.clone()
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(
f,
"StatisticsExec: col_count={}, row_count={:?}",
self.schema.fields().len(),
self.stats.num_rows,
)
}
}
}
}
#[derive(Debug)]
pub struct BlockingExec {
schema: SchemaRef,
n_partitions: usize,
refs: Arc<()>,
}
impl BlockingExec {
pub fn new(schema: SchemaRef, n_partitions: usize) -> Self {
Self {
schema,
n_partitions,
refs: Default::default(),
}
}
pub fn refs(&self) -> Weak<()> {
Arc::downgrade(&self.refs)
}
}
#[async_trait]
impl ExecutionPlan for BlockingExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(self.n_partitions)
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn with_new_children(
&self,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Internal(format!(
"Children cannot be replaced in {:?}",
self
)))
}
async fn execute(
&self,
_partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(BlockingStream {
schema: Arc::clone(&self.schema),
_refs: Arc::clone(&self.refs),
}))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "BlockingExec",)
}
}
}
fn statistics(&self) -> Statistics {
unimplemented!()
}
}
#[derive(Debug)]
pub struct BlockingStream {
schema: SchemaRef,
_refs: Arc<()>,
}
impl Stream for BlockingStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Pending
}
}
impl RecordBatchStream for BlockingStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
pub async fn assert_strong_count_converges_to_zero<T>(refs: Weak<T>) {
tokio::time::timeout(std::time::Duration::from_secs(10), async {
loop {
if dbg!(Weak::strong_count(&refs)) == 0 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
})
.await
.unwrap();
}