use crate::arrow::datatypes::Schema;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::TableProvider;
use crate::error::Result;
use crate::execution::{
context::{SessionState, TaskContext},
FunctionRegistry,
};
use crate::logical_expr::{utils::find_window_exprs, TableType};
use crate::logical_plan::{
col, DFSchema, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning,
};
use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet};
use crate::physical_plan::SendableRecordBatchStream;
use crate::physical_plan::{collect, collect_partitioned};
use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
use crate::scalar::ScalarValue;
use async_trait::async_trait;
use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
pub struct DataFrame {
session_state: Arc<RwLock<SessionState>>,
plan: LogicalPlan,
}
impl DataFrame {
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: &LogicalPlan) -> Self {
Self {
session_state,
plan: plan.clone(),
}
}
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
let state = self.session_state.read().clone();
state.create_physical_plan(&self.plan).await
}
pub fn select_columns(&self, columns: &[&str]) -> Result<Arc<DataFrame>> {
let fields = columns
.iter()
.map(|name| self.plan.schema().field_with_unqualified_name(name))
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
self.select(expr)
}
pub fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<DataFrame>> {
let window_func_exprs = find_window_exprs(&expr_list);
let plan = if window_func_exprs.is_empty() {
self.plan.clone()
} else {
LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)?
};
let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
}
pub fn filter(&self, predicate: Expr) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.filter(predicate)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn limit(
&self,
skip: Option<usize>,
fetch: Option<usize>,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.limit(skip, fetch)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn union(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.union(dataframe.plan.clone())?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn union_distinct(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::from(self.plan.clone())
.union_distinct(dataframe.plan.clone())?
.build()?,
)))
}
pub fn distinct(&self) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::from(self.plan.clone())
.distinct()?
.build()?,
)))
}
pub fn sort(&self, expr: Vec<Expr>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.sort(expr)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn join(
&self,
right: Arc<DataFrame>,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
filter: Option<Expr>,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.join(
&right.plan.clone(),
join_type,
(left_cols.to_vec(), right_cols.to_vec()),
filter,
)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn repartition(
&self,
partitioning_scheme: Partitioning,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.repartition(partitioning_scheme)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub async fn collect(&self) -> Result<Vec<RecordBatch>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
collect(plan, task_ctx).await
}
pub async fn show(&self) -> Result<()> {
let results = self.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub async fn show_limit(&self, num: usize) -> Result<()> {
let results = self.limit(None, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
execute_stream(plan, task_ctx).await
}
pub async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
collect_partitioned(plan, task_ctx).await
}
pub async fn execute_stream_partitioned(
&self,
) -> Result<Vec<SendableRecordBatchStream>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
execute_stream_partitioned(plan, task_ctx).await
}
pub fn schema(&self) -> &DFSchema {
self.plan.schema()
}
pub fn to_logical_plan(&self) -> Result<LogicalPlan> {
let state = self.session_state.read().clone();
state.optimize(&self.plan)
}
pub fn explain(&self, verbose: bool, analyze: bool) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.explain(verbose, analyze)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
let registry = self.session_state.read().clone();
Arc::new(registry)
}
pub fn intersect(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let left_plan = self.plan.clone();
let right_plan = dataframe.plan.clone();
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::intersect(left_plan, right_plan, true)?,
)))
}
pub fn except(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let left_plan = self.plan.clone();
let right_plan = dataframe.plan.clone();
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::except(left_plan, right_plan, true)?,
)))
}
pub async fn write_csv(&self, path: &str) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_csv(&state, plan, path).await
}
pub async fn write_parquet(
&self,
path: &str,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_parquet(&state, plan, path, writer_properties).await
}
pub async fn write_json(&self, path: impl AsRef<str>) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_json(&state, plan, path).await
}
pub fn with_column(&self, name: &str, expr: Expr) -> Result<Arc<DataFrame>> {
let window_func_exprs = find_window_exprs(&[expr.clone()]);
let plan = if window_func_exprs.is_empty() {
self.plan.clone()
} else {
LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)?
};
let new_column = Expr::Alias(Box::new(expr), name.to_string());
let mut col_exists = false;
let mut fields: Vec<Expr> = plan
.schema()
.fields()
.iter()
.map(|f| {
if f.name() == name {
col_exists = true;
new_column.clone()
} else {
col(f.name())
}
})
.collect();
if !col_exists {
fields.push(new_column);
}
let project_plan = LogicalPlanBuilder::from(plan).project(fields)?.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
}
}
#[async_trait]
impl TableProvider for DataFrame {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
}
fn table_type(&self) -> TableType {
TableType::View
}
async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let expr = projection
.as_ref()
.map_or_else(
|| {
Ok(Arc::new(Self::new(self.session_state.clone(), &self.plan))
as Arc<_>)
},
|projection| {
let schema = TableProvider::schema(self).project(projection)?;
let names = schema
.fields()
.iter()
.map(|field| field.name().as_str())
.collect::<Vec<_>>();
self.select_columns(names.as_slice())
},
)?
.filter(filters.iter().cloned().fold(
Expr::Literal(ScalarValue::Boolean(Some(true))),
|acc, new| acc.and(new),
))?;
Self::new(
self.session_state.clone(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(None, Some(n)))?
.plan
.clone(),
)
.create_physical_plan()
.await
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use crate::execution::options::CsvReadOptions;
use crate::physical_plan::ColumnarValue;
use crate::{assert_batches_sorted_eq, execution::context::SessionContext};
use crate::{logical_plan::*, test_util};
use arrow::datatypes::DataType;
use datafusion_expr::Volatility;
use datafusion_expr::{
BuiltInWindowFunction, ScalarFunctionImplementation, WindowFunction,
};
#[tokio::test]
async fn select_columns() -> Result<()> {
let t = test_table().await?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_expr() -> Result<()> {
let t = test_table().await?;
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_with_window_exprs() -> Result<()> {
let t = test_table().await?;
let first_row = Expr::WindowFunction {
fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
args: vec![col("aggregate_test_100.c1")],
partition_by: vec![col("aggregate_test_100.c2")],
order_by: vec![],
window_frame: None,
};
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"select c1, first_value(c1) over (partition by c2) from aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn aggregate() -> Result<()> {
let df = test_table().await?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
];
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
"| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |",
"| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |",
"| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
],
&df
);
Ok(())
}
#[tokio::test]
async fn join() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c3"])?;
let left_rows = left.collect().await?;
let right_rows = right.collect().await?;
let join = left.join(right, JoinType::Inner, &["c1"], &["c1"], None)?;
let join_rows = join.collect().await?;
assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(2008, join_rows.iter().map(|x| x.num_rows()).sum::<usize>());
Ok(())
}
#[tokio::test]
async fn limit() -> Result<()> {
let t = test_table().await?;
let t2 = t
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?;
let plan = t2.plan.clone();
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn explain() -> Result<()> {
let df = test_table().await?;
let df = df
.select_columns(&["c1", "c2", "c11"])?
.limit(None, Some(10))?
.explain(false, false)?;
let plan = df.plan.clone();
let sql_plan =
create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn registry() -> Result<()> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented"));
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
my_fn,
));
let df = ctx.table("aggregate_test_100")?;
let f = df.registry();
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let plan = df.plan.clone();
let sql_plan =
ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn sendable() {
let df = test_table().await.unwrap();
let task = tokio::task::spawn(async move {
df.select_columns(&["c1"])
.expect("should be usable in a task")
});
task.await.expect("task completed successfully");
}
#[tokio::test]
async fn intersect() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let plan = df.intersect(df.clone())?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
INTERSECT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn except() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let plan = df.except(df.clone())?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
EXCEPT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone()));
ctx.register_table("test_table", df_impl.clone())?;
let table = ctx.table("test_table")?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![sum(col("c12"))];
let df_results = &df_impl
.aggregate(group_expr.clone(), aggr_expr.clone())?
.collect()
.await?;
let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+-----------------------------+",
"| c1 | SUM(aggregate_test_100.c12) |",
"+----+-----------------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+-----------------------------+",
],
df_results
);
assert_batches_sorted_eq!(
vec![
"+----+---------------------+",
"| c1 | SUM(test_table.c12) |",
"+----+---------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+---------------------+",
],
table_results
);
Ok(())
}
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
}
async fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
ctx.create_logical_plan(sql)
}
async fn test_table_with_name(name: &str) -> Result<Arc<DataFrame>> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, name).await?;
ctx.table(name)
}
async fn test_table() -> Result<Arc<DataFrame>> {
test_table_with_name("aggregate_test_100").await
}
async fn register_aggregate_csv(
ctx: &mut SessionContext,
table_name: &str,
) -> Result<()> {
let schema = test_util::aggr_test_schema();
let testdata = crate::test_util::arrow_test_data();
ctx.register_csv(
table_name,
&format!("{}/csv/aggregate_test_100.csv", testdata),
CsvReadOptions::new().schema(schema.as_ref()),
)
.await?;
Ok(())
}
#[tokio::test]
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone()));
let df = &df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.with_column("sum", col("c2") + col("c3"))?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 3 | -12 | -9 |",
"| a | 3 | -72 | -69 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 14 | 17 |",
"| a | 3 | 17 | 20 |",
"+----+----+-----+-----+",
],
&df_results
);
let df_results_overwrite = df
.with_column("c1", col("c2") + col("c3"))?
.collect()
.await?;
assert_batches_sorted_eq!(
vec![
"+-----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+-----+----+-----+-----+",
"| -69 | 3 | -72 | -69 |",
"| -9 | 3 | -12 | -9 |",
"| 16 | 3 | 13 | 16 |",
"| 16 | 3 | 13 | 16 |",
"| 17 | 3 | 14 | 17 |",
"| 20 | 3 | 17 | 20 |",
"+-----+----+-----+-----+",
],
&df_results_overwrite
);
let df_results_overwrite_self =
df.with_column("c2", col("c2") + lit(1))?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 4 | -12 | -9 |",
"| a | 4 | -72 | -69 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 14 | 17 |",
"| a | 4 | 17 | 20 |",
"+----+----+-----+-----+",
],
&df_results_overwrite_self
);
Ok(())
}
}