use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::expressions::PhysicalSortExpr;
use super::{RecordBatchStream, SendableRecordBatchStream, Statistics};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr,
};
use arrow::array::BooleanArray;
use arrow::compute::filter_record_batch;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use crate::execution::runtime_env::RuntimeEnv;
use futures::stream::{Stream, StreamExt};
#[derive(Debug)]
pub struct FilterExec {
predicate: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
metrics: ExecutionPlanMetricsSet,
}
impl FilterExec {
pub fn try_new(
predicate: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
match predicate.data_type(input.schema().as_ref())? {
DataType::Boolean => Ok(Self {
predicate,
input: input.clone(),
metrics: ExecutionPlanMetricsSet::new(),
}),
other => Err(DataFusionError::Plan(format!(
"Filter predicate must return boolean values, not {:?}",
other
))),
}
}
pub fn predicate(&self) -> &Arc<dyn PhysicalExpr> {
&self.predicate
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
}
#[async_trait]
impl ExecutionPlan for FilterExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
self.input.output_ordering()
}
fn maintains_input_order(&self) -> bool {
true
}
fn relies_on_input_order(&self) -> bool {
false
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match children.len() {
1 => Ok(Arc::new(FilterExec::try_new(
self.predicate.clone(),
children[0].clone(),
)?)),
_ => Err(DataFusionError::Internal(
"FilterExec wrong number of children".to_string(),
)),
}
}
async fn execute(
&self,
partition: usize,
runtime: Arc<RuntimeEnv>,
) -> Result<SendableRecordBatchStream> {
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Ok(Box::pin(FilterExecStream {
schema: self.input.schema().clone(),
predicate: self.predicate.clone(),
input: self.input.execute(partition, runtime).await?,
baseline_metrics,
}))
}
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default => {
write!(f, "FilterExec: {}", self.predicate)
}
}
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
}
struct FilterExecStream {
schema: SchemaRef,
predicate: Arc<dyn PhysicalExpr>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
}
fn batch_filter(
batch: &RecordBatch,
predicate: &Arc<dyn PhysicalExpr>,
) -> ArrowResult<RecordBatch> {
predicate
.evaluate(batch)
.map(|v| v.into_array(batch.num_rows()))
.map_err(DataFusionError::into)
.and_then(|array| {
array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"Filter predicate evaluated to non-boolean value".to_string(),
)
.into()
})
.and_then(|filter_array| filter_record_batch(batch, filter_array))
})
}
impl Stream for FilterExecStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.input.poll_next_unpin(cx).map(|x| match x {
Some(Ok(batch)) => {
let timer = self.baseline_metrics.elapsed_compute().timer();
let filtered_batch = batch_filter(&batch, &self.predicate);
timer.done();
Some(filtered_batch)
}
other => other,
});
self.baseline_metrics.record_poll(poll)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}
impl RecordBatchStream for FilterExecStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::object_store::local::LocalFileSystem;
use crate::physical_plan::expressions::*;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
use crate::scalar::ScalarValue;
use crate::test;
use crate::test_util;
use crate::{logical_plan::Operator, physical_plan::collect};
use std::iter::Iterator;
#[tokio::test]
async fn simple_predicate() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let schema = test_util::aggr_test_schema();
let partitions = 4;
let (_, files) =
test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
let csv = CsvExec::new(
FileScanConfig {
object_store: Arc::new(LocalFileSystem {}),
file_schema: Arc::clone(&schema),
file_groups: files,
statistics: Statistics::default(),
projection: None,
limit: None,
table_partition_cols: vec![],
},
true,
b',',
);
let predicate: Arc<dyn PhysicalExpr> = binary(
binary(
col("c2", &schema)?,
Operator::Gt,
lit(ScalarValue::from(1u32)),
&schema,
)?,
Operator::And,
binary(
col("c2", &schema)?,
Operator::Lt,
lit(ScalarValue::from(4u32)),
&schema,
)?,
&schema,
)?;
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?);
let results = collect(filter, runtime).await?;
results
.iter()
.for_each(|batch| assert_eq!(13, batch.num_columns()));
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(41, row_count);
Ok(())
}
}