use std::sync::Arc;
use crate::arrow::record_batch::RecordBatch;
use crate::dataframe::*;
use crate::error::Result;
use crate::execution::context::{ExecutionContext, ExecutionContextState};
use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder};
use arrow::datatypes::Schema;
use async_trait::async_trait;
pub struct DataFrameImpl {
ctx_state: ExecutionContextState,
plan: LogicalPlan,
}
impl DataFrameImpl {
pub fn new(ctx_state: ExecutionContextState, plan: &LogicalPlan) -> Self {
Self {
ctx_state,
plan: plan.clone(),
}
}
}
#[async_trait]
impl DataFrame for DataFrameImpl {
fn select_columns(&self, columns: Vec<&str>) -> Result<Arc<dyn DataFrame>> {
let exprs = columns
.iter()
.map(|name| {
self.plan
.schema()
.index_of(name.to_owned())
.and_then(|_| Ok(col(name)))
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
self.select(exprs)
}
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.project(expr_list)?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.filter(predicate)?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn limit(&self, n: usize) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan).limit(n)?.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan).sort(expr)?.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn to_logical_plan(&self) -> LogicalPlan {
self.plan.clone()
}
async fn collect(&self) -> Result<Vec<RecordBatch>> {
let ctx = ExecutionContext::from(self.ctx_state.clone());
let plan = ctx.optimize(&self.plan)?;
let plan = ctx.create_physical_plan(&plan)?;
Ok(ctx.collect(plan).await?)
}
fn schema(&self) -> &Schema {
self.plan.schema()
}
fn explain(&self, verbose: bool) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.explain(verbose)?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
fn registry(&self) -> &dyn FunctionRegistry {
&self.ctx_state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::csv::CsvReadOptions;
use crate::execution::context::ExecutionContext;
use crate::logical_plan::*;
use crate::{physical_plan::functions::ScalarFunctionImplementation, test};
use arrow::{array::ArrayRef, datatypes::DataType};
#[test]
fn select_columns() -> Result<()> {
let t = test_table()?;
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?;
let plan = t2.to_logical_plan();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn select_expr() -> Result<()> {
let t = test_table()?;
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
let plan = t2.to_logical_plan();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn aggregate() -> Result<()> {
let df = test_table()?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
];
let df = df.aggregate(group_expr.clone(), aggr_expr.clone())?;
let plan = df.to_logical_plan();
let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \
FROM aggregate_test_100 \
GROUP BY c1";
let sql_plan = create_plan(sql)?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn limit() -> Result<()> {
let t = test_table()?;
let t2 = t.select_columns(vec!["c1", "c2", "c11"])?.limit(10)?;
let plan = t2.to_logical_plan();
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn explain() -> Result<()> {
let df = test_table()?;
let df = df
.select_columns(vec!["c1", "c2", "c11"])?
.limit(10)?
.explain(false)?;
let plan = df.to_logical_plan();
let sql_plan =
create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[test]
fn registry() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ArrayRef]| unimplemented!("my_fn is not implemented"));
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
my_fn,
));
let df = ctx.table("aggregate_test_100")?;
let f = df.registry();
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let plan = df.to_logical_plan();
let sql_plan =
ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
}
fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
ctx.create_logical_plan(sql)
}
fn test_table() -> Result<Arc<dyn DataFrame + 'static>> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
ctx.table("aggregate_test_100")
}
fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
let schema = test::aggr_test_schema();
let testdata = test::arrow_testdata_path();
ctx.register_csv(
"aggregate_test_100",
&format!("{}/csv/aggregate_test_100.csv", testdata),
CsvReadOptions::new().schema(&schema),
)?;
Ok(())
}
}