use crate::arrow::datatypes::Schema;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::{MemTable, TableProvider};
use crate::error::Result;
use crate::execution::{
context::{SessionState, TaskContext},
FunctionRegistry,
};
use crate::logical_expr::{
col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder,
Partitioning, TableType,
};
use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet};
use crate::physical_plan::SendableRecordBatchStream;
use crate::physical_plan::{collect, collect_partitioned};
use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
use crate::prelude::SessionContext;
use async_trait::async_trait;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::TableProviderFilterPushDown;
use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
pub struct DataFrame {
session_state: Arc<RwLock<SessionState>>,
plan: LogicalPlan,
}
impl DataFrame {
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: &LogicalPlan) -> Self {
Self {
session_state,
plan: plan.clone(),
}
}
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
let state_cloned = {
let mut state = self.session_state.write();
state.execution_props.start_execution();
state.clone()
};
state_cloned.create_physical_plan(&self.plan).await
}
pub fn select_columns(&self, columns: &[&str]) -> Result<Arc<DataFrame>> {
let fields = columns
.iter()
.map(|name| self.plan.schema().field_with_unqualified_name(name))
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = fields
.iter()
.map(|f| Expr::Column(f.qualified_column()))
.collect();
self.select(expr)
}
pub fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<DataFrame>> {
let window_func_exprs = find_window_exprs(&expr_list);
let plan = if window_func_exprs.is_empty() {
self.plan.clone()
} else {
LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)?
};
let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
}
pub fn filter(&self, predicate: Expr) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.filter(predicate)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn limit(&self, skip: usize, fetch: Option<usize>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.limit(skip, fetch)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn union(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.union(dataframe.plan.clone())?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn union_distinct(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::from(self.plan.clone())
.union_distinct(dataframe.plan.clone())?
.build()?,
)))
}
pub fn distinct(&self) -> Result<Arc<DataFrame>> {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::from(self.plan.clone())
.distinct()?
.build()?,
)))
}
pub fn sort(&self, expr: Vec<Expr>) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.sort(expr)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn join(
&self,
right: Arc<DataFrame>,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
filter: Option<Expr>,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.join(
&right.plan.clone(),
join_type,
(left_cols.to_vec(), right_cols.to_vec()),
filter,
)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn repartition(
&self,
partitioning_scheme: Partitioning,
) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.repartition(partitioning_scheme)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub async fn collect(&self) -> Result<Vec<RecordBatch>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
collect(plan, task_ctx).await
}
pub async fn show(&self) -> Result<()> {
let results = self.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub async fn show_limit(&self, num: usize) -> Result<()> {
let results = self.limit(0, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
execute_stream(plan, task_ctx).await
}
pub async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
collect_partitioned(plan, task_ctx).await
}
pub async fn execute_stream_partitioned(
&self,
) -> Result<Vec<SendableRecordBatchStream>> {
let plan = self.create_physical_plan().await?;
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
execute_stream_partitioned(plan, task_ctx).await
}
pub fn schema(&self) -> &DFSchema {
self.plan.schema()
}
pub fn to_unoptimized_plan(&self) -> LogicalPlan {
self.plan.clone()
}
pub fn to_logical_plan(&self) -> Result<LogicalPlan> {
let state = self.session_state.read().clone();
state.optimize(&self.plan)
}
pub fn explain(&self, verbose: bool, analyze: bool) -> Result<Arc<DataFrame>> {
let plan = LogicalPlanBuilder::from(self.plan.clone())
.explain(verbose, analyze)?
.build()?;
Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan)))
}
pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
let registry = self.session_state.read().clone();
Arc::new(registry)
}
pub fn intersect(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let left_plan = self.plan.clone();
let right_plan = dataframe.plan.clone();
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::intersect(left_plan, right_plan, true)?,
)))
}
pub fn except(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
let left_plan = self.plan.clone();
let right_plan = dataframe.plan.clone();
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&LogicalPlanBuilder::except(left_plan, right_plan, true)?,
)))
}
pub async fn write_csv(&self, path: &str) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_csv(&state, plan, path).await
}
pub async fn write_parquet(
&self,
path: &str,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_parquet(&state, plan, path, writer_properties).await
}
pub async fn write_json(&self, path: impl AsRef<str>) -> Result<()> {
let plan = self.create_physical_plan().await?;
let state = self.session_state.read().clone();
plan_to_json(&state, plan, path).await
}
pub fn with_column(&self, name: &str, expr: Expr) -> Result<Arc<DataFrame>> {
let window_func_exprs = find_window_exprs(&[expr.clone()]);
let plan = if window_func_exprs.is_empty() {
self.plan.clone()
} else {
LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)?
};
let new_column = Expr::Alias(Box::new(expr), name.to_string());
let mut col_exists = false;
let mut fields: Vec<Expr> = plan
.schema()
.fields()
.iter()
.map(|f| {
if f.name() == name {
col_exists = true;
new_column.clone()
} else {
Expr::Column(Column {
relation: None,
name: f.name().into(),
})
}
})
.collect();
if !col_exists {
fields.push(new_column);
}
let project_plan = LogicalPlanBuilder::from(plan).project(fields)?.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
}
pub fn with_column_renamed(
&self,
old_name: &str,
new_name: &str,
) -> Result<Arc<DataFrame>> {
let mut projection = vec![];
let mut rename_applied = false;
for field in self.plan.schema().fields() {
let field_name = field.qualified_name();
if old_name == field_name {
projection.push(col(&field_name).alias(new_name));
rename_applied = true;
} else {
projection.push(col(&field_name));
}
}
if rename_applied {
let project_plan = LogicalPlanBuilder::from(self.plan.clone())
.project(projection)?
.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
} else {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&self.plan,
)))
}
}
pub async fn cache(&self) -> Result<Arc<DataFrame>> {
let mem_table = MemTable::try_new(
SchemaRef::from(self.schema().clone()),
self.collect_partitioned().await?,
)?;
SessionContext::with_state(self.session_state.read().clone())
.read_table(Arc::new(mem_table))
}
}
#[async_trait]
impl TableProvider for DataFrame {
fn as_any(&self) -> &dyn Any {
self
}
fn get_logical_plan(&self) -> Option<&LogicalPlan> {
Some(&self.plan)
}
fn supports_filter_pushdown(
&self,
_filter: &Expr,
) -> Result<TableProviderFilterPushDown> {
Ok(TableProviderFilterPushDown::Exact)
}
fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
}
fn table_type(&self) -> TableType {
TableType::View
}
async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut expr = projection
.as_ref()
.map_or_else(
|| {
Ok(Arc::new(Self::new(self.session_state.clone(), &self.plan))
as Arc<_>)
},
|projection| {
let schema = TableProvider::schema(self).project(projection)?;
let names = schema
.fields()
.iter()
.map(|field| field.name().as_str())
.collect::<Vec<_>>();
self.select_columns(names.as_slice())
},
)?;
let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
if let Some(filter) = filter {
expr = expr.filter(filter)?
}
Self::new(
self.session_state.clone(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(0, Some(n)))?
.plan
.clone(),
)
.create_physical_plan()
.await
}
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
use crate::execution::options::{CsvReadOptions, ParquetReadOptions};
use crate::physical_plan::ColumnarValue;
use crate::test_util;
use crate::test_util::parquet_test_data;
use crate::{assert_batches_sorted_eq, execution::context::SessionContext};
use arrow::array::Int32Array;
use arrow::datatypes::DataType;
use datafusion_expr::{
avg, cast, count, count_distinct, create_udf, lit, max, min, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFunction,
};
#[tokio::test]
async fn select_columns() -> Result<()> {
let t = test_table().await?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_expr() -> Result<()> {
let t = test_table().await?;
let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
let plan = t2.plan.clone();
let sql_plan = create_plan("SELECT c1, c2, c11 FROM aggregate_test_100").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_with_window_exprs() -> Result<()> {
let t = test_table().await?;
let first_row = Expr::WindowFunction {
fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
args: vec![col("aggregate_test_100.c1")],
partition_by: vec![col("aggregate_test_100.c2")],
order_by: vec![],
window_frame: None,
};
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
let sql_plan = create_plan(
"select c1, first_value(c1) over (partition by c2) from aggregate_test_100",
)
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn select_with_periods() -> Result<()> {
let array: Int32Array = [1, 10].into_iter().collect();
let batch =
RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)]).unwrap();
let ctx = SessionContext::new();
ctx.register_batch("t", batch).unwrap();
let df = ctx.table("t").unwrap().select_columns(&["f.c1"]).unwrap();
let df_results = df.collect().await.unwrap();
assert_batches_sorted_eq!(
vec!["+------+", "| f.c1 |", "+------+", "| 1 |", "| 10 |", "+------+",],
&df_results
);
Ok(())
}
#[tokio::test]
async fn aggregate() -> Result<()> {
let df = test_table().await?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
];
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
"| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |",
"| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |",
"| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
],
&df
);
Ok(())
}
#[tokio::test]
async fn join() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
let right = test_table_with_name("c2")
.await?
.select_columns(&["c1", "c3"])?;
let left_rows = left.collect().await?;
let right_rows = right.collect().await?;
let join = left.join(right, JoinType::Inner, &["c1"], &["c1"], None)?;
let join_rows = join.collect().await?;
assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::<usize>());
assert_eq!(2008, join_rows.iter().map(|x| x.num_rows()).sum::<usize>());
Ok(())
}
#[tokio::test]
async fn limit() -> Result<()> {
let t = test_table().await?;
let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(0, Some(10))?;
let plan = t2.plan.clone();
let sql_plan =
create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10").await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn explain() -> Result<()> {
let df = test_table().await?;
let df = df
.select_columns(&["c1", "c2", "c11"])?
.limit(0, Some(10))?
.explain(false, false)?;
let plan = df.plan.clone();
let sql_plan =
create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT 10")
.await?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn registry() -> Result<()> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented"));
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
my_fn,
));
let df = ctx.table("aggregate_test_100")?;
let f = df.registry();
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let plan = df.plan.clone();
let sql_plan =
ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?;
assert_same_plan(&plan, &sql_plan);
Ok(())
}
#[tokio::test]
async fn sendable() {
let df = test_table().await.unwrap();
let task = tokio::task::spawn(async move {
df.select_columns(&["c1"])
.expect("should be usable in a task")
});
task.await.expect("task completed successfully");
}
#[tokio::test]
async fn intersect() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let plan = df.intersect(df.clone())?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
INTERSECT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn except() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c3"])?;
let plan = df.except(df.clone())?;
let result = plan.plan.clone();
let expected = create_plan(
"SELECT c1, c3 FROM aggregate_test_100
EXCEPT ALL SELECT c1, c3 FROM aggregate_test_100",
)
.await?;
assert_same_plan(&result, &expected);
Ok(())
}
#[tokio::test]
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone()));
ctx.register_table("test_table", df_impl.clone())?;
let table = ctx.table("test_table")?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![sum(col("c12"))];
let df_results = &df_impl
.aggregate(group_expr.clone(), aggr_expr.clone())?
.collect()
.await?;
let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+-----------------------------+",
"| c1 | SUM(aggregate_test_100.c12) |",
"+----+-----------------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+-----------------------------+",
],
df_results
);
assert_batches_sorted_eq!(
vec![
"+----+---------------------+",
"| c1 | SUM(test_table.c12) |",
"+----+---------------------+",
"| a | 10.238448667882977 |",
"| b | 7.797734760124923 |",
"| c | 13.860958726523545 |",
"| d | 8.793968289758968 |",
"| e | 10.206140546981722 |",
"+----+---------------------+",
],
table_results
);
Ok(())
}
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
}
async fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
ctx.create_logical_plan(sql)
}
async fn test_table_with_name(name: &str) -> Result<Arc<DataFrame>> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, name).await?;
ctx.table(name)
}
async fn test_table() -> Result<Arc<DataFrame>> {
test_table_with_name("aggregate_test_100").await
}
async fn register_aggregate_csv(
ctx: &mut SessionContext,
table_name: &str,
) -> Result<()> {
let schema = test_util::aggr_test_schema();
let testdata = test_util::arrow_test_data();
ctx.register_csv(
table_name,
&format!("{}/csv/aggregate_test_100.csv", testdata),
CsvReadOptions::new().schema(schema.as_ref()),
)
.await?;
Ok(())
}
#[tokio::test]
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone()));
let df = &df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.with_column("sum", col("c2") + col("c3"))?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 3 | -12 | -9 |",
"| a | 3 | -72 | -69 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 13 | 16 |",
"| a | 3 | 14 | 17 |",
"| a | 3 | 17 | 20 |",
"+----+----+-----+-----+",
],
&df_results
);
let df_results_overwrite = df
.with_column("c1", col("c2") + col("c3"))?
.collect()
.await?;
assert_batches_sorted_eq!(
vec![
"+-----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+-----+----+-----+-----+",
"| -69 | 3 | -72 | -69 |",
"| -9 | 3 | -12 | -9 |",
"| 16 | 3 | 13 | 16 |",
"| 16 | 3 | 13 | 16 |",
"| 17 | 3 | 14 | 17 |",
"| 20 | 3 | 17 | 20 |",
"+-----+----+-----+-----+",
],
&df_results_overwrite
);
let df_results_overwrite_self =
df.with_column("c2", col("c2") + lit(1))?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+-----+",
"| c1 | c2 | c3 | sum |",
"+----+----+-----+-----+",
"| a | 4 | -12 | -9 |",
"| a | 4 | -72 | -69 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 13 | 16 |",
"| a | 4 | 14 | 17 |",
"| a | 4 | 17 | 20 |",
"+----+----+-----+-----+",
],
&df_results_overwrite_self
);
Ok(())
}
#[tokio::test]
async fn with_column_renamed() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(0, Some(1))?
.sort(vec![
col("c1").sort(true, true),
col("c2").sort(true, true),
col("c3").sort(true, true),
])?
.with_column("sum", col("c2") + col("c3"))?;
let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+----+-------+",
"| c1 | c2 | c3 | total |",
"+----+----+----+-------+",
"| a | 3 | 13 | 16 |",
"+----+----+----+-------+",
],
&df_sum_renamed
);
Ok(())
}
#[tokio::test]
async fn with_column_renamed_join() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
ctx.register_table("t1", df.clone())?;
ctx.register_table("t2", df)?;
let df = ctx
.table("t1")?
.join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)?
.sort(vec![
col("t1.c1").sort(true, true),
col("t1.c2").sort(true, true),
col("t1.c3").sort(true, true),
col("t2.c1").sort(true, true),
col("t2.c2").sort(true, true),
col("t2.c3").sort(true, true),
])?
.limit(0, Some(1))?;
let df_results = df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+----+----+-----+",
"| c1 | c2 | c3 | c1 | c2 | c3 |",
"+----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+----+----+-----+----+----+-----+",
],
&df_results
);
let df_renamed = df.with_column_renamed("t1.c1", "AAA")?;
assert_eq!("\
Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST\
\n Inner Join: t1.c1 = t2.c1\
\n TableScan: t1\
\n TableScan: t2",
format!("{:?}", df_renamed.to_unoptimized_plan())
);
assert_eq!("\
Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1\
\n Inner Join: t1.c1 = t2.c1\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3, alias=t1\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3, alias=t2\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]",
format!("{:?}", df_renamed.to_logical_plan()?)
);
let df_results = df_renamed.collect().await?;
assert_batches_sorted_eq!(
vec![
"+-----+----+-----+----+----+-----+",
"| AAA | c2 | c3 | c1 | c2 | c3 |",
"+-----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+-----+----+-----+----+----+-----+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn filter_pushdown_dataframe() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_parquet(
"test",
&format!("{}/alltypes_plain.snappy.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;
ctx.register_table("t1", ctx.table("test")?)?;
let df = ctx
.table("t1")?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;
let plan = df.explain(false, false)?.collect().await?;
let formatted = pretty::pretty_format_batches(&plan).unwrap().to_string();
assert!(formatted.contains("predicate=id_min@0 <= 1 AND 1 <= id_max@1"));
Ok(())
}
#[tokio::test]
async fn cast_expr_test() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
let df_results = df.collect().await?;
df.show().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+",
"| c2 | c3 | sum |",
"+----+----+-----+",
"| 2 | 1 | 3 |",
"+----+----+-----+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn row_writer_resize_test() -> Result<()> {
let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new(
"column_1",
DataType::Utf8,
false,
)]));
let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::StringArray::from(vec![
Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"),
Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"),
]))
],
)?;
let ctx = SessionContext::new();
ctx.register_batch("test", data)?;
let sql = r#"
SELECT
COUNT(1)
FROM
test
GROUP BY
column_1"#;
let df = ctx.sql(sql).await.unwrap();
df.show_limit(10).await.unwrap();
Ok(())
}
#[tokio::test]
async fn with_column_name() -> Result<()> {
let array: Int32Array = [1, 10].into_iter().collect();
let batch =
RecordBatch::try_from_iter(vec![("f.c1", Arc::new(array) as _)]).unwrap();
let ctx = SessionContext::new();
ctx.register_batch("t", batch).unwrap();
let df = ctx
.table("t")
.unwrap()
.with_column("f.c2", lit("hello"))
.unwrap();
let df_results = df.collect().await.unwrap();
assert_batches_sorted_eq!(
vec![
"+------+-------+",
"| f.c1 | f.c2 |",
"+------+-------+",
"| 1 | hello |",
"| 10 | hello |",
"+------+-------+",
],
&df_results
);
Ok(())
}
#[tokio::test]
async fn cache_test() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
let cached_df = df.cache().await?;
assert_eq!(
"TableScan: ?table? projection=[c2, c3, sum]",
format!("{:?}", cached_df.to_logical_plan()?)
);
let df_results = df.collect().await?;
let cached_df_results = cached_df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+",
"| c2 | c3 | sum |",
"+----+----+-----+",
"| 2 | 1 | 3 |",
"+----+----+-----+",
],
&cached_df_results
);
assert_eq!(&df_results, &cached_df_results);
Ok(())
}
}