pub use super::Operator;
use crate::error::Result;
use crate::logical_plan::ExprSchemable;
use crate::logical_plan::{DFField, DFSchema};
use arrow::datatypes::DataType;
use datafusion_common::DataFusionError;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
};
use std::collections::HashSet;
use std::sync::Arc;
pub struct CaseBuilder {
expr: Option<Box<Expr>>,
when_expr: Vec<Expr>,
then_expr: Vec<Expr>,
else_expr: Option<Box<Expr>>,
}
impl CaseBuilder {
pub fn when(&mut self, when: Expr, then: Expr) -> CaseBuilder {
self.when_expr.push(when);
self.then_expr.push(then);
CaseBuilder {
expr: self.expr.clone(),
when_expr: self.when_expr.clone(),
then_expr: self.then_expr.clone(),
else_expr: self.else_expr.clone(),
}
}
pub fn otherwise(&mut self, else_expr: Expr) -> Result<Expr> {
self.else_expr = Some(Box::new(else_expr));
self.build()
}
pub fn end(&self) -> Result<Expr> {
self.build()
}
fn build(&self) -> Result<Expr> {
let mut then_expr = self.then_expr.clone();
if let Some(e) = &self.else_expr {
then_expr.push(e.as_ref().to_owned());
}
let then_types: Vec<DataType> = then_expr
.iter()
.map(|e| match e {
Expr::Literal(_) => e.get_type(&DFSchema::empty()),
_ => Ok(DataType::Null),
})
.collect::<Result<Vec<_>>>()?;
if then_types.contains(&DataType::Null) {
} else {
let unique_types: HashSet<&DataType> = then_types.iter().collect();
if unique_types.len() != 1 {
return Err(DataFusionError::Plan(format!(
"CASE expression 'then' values had multiple data types: {:?}",
unique_types
)));
}
}
Ok(Expr::Case {
expr: self.expr.clone(),
when_then_expr: self
.when_expr
.iter()
.zip(self.then_expr.iter())
.map(|(w, t)| (Box::new(w.clone()), Box::new(t.clone())))
.collect(),
else_expr: self.else_expr.clone(),
})
}
}
pub fn case(expr: Expr) -> CaseBuilder {
CaseBuilder {
expr: Some(Box::new(expr)),
when_expr: vec![],
then_expr: vec![],
else_expr: None,
}
}
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder {
expr: None,
when_expr: vec![when],
then_expr: vec![then],
else_expr: None,
}
}
pub fn combine_filters(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
let combined_filter = filters
.iter()
.skip(1)
.fold(filters[0].clone(), |acc, filter| and(acc, filter.clone()));
Some(combined_filter)
}
pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
match e {
Expr::Column(_) => e,
Expr::Alias(inner_expr, name) => {
Expr::Alias(Box::new(columnize_expr(*inner_expr, input_schema)), name)
}
_ => match e.name(input_schema) {
Ok(name) => match input_schema.field_with_unqualified_name(&name) {
Ok(field) => Expr::Column(field.qualified_column()),
Err(_) => e,
},
Err(_) => e,
},
}
}
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
AggregateUDF::new(
name,
&Signature::exact(vec![input_type], volatility),
&return_type,
&accumulator,
&state_type,
)
}
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
input_schema: &DFSchema,
) -> Result<Vec<DFField>> {
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
}
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
Err(e) => Err(e),
}
}
#[cfg(test)]
mod tests {
use super::super::{col, lit, when};
use super::*;
use datafusion_expr::expr_fn::binary_expr;
#[test]
fn case_when_same_literal_then_types() -> Result<()> {
let _ = when(col("state").eq(lit("CO")), lit(303))
.when(col("state").eq(lit("NY")), lit(212))
.end()?;
Ok(())
}
#[test]
fn case_when_different_literal_then_types() {
let maybe_expr = when(col("state").eq(lit("CO")), lit(303))
.when(col("state").eq(lit("NY")), lit("212"))
.end();
assert!(maybe_expr.is_err());
}
#[test]
fn digest_function_definitions() {
if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
unreachable!();
}
}
#[test]
fn combine_zero_filters() {
let result = combine_filters(&[]);
assert_eq!(result, None);
}
#[test]
fn combine_one_filter() {
let filter = binary_expr(col("c1"), Operator::Lt, lit(1));
let result = combine_filters(&[filter.clone()]);
assert_eq!(result, Some(filter));
}
#[test]
fn combine_multiple_filters() {
let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1));
let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2));
let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3));
let result =
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
#[test]
fn expr_schema_nullability() {
let expr = col("foo").eq(lit(1));
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
assert!(expr
.nullable(&MockExprSchema::new().with_nullable(true))
.unwrap());
}
#[test]
fn expr_schema_data_type() {
let expr = col("foo");
assert_eq!(
DataType::Utf8,
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
.unwrap()
);
}
struct MockExprSchema {
nullable: bool,
data_type: DataType,
}
impl MockExprSchema {
fn new() -> Self {
Self {
nullable: false,
data_type: DataType::Null,
}
}
fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
}
impl ExprSchema for MockExprSchema {
fn nullable(&self, _col: &Column) -> Result<bool> {
Ok(self.nullable)
}
fn data_type(&self, _col: &Column) -> Result<&DataType> {
Ok(&self.data_type)
}
}
}