#[cfg(feature = "parquet")]
mod parquet;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::file_format::csv::CsvFormatFactory;
use crate::datasource::file_format::format_as_file_type;
use crate::datasource::file_format::json::JsonFormatFactory;
use crate::datasource::{
provider_as_source, DefaultTableSource, MemTable, TableProvider,
};
use crate::error::Result;
use crate::execution::context::{SessionState, TaskContext};
use crate::execution::FunctionRegistry;
use crate::logical_expr::utils::find_window_exprs;
use crate::logical_expr::{
col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder,
LogicalPlanBuilderOptions, Partitioning, TableType,
};
use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;
use std::any::Any;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::config::{CsvOptions, JsonOptions};
use datafusion_common::{
exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema,
DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions,
};
use datafusion_expr::select_expr::SelectExpr;
use datafusion_expr::{
case,
dml::InsertOp,
expr::{Alias, ScalarFunction},
is_null, lit,
utils::COUNT_STAR_EXPANSION,
ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_functions::core::coalesce;
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_sql::TableReference;
pub struct DataFrameWriteOptions {
insert_op: InsertOp,
single_file_output: bool,
partition_by: Vec<String>,
sort_by: Vec<SortExpr>,
}
impl DataFrameWriteOptions {
pub fn new() -> Self {
DataFrameWriteOptions {
insert_op: InsertOp::Append,
single_file_output: false,
partition_by: vec![],
sort_by: vec![],
}
}
pub fn with_insert_operation(mut self, insert_op: InsertOp) -> Self {
self.insert_op = insert_op;
self
}
pub fn with_single_file_output(mut self, single_file_output: bool) -> Self {
self.single_file_output = single_file_output;
self
}
pub fn with_partition_by(mut self, partition_by: Vec<String>) -> Self {
self.partition_by = partition_by;
self
}
pub fn with_sort_by(mut self, sort_by: Vec<SortExpr>) -> Self {
self.sort_by = sort_by;
self
}
}
impl Default for DataFrameWriteOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct DataFrame {
session_state: Box<SessionState>,
plan: LogicalPlan,
projection_requires_validation: bool,
}
impl DataFrame {
pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
Self {
session_state: Box::new(session_state),
plan,
projection_requires_validation: true,
}
}
pub fn parse_sql_expr(&self, sql: &str) -> Result<Expr> {
let df_schema = self.schema();
self.session_state.create_logical_expr(sql, df_schema)
}
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
self.session_state.create_physical_plan(&self.plan).await
}
pub fn select_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = fields
.into_iter()
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}
pub fn select_exprs(self, exprs: &[&str]) -> Result<DataFrame> {
let expr_list = exprs
.iter()
.map(|e| self.parse_sql_expr(e))
.collect::<Result<Vec<_>>>()?;
self.select(expr_list)
}
pub fn select(
self,
expr_list: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<DataFrame> {
let expr_list: Vec<SelectExpr> =
expr_list.into_iter().map(|e| e.into()).collect::<Vec<_>>();
let expressions = expr_list.iter().filter_map(|e| match e {
SelectExpr::Expression(expr) => Some(expr),
_ => None,
});
let window_func_exprs = find_window_exprs(expressions);
let plan = if window_func_exprs.is_empty() {
self.plan
} else {
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?
};
let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
projection_requires_validation: false,
})
}
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields_to_drop = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.filter(|r| r.is_ok())
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, _)| self.plan.schema().qualified_field(idx))
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}
pub fn unnest_columns(self, columns: &[&str]) -> Result<DataFrame> {
self.unnest_columns_with_options(columns, UnnestOptions::new())
}
pub fn unnest_columns_with_options(
self,
columns: &[&str],
options: UnnestOptions,
) -> Result<DataFrame> {
let columns = columns.iter().map(|c| Column::from(*c)).collect();
let plan = LogicalPlanBuilder::from(self.plan)
.unnest_columns_with_options(columns, options)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn filter(self, predicate: Expr) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.filter(predicate)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn aggregate(
self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(self.plan)
.with_options(options)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: !is_grouping_set,
})
}
pub fn window(self, window_exprs: Vec<Expr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.window(window_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.limit(skip, fetch)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub fn union(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn union_by_name(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union_by_name(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn union_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union_distinct(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn union_by_name_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.union_by_name_distinct(dataframe.plan)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn distinct(self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).distinct()?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn distinct_on(
self,
on_expr: Vec<Expr>,
select_expr: Vec<Expr>,
sort_expr: Option<Vec<SortExpr>>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.distinct_on(on_expr, select_expr, sort_expr)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub async fn describe(self) -> Result<Self> {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];
let original_schema_fields = self.schema().fields().iter();
let mut describe_schemas = vec![Field::new("describe", DataType::Utf8, false)];
describe_schemas.extend(original_schema_fields.clone().map(|field| {
if field.data_type().is_numeric() {
Field::new(field.name(), DataType::Float64, true)
} else {
Field::new(field.name(), DataType::Utf8, true)
}
}));
let describe_record_batch = vec![
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| count(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| {
sum(case(is_null(ident(f.name())))
.when(lit(true), lit(1))
.otherwise(lit(0))
.unwrap())
.alias(f.name())
})
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
})
.map(|f| min(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
})
.map(|f| max(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(ident(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
),
];
let mut array_ref_vec: Vec<ArrayRef> = vec![Arc::new(StringArray::from(
supported_describe_functions.clone(),
))];
for field in original_schema_fields {
let mut array_datas = vec![];
for result in describe_record_batch.iter() {
let array_ref = match result {
Ok(df) => {
let batches = df.clone().collect().await;
match batches {
Ok(batches)
if batches.len() == 1
&& batches[0]
.column_by_name(field.name())
.is_some() =>
{
let column =
batches[0].column_by_name(field.name()).unwrap();
if column.data_type().is_null() {
Arc::new(StringArray::from(vec!["null"]))
} else if field.data_type().is_numeric() {
cast(column, &DataType::Float64)?
} else {
cast(column, &DataType::Utf8)?
}
}
_ => Arc::new(StringArray::from(vec!["null"])),
}
}
Err(err)
if err.to_string().contains(
"Error during planning: \
Aggregate requires at least one grouping \
or aggregate expression",
) =>
{
Arc::new(StringArray::from(vec!["null"]))
}
Err(e) => return exec_err!("{}", e),
};
array_datas.push(array_ref);
}
array_ref_vec.push(concat(
array_datas
.iter()
.map(|af| af.as_ref())
.collect::<Vec<_>>()
.as_slice(),
)?);
}
let describe_record_batch =
RecordBatch::try_new(Arc::new(Schema::new(describe_schemas)), array_ref_vec)?;
let provider = MemTable::try_new(
describe_record_batch.schema(),
vec![vec![describe_record_batch]],
)?;
let plan = LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
None,
)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub fn sort_by(self, expr: Vec<Expr>) -> Result<DataFrame> {
self.sort(
expr.into_iter()
.map(|e| e.sort(true, false))
.collect::<Vec<SortExpr>>(),
)
}
pub fn sort(self, expr: Vec<SortExpr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub fn join(
self,
right: DataFrame,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
filter: Option<Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.join(
right.plan,
join_type,
(left_cols.to_vec(), right_cols.to_vec()),
filter,
)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn join_on(
self,
right: DataFrame,
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.join_on(right.plan, join_type, on_exprs)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn repartition(self, partitioning_scheme: Partitioning) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.repartition(partitioning_scheme)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub async fn count(self) -> Result<usize> {
let rows = self
.aggregate(
vec![],
vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))],
)?
.collect()
.await?;
let len = *rows
.first()
.and_then(|r| r.columns().first())
.and_then(|c| c.as_any().downcast_ref::<Int64Array>())
.and_then(|a| a.values().first())
.ok_or(DataFusionError::Internal(
"Unexpected output when collecting for count()".to_string(),
))? as usize;
Ok(len)
}
pub async fn collect(self) -> Result<Vec<RecordBatch>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
collect(plan, task_ctx).await
}
pub async fn show(self) -> Result<()> {
println!("{}", self.to_string().await?);
Ok(())
}
pub async fn to_string(self) -> Result<String> {
let options = self.session_state.config().options().format.clone();
let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?;
let results = self.collect().await?;
Ok(
pretty::pretty_format_batches_with_options(&results, &arrow_options)?
.to_string(),
)
}
pub async fn show_limit(self, num: usize) -> Result<()> {
let results = self.limit(0, Some(num))?.collect().await?;
Ok(pretty::print_batches(&results)?)
}
pub fn task_ctx(&self) -> TaskContext {
TaskContext::from(self.session_state.as_ref())
}
pub async fn execute_stream(self) -> Result<SendableRecordBatchStream> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
execute_stream(plan, task_ctx)
}
pub async fn collect_partitioned(self) -> Result<Vec<Vec<RecordBatch>>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
collect_partitioned(plan, task_ctx).await
}
pub async fn execute_stream_partitioned(
self,
) -> Result<Vec<SendableRecordBatchStream>> {
let task_ctx = Arc::new(self.task_ctx());
let plan = self.create_physical_plan().await?;
execute_stream_partitioned(plan, task_ctx)
}
pub fn schema(&self) -> &DFSchema {
self.plan.schema()
}
pub fn logical_plan(&self) -> &LogicalPlan {
&self.plan
}
pub fn into_parts(self) -> (SessionState, LogicalPlan) {
(*self.session_state, self.plan)
}
pub fn into_unoptimized_plan(self) -> LogicalPlan {
self.plan
}
pub fn into_optimized_plan(self) -> Result<LogicalPlan> {
self.session_state.optimize(&self.plan)
}
pub fn into_view(self) -> Arc<dyn TableProvider> {
Arc::new(DataFrameTableProvider { plan: self.plan })
}
pub fn explain(self, verbose: bool, analyze: bool) -> Result<DataFrame> {
let opts = ExplainOption::default()
.with_verbose(verbose)
.with_analyze(analyze);
self.explain_with_options(opts)
}
pub fn explain_with_options(
self,
explain_option: ExplainOption,
) -> Result<DataFrame> {
if matches!(self.plan, LogicalPlan::Explain(_)) {
return plan_err!("Nested EXPLAINs are not supported");
}
let plan = LogicalPlanBuilder::from(self.plan)
.explain_option_format(explain_option)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub fn registry(&self) -> &dyn FunctionRegistry {
self.session_state.as_ref()
}
pub fn intersect(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::intersect(left_plan, right_plan, true)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn intersect_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::intersect(left_plan, right_plan, false)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn except(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::except(left_plan, right_plan, true)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub fn except_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
let left_plan = self.plan;
let right_plan = dataframe.plan;
let plan = LogicalPlanBuilder::except(left_plan, right_plan, false)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
pub async fn write_table(
self,
table_name: &str,
write_options: DataFrameWriteOptions,
) -> Result<Vec<RecordBatch>, DataFusionError> {
let plan = if write_options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(write_options.sort_by)?
.build()?
};
let table_ref: TableReference = table_name.into();
let table_schema = self.session_state.schema_for_ref(table_ref.clone())?;
let target = match table_schema.table(table_ref.table()).await? {
Some(ref provider) => Ok(Arc::clone(provider)),
_ => plan_err!("No table named '{table_name}'"),
}?;
let target = Arc::new(DefaultTableSource::new(target));
let plan = LogicalPlanBuilder::insert_into(
plan,
table_ref,
target,
write_options.insert_op,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
}
.collect()
.await
}
pub async fn write_csv(
self,
path: &str,
options: DataFrameWriteOptions,
writer_options: Option<CsvOptions>,
) -> Result<Vec<RecordBatch>, DataFusionError> {
if options.insert_op != InsertOp::Append {
return not_impl_err!(
"{} is not implemented for DataFrame::write_csv.",
options.insert_op
);
}
let format = if let Some(csv_opts) = writer_options {
Arc::new(CsvFormatFactory::new_with_options(csv_opts))
} else {
Arc::new(CsvFormatFactory::new())
};
let file_type = format_as_file_type(format);
let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};
let plan = LogicalPlanBuilder::copy_to(
plan,
path.into(),
file_type,
HashMap::new(),
options.partition_by,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
}
.collect()
.await
}
pub async fn write_json(
self,
path: &str,
options: DataFrameWriteOptions,
writer_options: Option<JsonOptions>,
) -> Result<Vec<RecordBatch>, DataFusionError> {
if options.insert_op != InsertOp::Append {
return not_impl_err!(
"{} is not implemented for DataFrame::write_json.",
options.insert_op
);
}
let format = if let Some(json_opts) = writer_options {
Arc::new(JsonFormatFactory::new_with_options(json_opts))
} else {
Arc::new(JsonFormatFactory::new())
};
let file_type = format_as_file_type(format);
let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};
let plan = LogicalPlanBuilder::copy_to(
plan,
path.into(),
file_type,
Default::default(),
options.partition_by,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
}
.collect()
.await
}
pub fn with_column(self, name: &str, expr: Expr) -> Result<DataFrame> {
let window_func_exprs = find_window_exprs([&expr]);
let (window_fn_str, plan) = if window_func_exprs.is_empty() {
(None, self.plan)
} else {
(
Some(window_func_exprs[0].to_string()),
LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?,
)
};
let mut col_exists = false;
let new_column = expr.alias(name);
let mut fields: Vec<(Expr, bool)> = plan
.schema()
.iter()
.filter_map(|(qualifier, field)| {
if field.name() == name {
col_exists = true;
Some((new_column.clone(), true))
} else {
let e = col(Column::from((qualifier, field)));
window_fn_str
.as_ref()
.filter(|s| *s == &e.to_string())
.is_none()
.then_some((e, self.projection_requires_validation))
}
})
.collect();
if !col_exists {
fields.push((new_column, true));
}
let project_plan = LogicalPlanBuilder::from(plan)
.project_with_validation(fields)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
projection_requires_validation: false,
})
}
pub fn with_column_renamed(
self,
old_name: impl Into<String>,
new_name: &str,
) -> Result<DataFrame> {
let ident_opts = self
.session_state
.config_options()
.sql_parser
.enable_ident_normalization;
let old_column: Column = if ident_opts {
Column::from_qualified_name(old_name)
} else {
Column::from_qualified_name_ignore_case(old_name)
};
let (qualifier_rename, field_rename) =
match self.plan.schema().qualified_field_from_column(&old_column) {
Ok(qualifier_and_field) => qualifier_and_field,
Err(DataFusionError::SchemaError(e, _))
if matches!(*e, SchemaError::FieldNotFound { .. }) =>
{
return Ok(self);
}
Err(err) => return Err(err),
};
let projection = self
.plan
.schema()
.iter()
.map(|(qualifier, field)| {
if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename {
(
col(Column::from((qualifier, field)))
.alias_qualified(qualifier.cloned(), new_name),
false,
)
} else {
(col(Column::from((qualifier, field))), false)
}
})
.collect::<Vec<_>>();
let project_plan = LogicalPlanBuilder::from(self.plan)
.project_with_validation(projection)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan: project_plan,
projection_requires_validation: false,
})
}
pub fn with_param_values(self, query_values: impl Into<ParamValues>) -> Result<Self> {
let plan = self.plan.with_param_values(query_values)?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::new_with_state((*self.session_state).clone());
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
}
pub fn alias(self, alias: &str) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan).alias(alias)?.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}
pub fn fill_null(
&self,
value: ScalarValue,
columns: Vec<String>,
) -> Result<DataFrame> {
let cols = if columns.is_empty() {
self.logical_plan()
.schema()
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect()
} else {
self.find_columns(&columns)?
};
let projections = self
.logical_plan()
.schema()
.fields()
.iter()
.map(|field| {
if cols.contains(field) {
match value.clone().cast_to(field.data_type()) {
Ok(fill_value) => Expr::Alias(Alias {
expr: Box::new(Expr::ScalarFunction(ScalarFunction {
func: coalesce(),
args: vec![col(field.name()), lit(fill_value)],
})),
relation: None,
name: field.name().to_string(),
metadata: None,
}),
Err(_) => col(field.name()),
}
} else {
col(field.name())
}
})
.collect::<Vec<_>>();
self.clone().select(projections)
}
fn find_columns(&self, names: &[String]) -> Result<Vec<Field>> {
let schema = self.logical_plan().schema();
names
.iter()
.map(|name| {
schema
.field_with_name(None, name)
.cloned()
.map_err(|_| plan_datafusion_err!("Column '{}' not found", name))
})
.collect()
}
pub fn from_columns(columns: Vec<(&str, ArrayRef)>) -> Result<Self> {
let fields = columns
.iter()
.map(|(name, array)| Field::new(*name, array.data_type().clone(), true))
.collect::<Vec<_>>();
let arrays = columns
.into_iter()
.map(|(_, array)| array)
.collect::<Vec<_>>();
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(schema, arrays)?;
let ctx = SessionContext::new();
let df = ctx.read_batch(batch)?;
Ok(df)
}
}
#[macro_export]
macro_rules! dataframe {
() => {{
use std::sync::Arc;
use datafusion::prelude::SessionContext;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::Schema;
let ctx = SessionContext::new();
let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
ctx.read_batch(batch)
}};
($($name:expr => $data:expr),+ $(,)?) => {{
use datafusion::prelude::DataFrame;
use datafusion::common::test_util::IntoArrayRef;
let columns = vec![
$(
($name, $data.into_array_ref()),
)+
];
DataFrame::from_columns(columns)
}};
}
#[derive(Debug)]
struct DataFrameTableProvider {
plan: LogicalPlan,
}
#[async_trait]
impl TableProvider for DataFrameTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn get_logical_plan(&self) -> Option<Cow<LogicalPlan>> {
Some(Cow::Borrowed(&self.plan))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
}
fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
}
fn table_type(&self) -> TableType {
TableType::View
}
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut expr = LogicalPlanBuilder::from(self.plan.clone());
let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
if let Some(filter) = filter {
expr = expr.filter(filter)?
}
if let Some(p) = projection {
expr = expr.select(p.iter().copied())?
}
if let Some(l) = limit {
expr = expr.limit(0, Some(l))?
}
let plan = expr.build()?;
state.create_physical_plan(&plan).await
}
}