use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::sync::Arc;
pub fn col(ident: &str) -> Expr {
Expr::Column(ident.into())
}
pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(l),
op,
right: Box::new(r),
}
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::And,
right: Box::new(right),
}
}
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Or,
right: Box::new(right),
}
}
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Min,
distinct: false,
args: vec![expr],
}
}
pub fn max(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Max,
distinct: false,
args: vec![expr],
}
}
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Sum,
distinct: false,
args: vec![expr],
}
}
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Avg,
distinct: false,
args: vec![expr],
}
}
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
distinct: false,
args: vec![expr],
}
}
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
distinct: true,
args: vec![expr],
}
}
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList {
expr: Box::new(expr),
list,
negated,
}
}
pub fn concat(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Concat,
args: args.to_vec(),
}
}
pub fn concat_ws(sep: impl Into<String>, values: &[Expr]) -> Expr {
let mut args = vec![lit(sep.into())];
args.extend_from_slice(values);
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::ConcatWithSeparator,
args,
}
}
pub fn random() -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Random,
args: vec![],
}
}
pub fn approx_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxDistinct,
distinct: false,
args: vec![expr],
}
}
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxPercentileCont,
distinct: false,
args: vec![expr, percentile],
}
}
pub fn approx_percentile_cont_with_weight(
expr: Expr,
weight_expr: Expr,
percentile: Expr,
) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
distinct: false,
args: vec![expr, weight_expr, percentile],
}
}
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
Expr::Exists {
subquery: Subquery { subquery },
negated: false,
}
}
pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
Expr::Exists {
subquery: Subquery { subquery },
negated: true,
}
}
pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
Expr::InSubquery {
expr: Box::new(expr),
subquery: Subquery { subquery },
negated: false,
}
}
pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
Expr::InSubquery {
expr: Box::new(expr),
subquery: Subquery { subquery },
negated: true,
}
}
pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
Expr::ScalarSubquery(Subquery { subquery })
}
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
#[doc = concat!("Unary scalar function definition for ", stringify!($FUNC) ) ]
pub fn $FUNC(e: Expr) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args: vec![e],
}
}
};
}
macro_rules! scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
#[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ]
pub fn $FUNC($($arg: Expr),*) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args: vec![$($arg),*],
}
}
};
}
macro_rules! nary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
#[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ]
pub fn $FUNC(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args,
}
}
};
}
unary_scalar_expr!(Sqrt, sqrt);
unary_scalar_expr!(Sin, sin);
unary_scalar_expr!(Cos, cos);
unary_scalar_expr!(Tan, tan);
unary_scalar_expr!(Asin, asin);
unary_scalar_expr!(Acos, acos);
unary_scalar_expr!(Atan, atan);
unary_scalar_expr!(Floor, floor);
unary_scalar_expr!(Ceil, ceil);
unary_scalar_expr!(Now, now);
unary_scalar_expr!(Round, round);
unary_scalar_expr!(Trunc, trunc);
unary_scalar_expr!(Abs, abs);
unary_scalar_expr!(Signum, signum);
unary_scalar_expr!(Exp, exp);
unary_scalar_expr!(Log2, log2);
unary_scalar_expr!(Log10, log10);
unary_scalar_expr!(Ln, ln);
unary_scalar_expr!(NullIf, nullif);
scalar_expr!(Power, power, base, exponent);
scalar_expr!(Ascii, ascii, string);
scalar_expr!(BitLength, bit_length, string);
scalar_expr!(CharacterLength, character_length, string);
scalar_expr!(CharacterLength, length, string);
scalar_expr!(Chr, chr, string);
scalar_expr!(Digest, digest, string, algorithm);
scalar_expr!(InitCap, initcap, string);
scalar_expr!(Left, left, string, count);
scalar_expr!(Lower, lower, string);
scalar_expr!(Ltrim, ltrim, string);
scalar_expr!(MD5, md5, string);
scalar_expr!(OctetLength, octet_length, string);
scalar_expr!(Replace, replace, string, from, to);
scalar_expr!(Repeat, repeat, string, count);
scalar_expr!(Reverse, reverse, string);
scalar_expr!(Right, right, string, count);
scalar_expr!(Rtrim, rtrim, string);
scalar_expr!(SHA224, sha224, string);
scalar_expr!(SHA256, sha256, string);
scalar_expr!(SHA384, sha384, string);
scalar_expr!(SHA512, sha512, string);
scalar_expr!(SplitPart, split_part, expr, delimiter, index);
scalar_expr!(StartsWith, starts_with, string, characters);
scalar_expr!(Strpos, strpos, string, substring);
scalar_expr!(Substr, substr, string, position);
scalar_expr!(ToHex, to_hex, string);
scalar_expr!(Translate, translate, string, from, to);
scalar_expr!(Trim, trim, string);
scalar_expr!(Upper, upper, string);
nary_scalar_expr!(Lpad, lpad);
nary_scalar_expr!(Rpad, rpad);
nary_scalar_expr!(RegexpReplace, regexp_replace);
nary_scalar_expr!(RegexpMatch, regexp_match);
nary_scalar_expr!(Btrim, btrim);
nary_scalar_expr!(ConcatWithSeparator, concat_ws_expr);
nary_scalar_expr!(Concat, concat_expr);
nary_scalar_expr!(Now, now_expr);
scalar_expr!(DatePart, date_part, part, date);
scalar_expr!(DateTrunc, date_trunc, part, date);
scalar_expr!(ToTimestampMillis, to_timestamp_millis, date);
scalar_expr!(ToTimestampMicros, to_timestamp_micros, date);
scalar_expr!(ToTimestampSeconds, to_timestamp_seconds, date);
pub fn array(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Array,
args,
}
}
pub fn coalesce(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Coalesce,
args,
}
}
pub fn case(expr: Expr) -> CaseBuilder {
CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
}
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}
pub fn combine_filters(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
let combined_filter = filters
.iter()
.skip(1)
.fold(filters[0].clone(), |acc, filter| and(acc, filter.clone()));
Some(combined_filter)
}
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
AggregateUDF::new(
name,
&Signature::exact(vec![input_type], volatility),
&return_type,
&accumulator,
&state_type,
)
}
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
Err(e) => Err(e),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn filter_is_null_and_is_not_null() {
let col_null = col("col1");
let col_not_null = col("col2");
assert_eq!(format!("{:?}", col_null.is_null()), "#col1 IS NULL");
assert_eq!(
format!("{:?}", col_not_null.is_not_null()),
"#col2 IS NOT NULL"
);
}
macro_rules! test_unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {{
if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(1, args.len());
} else {
assert!(false, "unexpected");
}
}};
}
macro_rules! test_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = vec![$(stringify!($arg)),*];
let result = $FUNC(
$(
col(stringify!($arg.to_string()))
),*
);
if let Expr::ScalarFunction { fun, args } = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}
macro_rules! test_nary_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = vec![$(stringify!($arg)),*];
let result = $FUNC(
vec![
$(
col(stringify!($arg.to_string()))
),*
]
);
if let Expr::ScalarFunction { fun, args } = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}
#[test]
fn scalar_function_definitions() {
test_unary_scalar_expr!(Sqrt, sqrt);
test_unary_scalar_expr!(Sin, sin);
test_unary_scalar_expr!(Cos, cos);
test_unary_scalar_expr!(Tan, tan);
test_unary_scalar_expr!(Asin, asin);
test_unary_scalar_expr!(Acos, acos);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Floor, floor);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Now, now);
test_unary_scalar_expr!(Round, round);
test_unary_scalar_expr!(Trunc, trunc);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
test_unary_scalar_expr!(Exp, exp);
test_unary_scalar_expr!(Log2, log2);
test_unary_scalar_expr!(Log10, log10);
test_unary_scalar_expr!(Ln, ln);
test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
test_nary_scalar_expr!(Btrim, btrim, string);
test_nary_scalar_expr!(Btrim, btrim, string, characters);
test_scalar_expr!(CharacterLength, character_length, string);
test_scalar_expr!(CharacterLength, length, string);
test_scalar_expr!(Chr, chr, string);
test_scalar_expr!(Digest, digest, string, algorithm);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Ltrim, ltrim, string);
test_scalar_expr!(MD5, md5, string);
test_scalar_expr!(OctetLength, octet_length, string);
test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern);
test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern, flags);
test_nary_scalar_expr!(
RegexpReplace,
regexp_replace,
string,
pattern,
replacement
);
test_nary_scalar_expr!(
RegexpReplace,
regexp_replace,
string,
pattern,
replacement,
flags
);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(Rtrim, rtrim, string);
test_scalar_expr!(SHA224, sha224, string);
test_scalar_expr!(SHA256, sha256, string);
test_scalar_expr!(SHA384, sha384, string);
test_scalar_expr!(SHA512, sha512, string);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
test_scalar_expr!(ToHex, to_hex, string);
test_scalar_expr!(Translate, translate, string, from, to);
test_scalar_expr!(Trim, trim, string);
test_scalar_expr!(Upper, upper, string);
test_scalar_expr!(DatePart, date_part, part, date);
test_scalar_expr!(DateTrunc, date_trunc, part, date);
}
#[test]
fn digest_function_definitions() {
if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
unreachable!();
}
}
#[test]
fn combine_zero_filters() {
let result = combine_filters(&[]);
assert_eq!(result, None);
}
#[test]
fn combine_one_filter() {
let filter = binary_expr(col("c1"), Operator::Lt, lit(1));
let result = combine_filters(&[filter.clone()]);
assert_eq!(result, Some(filter));
}
#[test]
fn combine_multiple_filters() {
let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1));
let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2));
let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3));
let result =
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
}