use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::execution::expression::CompiledExpr;
use crate::execution::relation::Relation;
pub(super) struct ProjectRelation {
schema: Arc<Schema>,
input: Rc<RefCell<Relation>>,
expr: Vec<CompiledExpr>,
}
impl ProjectRelation {
pub fn new(
input: Rc<RefCell<Relation>>,
expr: Vec<CompiledExpr>,
schema: Arc<Schema>,
) -> Self {
ProjectRelation {
input,
expr,
schema,
}
}
}
impl Relation for ProjectRelation {
fn next(&mut self) -> Result<Option<RecordBatch>> {
match self.input.borrow_mut().next()? {
Some(batch) => {
let projected_columns: Result<Vec<ArrayRef>> =
self.expr.iter().map(|e| e.invoke(&batch)).collect();
let schema = Schema::new(
self.expr
.iter()
.map(|e| Field::new(&e.name(), e.data_type().clone(), true))
.collect(),
);
let projected_batch: RecordBatch =
RecordBatch::try_new(Arc::new(schema), projected_columns?)?;
Ok(Some(projected_batch))
}
None => Ok(None),
}
}
fn schema(&self) -> &Arc<Schema> {
&self.schema
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::CsvBatchIterator;
use crate::execution::context::ExecutionContext;
use crate::execution::expression;
use crate::execution::relation::DataSourceRelation;
use crate::logicalplan::Expr;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Mutex;
#[test]
fn project_first_column() {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c3", DataType::Int16, false),
Field::new("c4", DataType::Int32, false),
Field::new("c5", DataType::Int64, false),
Field::new("c6", DataType::UInt8, false),
Field::new("c7", DataType::UInt16, false),
Field::new("c8", DataType::UInt32, false),
Field::new("c9", DataType::UInt64, false),
Field::new("c10", DataType::Float32, false),
Field::new("c11", DataType::Float64, false),
Field::new("c12", DataType::Utf8, false),
]));
let ds = CsvBatchIterator::new(
"../../testing/data/csv/aggregate_test_100.csv",
schema.clone(),
true,
&None,
1024,
);
let relation = Rc::new(RefCell::new(DataSourceRelation::new(Arc::new(
Mutex::new(ds),
))));
let context = ExecutionContext::new();
let projection_expr =
vec![
expression::compile_expr(&context, &Expr::Column(0), schema.as_ref())
.unwrap(),
];
let mut projection = ProjectRelation::new(relation, projection_expr, schema);
let batch = projection.next().unwrap().unwrap();
assert_eq!(1, batch.num_columns());
assert_eq!("c1", batch.schema().field(0).name());
}
}