use std::sync::{Arc, Mutex};
use crate::error::{ExecutionError, Result};
use crate::execution::physical_plan::{ExecutionPlan, Partition, PhysicalExpr};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::{RecordBatch, RecordBatchReader};
pub struct ProjectionExec {
expr: Vec<Arc<dyn PhysicalExpr>>,
schema: SchemaRef,
input: Arc<dyn ExecutionPlan>,
}
impl ProjectionExec {
pub fn try_new(
expr: Vec<Arc<dyn PhysicalExpr>>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
let input_schema = input.schema();
let fields: Result<Vec<_>> = expr
.iter()
.map(|e| e.to_schema_field(&input_schema))
.collect();
let schema = Arc::new(Schema::new(fields?));
Ok(Self {
expr: expr.clone(),
schema,
input: input.clone(),
})
}
}
impl ExecutionPlan for ProjectionExec {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn partitions(&self) -> Result<Vec<Arc<dyn Partition>>> {
let partitions: Vec<Arc<dyn Partition>> = self
.input
.partitions()?
.iter()
.map(|p| {
let projection: Arc<dyn Partition> = Arc::new(ProjectionPartition {
schema: self.schema.clone(),
expr: self.expr.clone(),
input: p.clone() as Arc<dyn Partition>,
});
projection
})
.collect();
Ok(partitions)
}
}
struct ProjectionPartition {
schema: SchemaRef,
expr: Vec<Arc<dyn PhysicalExpr>>,
input: Arc<dyn Partition>,
}
impl Partition for ProjectionPartition {
fn execute(&self) -> Result<Arc<Mutex<dyn RecordBatchReader + Send + Sync>>> {
Ok(Arc::new(Mutex::new(ProjectionIterator {
schema: self.schema.clone(),
expr: self.expr.clone(),
input: self.input.execute()?,
})))
}
}
struct ProjectionIterator {
schema: SchemaRef,
expr: Vec<Arc<dyn PhysicalExpr>>,
input: Arc<Mutex<dyn RecordBatchReader + Send + Sync>>,
}
impl RecordBatchReader for ProjectionIterator {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn next_batch(&mut self) -> ArrowResult<Option<RecordBatch>> {
let mut input = self.input.lock().unwrap();
match input.next_batch()? {
Some(batch) => {
let arrays: Result<Vec<_>> =
self.expr.iter().map(|expr| expr.evaluate(&batch)).collect();
Ok(Some(RecordBatch::try_new(
self.schema.clone(),
arrays.map_err(ExecutionError::into_arrow_external_error)?,
)?))
}
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions};
use crate::execution::physical_plan::expressions::Column;
use crate::test;
#[test]
fn project_first_column() -> Result<()> {
let schema = test::aggr_test_schema();
let partitions = 4;
let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
let csv =
CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?;
let projection = ProjectionExec::try_new(
vec![Arc::new(Column::new(0, &schema.as_ref().field(0).name()))],
Arc::new(csv),
)?;
assert_eq!("c1", projection.schema.field(0).name().as_str());
let mut partition_count = 0;
let mut row_count = 0;
for partition in projection.partitions()? {
partition_count += 1;
let iterator = partition.execute()?;
let mut iterator = iterator.lock().unwrap();
while let Some(batch) = iterator.next_batch()? {
assert_eq!(1, batch.num_columns());
row_count += batch.num_rows();
}
}
assert_eq!(partitions, partition_count);
assert_eq!(100, row_count);
Ok(())
}
}