use std::any::Any;
use std::fmt;
use std::fmt::Debug;
use std::sync::Arc;
use datafusion_physical_plan::metrics::MetricsSet;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::ExecutionPlanProperties;
use datafusion_physical_plan::{
execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
PlanProperties, SendableRecordBatchStream,
};
use arrow::array::{ArrayRef, RecordBatch, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{internal_err, Result};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{Distribution, EquivalenceProperties};
use datafusion_physical_expr_common::sort_expr::LexRequirement;
use async_trait::async_trait;
use futures::StreamExt;
#[async_trait]
pub trait DataSink: DisplayAs + Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn metrics(&self) -> Option<MetricsSet> {
None
}
fn schema(&self) -> &SchemaRef;
async fn write_all(
&self,
data: SendableRecordBatchStream,
context: &Arc<TaskContext>,
) -> Result<u64>;
}
#[derive(Clone)]
pub struct DataSinkExec {
input: Arc<dyn ExecutionPlan>,
sink: Arc<dyn DataSink>,
count_schema: SchemaRef,
sort_order: Option<LexRequirement>,
cache: PlanProperties,
}
impl Debug for DataSinkExec {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DataSinkExec schema: {:?}", self.count_schema)
}
}
impl DataSinkExec {
pub fn new(
input: Arc<dyn ExecutionPlan>,
sink: Arc<dyn DataSink>,
sort_order: Option<LexRequirement>,
) -> Self {
let count_schema = make_count_schema();
let cache = Self::create_schema(&input, count_schema);
Self {
input,
sink,
count_schema: make_count_schema(),
sort_order,
cache,
}
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
pub fn sink(&self) -> &dyn DataSink {
self.sink.as_ref()
}
pub fn sort_order(&self) -> &Option<LexRequirement> {
&self.sort_order
}
fn create_schema(
input: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
) -> PlanProperties {
let eq_properties = EquivalenceProperties::new(schema);
PlanProperties::new(
eq_properties,
Partitioning::UnknownPartitioning(1),
input.pipeline_behavior(),
input.boundedness(),
)
}
}
impl DisplayAs for DataSinkExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(f, "DataSinkExec: sink=")?;
self.sink.fmt_as(t, f)
}
DisplayFormatType::TreeRender => self.sink().fmt_as(t, f),
}
}
}
impl ExecutionPlan for DataSinkExec {
fn name(&self) -> &'static str {
"DataSinkExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
vec![false]
}
fn required_input_distribution(&self) -> Vec<Distribution> {
vec![Distribution::SinglePartition; self.children().len()]
}
fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
vec![self.sort_order.as_ref().cloned()]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::new(
Arc::clone(&children[0]),
Arc::clone(&self.sink),
self.sort_order.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
if partition != 0 {
return internal_err!("DataSinkExec can only be called on partition 0!");
}
let data = execute_input_stream(
Arc::clone(&self.input),
Arc::clone(self.sink.schema()),
0,
Arc::clone(&context),
)?;
let count_schema = Arc::clone(&self.count_schema);
let sink = Arc::clone(&self.sink);
let stream = futures::stream::once(async move {
sink.write_all(data, &context).await.map(make_count_batch)
})
.boxed();
Ok(Box::pin(RecordBatchStreamAdapter::new(
count_schema,
stream,
)))
}
fn metrics(&self) -> Option<MetricsSet> {
self.sink.metrics()
}
}
fn make_count_batch(count: u64) -> RecordBatch {
let array = Arc::new(UInt64Array::from(vec![count])) as ArrayRef;
RecordBatch::try_from_iter_with_nullable(vec![("count", array, false)]).unwrap()
}
fn make_count_schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new(
"count",
DataType::UInt64,
false,
)]))
}