use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use super::error::Result;
use super::expression::RuntimeExpr;
use super::relation::Relation;
pub struct ProjectRelation {
schema: Arc<Schema>,
input: Rc<RefCell<Relation>>,
expr: Vec<RuntimeExpr>,
}
impl ProjectRelation {
pub fn new(input: Rc<RefCell<Relation>>, expr: Vec<RuntimeExpr>, schema: Arc<Schema>) -> Self {
ProjectRelation {
input,
expr,
schema,
}
}
}
impl Relation for ProjectRelation {
fn next(&mut self) -> Result<Option<RecordBatch>> {
match self.input.borrow_mut().next()? {
Some(batch) => {
let projected_columns: Result<Vec<ArrayRef>> =
self.expr.iter().map(|e| e.get_func()(&batch)).collect();
let projected_batch: RecordBatch =
RecordBatch::new(Arc::new(Schema::empty()), projected_columns?);
Ok(Some(projected_batch))
}
None => Ok(None),
}
}
fn schema(&self) -> &Arc<Schema> {
&self.schema
}
}
#[cfg(test)]
mod tests {
use super::super::super::logicalplan::Expr;
use super::super::context::ExecutionContext;
use super::super::datasource::CsvDataSource;
use super::super::expression;
use super::super::relation::DataSourceRelation;
use super::*;
use arrow::csv;
use arrow::datatypes::{DataType, Field, Schema};
use std::fs::File;
#[test]
fn project_all_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
]));
let file = File::open("test/data/people.csv").unwrap();
let arrow_csv_reader = csv::Reader::new(file, schema.clone(), true, 1024, None);
let ds = CsvDataSource::new(schema.clone(), arrow_csv_reader);
let relation = Rc::new(RefCell::new(DataSourceRelation::new(Rc::new(
RefCell::new(ds),
))));
let context = Rc::new(ExecutionContext::new());
let projection_expr =
vec![expression::compile_expr(context, &Expr::Column(0), schema.as_ref()).unwrap()];
let mut projection = ProjectRelation::new(relation, projection_expr, schema);
let batch = projection.next().unwrap().unwrap();
assert_eq!(1, batch.num_columns());
}
}