use std::any::Any;
use std::sync::Arc;
use super::SendableRecordBatchReader;
use crate::error::{ExecutionError, Result};
use crate::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr};
use arrow::array::BooleanArray;
use arrow::compute::filter;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::{RecordBatch, RecordBatchReader};
use async_trait::async_trait;
#[derive(Debug)]
pub struct FilterExec {
predicate: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
}
impl FilterExec {
pub fn try_new(
predicate: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
match predicate.data_type(input.schema().as_ref())? {
DataType::Boolean => Ok(Self {
predicate: predicate.clone(),
input: input.clone(),
}),
other => Err(ExecutionError::General(format!(
"Filter predicate must return boolean values, not {:?}",
other
))),
}
}
}
#[async_trait]
impl ExecutionPlan for FilterExec {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.input.schema()
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
}
fn output_partitioning(&self) -> Partitioning {
self.input.output_partitioning()
}
fn with_new_children(
&self,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
match children.len() {
1 => Ok(Arc::new(FilterExec::try_new(
self.predicate.clone(),
children[0].clone(),
)?)),
_ => Err(ExecutionError::General(
"FilterExec wrong number of children".to_string(),
)),
}
}
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchReader> {
Ok(Box::new(FilterExecIter {
schema: self.input.schema().clone(),
predicate: self.predicate.clone(),
input: self.input.execute(partition).await?,
}))
}
}
struct FilterExecIter {
schema: SchemaRef,
predicate: Arc<dyn PhysicalExpr>,
input: SendableRecordBatchReader,
}
impl Iterator for FilterExecIter {
type Item = ArrowResult<RecordBatch>;
fn next(&mut self) -> Option<ArrowResult<RecordBatch>> {
match self.input.next() {
Some(Ok(batch)) => {
Some(
self.predicate
.evaluate(&batch)
.map_err(ExecutionError::into_arrow_external_error)
.and_then(|array| {
array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or(
ExecutionError::InternalError(
"Filter predicate evaluated to non-boolean value"
.to_string(),
)
.into_arrow_external_error(),
)
.and_then(|predicate| {
batch
.columns()
.iter()
.map(|column| filter(column.as_ref(), predicate))
.collect::<ArrowResult<Vec<_>>>()
})
})
.and_then(|columns| {
RecordBatch::try_new(batch.schema().clone(), columns)
}),
)
}
other => other,
}
}
}
impl RecordBatchReader for FilterExecIter {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::Operator;
use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
use crate::physical_plan::expressions::*;
use crate::physical_plan::ExecutionPlan;
use crate::scalar::ScalarValue;
use crate::test;
use std::iter::Iterator;
#[tokio::test]
async fn simple_predicate() -> Result<()> {
let schema = test::aggr_test_schema();
let partitions = 4;
let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
let csv =
CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?;
let predicate: Arc<dyn PhysicalExpr> = binary(
binary(
col("c2"),
Operator::Gt,
lit(ScalarValue::from(1u32)),
&schema,
)?,
Operator::And,
binary(
col("c2"),
Operator::Lt,
lit(ScalarValue::from(4u32)),
&schema,
)?,
&schema,
)?;
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?);
let results = test::execute(filter).await?;
results
.iter()
.for_each(|batch| assert_eq!(13, batch.num_columns()));
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(41, row_count);
Ok(())
}
}