use std::sync::{Arc, Mutex};
use crate::error::{ExecutionError, Result};
use crate::execution::physical_plan::{
BatchIterator, ExecutionPlan, Partition, PhysicalExpr,
};
use arrow::array::BooleanArray;
use arrow::compute::filter;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
pub struct SelectionExec {
expr: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
}
impl SelectionExec {
pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
Ok(Self {
expr: expr.clone(),
input: input.clone(),
})
}
}
impl ExecutionPlan for SelectionExec {
fn schema(&self) -> Arc<Schema> {
self.input.schema()
}
fn partitions(&self) -> Result<Vec<Arc<dyn Partition>>> {
let partitions: Vec<Arc<dyn Partition>> = self
.input
.partitions()?
.iter()
.map(|p| {
let expr = self.expr.clone();
let partition: Arc<dyn Partition> = Arc::new(SelectionPartition {
schema: self.input.schema(),
expr,
input: p.clone() as Arc<dyn Partition>,
});
partition
})
.collect();
Ok(partitions)
}
}
struct SelectionPartition {
schema: Arc<Schema>,
expr: Arc<dyn PhysicalExpr>,
input: Arc<dyn Partition>,
}
impl Partition for SelectionPartition {
fn execute(&self) -> Result<Arc<Mutex<dyn BatchIterator>>> {
Ok(Arc::new(Mutex::new(SelectionIterator {
schema: self.schema.clone(),
expr: self.expr.clone(),
input: self.input.execute()?,
})))
}
}
struct SelectionIterator {
schema: Arc<Schema>,
expr: Arc<dyn PhysicalExpr>,
input: Arc<Mutex<dyn BatchIterator>>,
}
impl BatchIterator for SelectionIterator {
fn schema(&self) -> Arc<Schema> {
self.schema.clone()
}
fn next(&mut self) -> Result<Option<RecordBatch>> {
let mut input = self.input.lock().unwrap();
match input.next()? {
Some(batch) => {
let predicate_result = self.expr.evaluate(&batch)?;
if let Some(f) = predicate_result.as_any().downcast_ref::<BooleanArray>()
{
let mut filtered_arrays = vec![];
for i in 0..batch.num_columns() {
let array = batch.column(i);
let filtered_array = filter(array.as_ref(), f)?;
filtered_arrays.push(filtered_array);
}
Ok(Some(RecordBatch::try_new(
batch.schema().clone(),
filtered_arrays,
)?))
} else {
Err(ExecutionError::InternalError(
"Predicate evaluated to non-boolean value".to_string(),
))
}
}
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::physical_plan::csv::CsvExec;
use crate::execution::physical_plan::expressions::*;
use crate::execution::physical_plan::ExecutionPlan;
use crate::logicalplan::{Operator, ScalarValue};
use crate::test;
use std::iter::Iterator;
#[test]
fn simple_predicate() -> Result<()> {
let schema = test::aggr_test_schema();
let partitions = 4;
let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
let csv = CsvExec::try_new(&path, schema, true, None, 1024)?;
let predicate: Arc<dyn PhysicalExpr> = binary(
binary(col(1), Operator::Gt, lit(ScalarValue::UInt32(1))),
Operator::And,
binary(col(1), Operator::Lt, lit(ScalarValue::UInt32(4))),
);
let selection: Arc<dyn ExecutionPlan> =
Arc::new(SelectionExec::try_new(predicate, Arc::new(csv))?);
let results = test::execute(selection.as_ref())?;
results
.iter()
.for_each(|batch| assert_eq!(13, batch.num_columns()));
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(41, row_count);
Ok(())
}
}