use std::collections::HashMap;
use std::vec;
use arrow_schema::{
DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
};
use datafusion_common::{
exec_err, internal_err, plan_err, Column, DFSchemaRef, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::builder::get_struct_unnested_columns;
use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
use datafusion_expr::{
col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
};
use indexmap::IndexMap;
use sqlparser::ast::{Ident, Value};
pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
expr.clone()
.transform_up(|nested_expr| {
match nested_expr {
Expr::Column(col) => {
let (qualifier, field) =
plan.schema().qualified_field_from_column(&col)?;
Ok(Transformed::yes(Expr::Column(Column::from((
qualifier, field,
)))))
}
_ => {
Ok(Transformed::no(nested_expr))
}
}
})
.data()
}
pub(crate) fn rebase_expr(
expr: &Expr,
base_exprs: &[Expr],
plan: &LogicalPlan,
) -> Result<Expr> {
expr.clone()
.transform_down(|nested_expr| {
if base_exprs.contains(&nested_expr) {
Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
} else {
Ok(Transformed::no(nested_expr))
}
})
.data()
}
pub(crate) fn check_columns_satisfy_exprs(
columns: &[Expr],
exprs: &[Expr],
message_prefix: &str,
) -> Result<()> {
columns.iter().try_for_each(|c| match c {
Expr::Column(_) => Ok(()),
_ => internal_err!("Expr::Column are required"),
})?;
let column_exprs = find_column_exprs(exprs);
for e in &column_exprs {
match e {
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
for exprs in lists_of_exprs {
for e in exprs {
check_column_satisfies_expr(columns, e, message_prefix)?;
}
}
}
_ => check_column_satisfies_expr(columns, e, message_prefix)?,
}
}
Ok(())
}
fn check_column_satisfies_expr(
columns: &[Expr],
expr: &Expr,
message_prefix: &str,
) -> Result<()> {
if !columns.contains(expr) {
return plan_err!(
"{}: Expression {} could not be resolved from available columns: {}",
message_prefix,
expr,
expr_vec_fmt!(columns)
);
}
Ok(())
}
pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
exprs
.iter()
.filter_map(|expr| match expr {
Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
_ => None,
})
.collect::<HashMap<String, Expr>>()
}
pub(crate) fn resolve_positions_to_exprs(
expr: Expr,
select_exprs: &[Expr],
) -> Result<Expr> {
match expr {
Expr::Literal(ScalarValue::Int64(Some(position)))
if position > 0_i64 && position <= select_exprs.len() as i64 =>
{
let index = (position - 1) as usize;
let select_expr = &select_exprs[index];
Ok(match select_expr {
Expr::Alias(Alias { expr, .. }) => *expr.clone(),
_ => select_expr.clone(),
})
}
Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!(
"Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
position, select_exprs.len()
),
_ => Ok(expr),
}
}
pub(crate) fn resolve_aliases_to_exprs(
expr: Expr,
aliases: &HashMap<String, Expr>,
) -> Result<Expr> {
expr.transform_up(|nested_expr| match nested_expr {
Expr::Column(c) if c.relation.is_none() => {
if let Some(aliased_expr) = aliases.get(&c.name) {
Ok(Transformed::yes(aliased_expr.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
}
_ => Ok(Transformed::no(nested_expr)),
})
.data()
}
pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
let all_partition_keys = window_exprs
.iter()
.map(|expr| match expr {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by),
Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => {
Ok(partition_by)
}
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
},
expr => exec_err!("Impossibly got non-window expr {expr:?}"),
})
.collect::<Result<Vec<_>>>()?;
let result = all_partition_keys
.iter()
.min_by_key(|s| s.len())
.ok_or_else(|| {
DataFusionError::Execution("No window expressions found".to_owned())
})?;
Ok(result)
}
pub(crate) fn make_decimal_type(
precision: Option<u64>,
scale: Option<u64>,
) -> Result<DataType> {
let (precision, scale) = match (precision, scale) {
(Some(p), Some(s)) => (p as u8, s as i8),
(Some(p), None) => (p as u8, 0),
(None, Some(_)) => {
return plan_err!("Cannot specify only scale for decimal data type")
}
(None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
};
if precision == 0
|| precision > DECIMAL256_MAX_PRECISION
|| scale.unsigned_abs() > precision
{
plan_err!(
"Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
)
} else if precision > DECIMAL128_MAX_PRECISION
&& precision <= DECIMAL256_MAX_PRECISION
{
Ok(DataType::Decimal256(precision, scale))
} else {
Ok(DataType::Decimal128(precision, scale))
}
}
pub(crate) fn normalize_ident(id: Ident) -> String {
match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
}
}
pub(crate) fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::SingleQuotedString(s) => Some(s.to_string()),
Value::DollarQuotedString(s) => Some(s.to_string()),
Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
Value::UnicodeStringLiteral(s) => Some(s.to_string()),
Value::EscapedStringLiteral(s) => Some(s.to_string()),
Value::DoubleQuotedString(_)
| Value::NationalStringLiteral(_)
| Value::SingleQuotedByteStringLiteral(_)
| Value::DoubleQuotedByteStringLiteral(_)
| Value::TripleSingleQuotedString(_)
| Value::TripleDoubleQuotedString(_)
| Value::TripleSingleQuotedByteStringLiteral(_)
| Value::TripleDoubleQuotedByteStringLiteral(_)
| Value::SingleQuotedRawStringLiteral(_)
| Value::DoubleQuotedRawStringLiteral(_)
| Value::TripleSingleQuotedRawStringLiteral(_)
| Value::TripleDoubleQuotedRawStringLiteral(_)
| Value::HexStringLiteral(_)
| Value::Null
| Value::Placeholder(_) => None,
}
}
pub(crate) fn rewrite_recursive_unnests_bottom_up(
input: &LogicalPlan,
unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
inner_projection_exprs: &mut Vec<Expr>,
original_exprs: &[Expr],
) -> Result<Vec<Expr>> {
Ok(original_exprs
.iter()
.map(|expr| {
rewrite_recursive_unnest_bottom_up(
input,
unnest_placeholder_columns,
inner_projection_exprs,
expr,
)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>())
}
struct RecursiveUnnestRewriter<'a> {
input_schema: &'a DFSchemaRef,
root_expr: &'a Expr,
top_most_unnest: Option<Unnest>,
consecutive_unnest: Vec<Option<Unnest>>,
inner_projection_exprs: &'a mut Vec<Expr>,
columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
transformed_root_exprs: Option<Vec<Expr>>,
}
impl<'a> RecursiveUnnestRewriter<'a> {
fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
self.consecutive_unnest
.iter()
.rev()
.skip_while(|item| item.is_none())
.take_while(|item| item.is_some())
.to_owned()
.cloned()
.map(|item| item.unwrap())
.collect()
}
fn transform(
&mut self,
level: usize,
alias_name: String,
expr_in_unnest: &Expr,
struct_allowed: bool,
) -> Result<Vec<Expr>> {
let inner_expr_name = expr_in_unnest.schema_name().to_string();
let placeholder_name = format!("unnest_placeholder({})", inner_expr_name);
let post_unnest_name =
format!("unnest_placeholder({},depth={})", inner_expr_name, level);
let placeholder_column = Column::from_name(placeholder_name.clone());
let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?;
match data_type {
DataType::Struct(inner_fields) => {
if !struct_allowed {
return internal_err!("unnest on struct can only be applied at the root level of select expression");
}
push_projection_dedupl(
self.inner_projection_exprs,
expr_in_unnest.clone().alias(placeholder_name.clone()),
);
self.columns_unnestings
.insert(Column::from_name(placeholder_name.clone()), None);
Ok(
get_struct_unnested_columns(&placeholder_name, &inner_fields)
.into_iter()
.map(Expr::Column)
.collect(),
)
}
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_) => {
push_projection_dedupl(
self.inner_projection_exprs,
expr_in_unnest.clone().alias(placeholder_name.clone()),
);
let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
let list_unnesting = self
.columns_unnestings
.entry(placeholder_column)
.or_insert(Some(vec![]));
let unnesting = ColumnUnnestList {
output_column: Column::from_name(post_unnest_name),
depth: level,
};
let list_unnestings = list_unnesting.as_mut().unwrap();
if !list_unnestings.contains(&unnesting) {
list_unnestings.push(unnesting);
}
Ok(vec![post_unnest_expr])
}
_ => {
internal_err!("unnest on non-list or struct type is not supported")
}
}
}
}
impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> {
type Node = Expr;
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if let Expr::Unnest(ref unnest_expr) = expr {
let (data_type, _) =
unnest_expr.expr.data_type_and_nullable(self.input_schema)?;
self.consecutive_unnest.push(Some(unnest_expr.clone()));
if let DataType::Struct(_) = data_type {
self.consecutive_unnest.push(None);
}
if self.top_most_unnest.is_none() {
self.top_most_unnest = Some(unnest_expr.clone());
}
Ok(Transformed::no(expr))
} else {
self.consecutive_unnest.push(None);
Ok(Transformed::no(expr))
}
}
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
if let Expr::Unnest(ref traversing_unnest) = expr {
if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
self.top_most_unnest = None;
}
let unnest_stack = self.get_latest_consecutive_unnest();
if traversing_unnest == unnest_stack.last().unwrap() {
let most_inner = unnest_stack.first().unwrap();
let inner_expr = most_inner.expr.as_ref();
let unnest_recursion = unnest_stack.len();
let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
let mut transformed_exprs = self.transform(
unnest_recursion,
expr.schema_name().to_string(),
inner_expr,
struct_allowed,
)?;
if struct_allowed {
self.transformed_root_exprs = Some(transformed_exprs.clone());
}
return Ok(Transformed::new(
transformed_exprs.swap_remove(0),
true,
TreeNodeRecursion::Continue,
));
}
} else {
self.consecutive_unnest.push(None);
}
if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
push_projection_dedupl(self.inner_projection_exprs, expr.clone());
}
Ok(Transformed::no(expr))
}
}
fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
let schema_name = expr.schema_name().to_string();
if !projection
.iter()
.any(|e| e.schema_name().to_string() == schema_name)
{
projection.push(expr);
}
}
pub(crate) fn rewrite_recursive_unnest_bottom_up(
input: &LogicalPlan,
unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
inner_projection_exprs: &mut Vec<Expr>,
original_expr: &Expr,
) -> Result<Vec<Expr>> {
let mut rewriter = RecursiveUnnestRewriter {
input_schema: input.schema(),
root_expr: original_expr,
top_most_unnest: None,
consecutive_unnest: vec![],
inner_projection_exprs,
columns_unnestings: unnest_placeholder_columns,
transformed_root_exprs: None,
};
let Transformed {
data: transformed_expr,
transformed,
tnr: _,
} = original_expr.clone().rewrite(&mut rewriter)?;
if !transformed {
if matches!(&transformed_expr, Expr::Column(_))
|| matches!(&transformed_expr, Expr::Wildcard { .. })
{
push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
Ok(vec![transformed_expr])
} else {
let column_name = transformed_expr.schema_name().to_string();
push_projection_dedupl(inner_projection_exprs, transformed_expr);
Ok(vec![Expr::Column(Column::from_name(column_name))])
}
} else {
if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
return Ok(transformed_root_exprs);
}
Ok(vec![transformed_expr])
}
}
#[cfg(test)]
mod tests {
use std::{ops::Add, sync::Arc};
use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
use arrow_schema::Fields;
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::{
col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions_aggregate::expr_fn::count;
use indexmap::IndexMap;
use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
fn column_unnests_eq(
l: Vec<&str>,
r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
) {
let r_formatted: Vec<String> = r
.iter()
.map(|i| match i.1 {
None => format!("{}", i.0),
Some(vec) => format!(
"{}=>[{}]",
i.0,
vec.iter()
.map(|i| format!("{}", i))
.collect::<Vec<String>>()
.join(", ")
),
})
.collect();
let l_formatted: Vec<String> = l.iter().map(|i| i.to_string()).collect();
assert_eq!(l_formatted, r_formatted);
}
#[test]
fn test_transform_bottom_unnest_recursive() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"3d_col",
ArrowDataType::List(Arc::new(Field::new(
"2d_col",
ArrowDataType::List(Arc::new(Field::new(
"elements",
ArrowDataType::Int64,
true,
))),
true,
))),
true,
),
Field::new("i64_col", ArrowDataType::Int64, true),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
let original_expr = unnest(unnest(col("3d_col")))
.add(unnest(unnest(col("3d_col"))))
.add(col("i64_col"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
transformed_exprs,
vec![col("unnest_placeholder(3d_col,depth=2)")
.alias("UNNEST(UNNEST(3d_col))")
.add(
col("unnest_placeholder(3d_col,depth=2)")
.alias("UNNEST(UNNEST(3d_col))")
)
.add(col("i64_col"))]
);
column_unnests_eq(
vec![
"unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]",
],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![
col("3d_col").alias("unnest_placeholder(3d_col)"),
col("i64_col")
]
);
let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr_2,
)?;
assert_eq!(
transformed_exprs,
vec![
(col("unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
.alias("2d_col")
]
);
column_unnests_eq(
vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![
col("3d_col").alias("unnest_placeholder(3d_col)"),
col("i64_col")
]
);
Ok(())
}
#[test]
fn test_transform_bottom_unnest() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"struct_col",
ArrowDataType::Struct(Fields::from(vec![
Field::new("field1", ArrowDataType::Int32, false),
Field::new("field2", ArrowDataType::Int32, false),
])),
false,
),
Field::new(
"array_col",
ArrowDataType::List(Arc::new(Field::new(
"item",
ArrowDataType::Int64,
true,
))),
true,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
let original_expr = unnest(col("struct_col"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
assert_eq!(
transformed_exprs,
vec![
col("unnest_placeholder(struct_col).field1"),
col("unnest_placeholder(struct_col).field2"),
]
);
column_unnests_eq(
vec!["unnest_placeholder(struct_col)"],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_col").alias("unnest_placeholder(struct_col)"),]
);
let original_expr = unnest(col("array_col")).add(lit(1i64));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&original_expr,
)?;
column_unnests_eq(
vec![
"unnest_placeholder(struct_col)",
"unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
assert_eq!(
transformed_exprs,
vec![col("unnest_placeholder(array_col,depth=1)")
.alias("UNNEST(array_col)")
.add(lit(1i64))]
);
assert_eq!(
inner_projection_exprs,
vec![
col("struct_col").alias("unnest_placeholder(struct_col)"),
col("array_col").alias("unnest_placeholder(array_col)")
]
);
Ok(())
}
#[test]
fn test_transform_non_consecutive_unnests() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"struct_list",
ArrowDataType::List(Arc::new(Field::new(
"element",
ArrowDataType::Struct(Fields::from(vec![
Field::new(
"subfield1",
ArrowDataType::List(Arc::new(Field::new(
"i64_element",
ArrowDataType::Int64,
true,
))),
true,
),
Field::new(
"subfield2",
ArrowDataType::List(Arc::new(Field::new(
"utf8_element",
ArrowDataType::Utf8,
true,
))),
true,
),
])),
true,
))),
true,
),
Field::new("int_col", ArrowDataType::Int32, false),
]);
let dfschema = DFSchema::try_from(schema)?;
let input = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(dfschema),
});
let mut unnest_placeholder_columns = IndexMap::new();
let mut inner_projection_exprs = vec![];
let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&select_expr1,
)?;
assert_eq!(
transformed_exprs,
vec![unnest(
col("unnest_placeholder(struct_list,depth=1)")
.alias("UNNEST(struct_list)")
.field("subfield1")
)]
);
column_unnests_eq(
vec![
"unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_list").alias("unnest_placeholder(struct_list)")]
);
let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
let transformed_exprs = rewrite_recursive_unnest_bottom_up(
&input,
&mut unnest_placeholder_columns,
&mut inner_projection_exprs,
&select_expr2,
)?;
assert_eq!(
transformed_exprs,
vec![unnest(
col("unnest_placeholder(struct_list,depth=1)")
.alias("UNNEST(struct_list)")
.field("subfield2")
)]
);
column_unnests_eq(
vec![
"unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]",
],
&unnest_placeholder_columns,
);
assert_eq!(
inner_projection_exprs,
vec![col("struct_list").alias("unnest_placeholder(struct_list)")]
);
Ok(())
}
#[test]
fn test_resolve_positions_to_exprs() -> Result<()> {
let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
assert_eq!(resolved, col("c1"));
let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
)));
let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
assert!(resolved.is_err_and(|e| e.message().contains(
"Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
)));
let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
assert_eq!(resolved, lit("text"));
let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
assert_eq!(resolved, col("fake"));
Ok(())
}
}