use std::sync::Arc;
use crate::arrow::datatypes::DataType;
use crate::arrow::record_batch::RecordBatch;
use crate::error::{ExecutionError, Result};
use crate::execution::context::ExecutionContext;
use crate::logicalplan::LogicalPlanBuilder;
use crate::logicalplan::{Expr, LogicalPlan};
use crate::table::*;
use arrow::datatypes::Schema;
pub struct TableImpl {
plan: LogicalPlan,
}
impl TableImpl {
pub fn new(plan: &LogicalPlan) -> Self {
Self { plan: plan.clone() }
}
}
impl Table for TableImpl {
fn select_columns(&self, columns: Vec<&str>) -> Result<Arc<dyn Table>> {
let exprs = columns
.iter()
.map(|name| {
self.plan
.schema()
.index_of(name.to_owned())
.and_then(|i| Ok(Expr::Column(i)))
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
self.select(exprs)
}
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.project(expr_list)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}
fn filter(&self, expr: Expr) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan).filter(expr)?.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}
fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}
fn limit(&self, n: usize) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan).limit(n)?.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}
fn col(&self, name: &str) -> Result<Expr> {
Ok(Expr::Column(self.plan.schema().index_of(name)?))
}
fn min(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("MIN", expr)
}
fn max(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("MAX", expr)
}
fn sum(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("SUM", expr)
}
fn avg(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("AVG", expr)
}
fn count(&self, expr: &Expr) -> Result<Expr> {
self.aggregate_expr("COUNT", expr)
}
fn to_logical_plan(&self) -> LogicalPlan {
self.plan.clone()
}
fn collect(
&self,
ctx: &mut ExecutionContext,
batch_size: usize,
) -> Result<Vec<RecordBatch>> {
ctx.collect_plan(&self.plan.clone(), batch_size)
}
fn schema(&self) -> &Schema {
self.plan.schema().as_ref()
}
}
impl TableImpl {
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
match expr {
Expr::Column(i) => Ok(self.plan.schema().field(*i).data_type().clone()),
_ => Err(ExecutionError::General(format!(
"Could not determine data type for expr {:?}",
expr
))),
}
}
fn aggregate_expr(&self, name: &str, expr: &Expr) -> Result<Expr> {
let return_type = self.get_data_type(expr)?;
Ok(Expr::AggregateFunction {
name: name.to_string(),
args: vec![expr.clone()],
return_type,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::csv::CsvReadOptions;
use crate::execution::context::ExecutionContext;
use crate::test;
#[test]
fn select_columns() -> Result<()> {
let t = test_table()?;
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?;
let plan = t2.to_logical_plan();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn select_expr() -> Result<()> {
let t = test_table()?;
let t2 = t.select(vec![t.col("c1")?, t.col("c2")?, t.col("c11")?])?;
let plan = t2.to_logical_plan();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn aggregate() -> Result<()> {
let t = test_table()?;
let group_expr = vec![t.col("c1")?];
let c12 = t.col("c12")?;
let aggr_expr = vec![
t.min(&c12)?,
t.max(&c12)?,
t.avg(&c12)?,
t.sum(&c12)?,
t.count(&c12)?,
];
let t2 = t.aggregate(group_expr.clone(), aggr_expr.clone())?;
let plan = t2.to_logical_plan();
let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \
FROM aggregate_test_100 \
GROUP BY c1";
let sql_plan = create_plan(sql)?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn limit() -> Result<()> {
let t = test_table()?;
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?.limit(10)?;
let plan = t2.to_logical_plan();
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
}
fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
ctx.create_logical_plan(sql)
}
fn test_table() -> Result<Arc<dyn Table + 'static>> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
ctx.table("aggregate_test_100")
}
fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
let schema = test::aggr_test_schema();
let testdata = test::arrow_testdata_path();
ctx.register_csv(
"aggregate_test_100",
&format!("{}/csv/aggregate_test_100.csv", testdata),
CsvReadOptions::new().schema(&schema),
)?;
Ok(())
}
}