use crate::expr::GroupingSet;
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::sync::Arc;
pub fn col(ident: &str) -> Expr {
Expr::Column(ident.into())
}
pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(l),
op,
right: Box::new(r),
}
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::And,
right: Box::new(right),
}
}
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
left: Box::new(left),
op: Operator::Or,
right: Box::new(right),
}
}
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Min,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn max(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Max,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Sum,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Avg,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::Count,
distinct: true,
args: vec![expr],
filter: None,
}
}
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList {
expr: Box::new(expr),
list,
negated,
}
}
pub fn concat(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Concat,
args: args.to_vec(),
}
}
pub fn concat_ws(sep: impl Into<String>, values: &[Expr]) -> Expr {
let mut args = vec![lit(sep.into())];
args.extend_from_slice(values);
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::ConcatWithSeparator,
args,
}
}
pub fn random() -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Random,
args: vec![],
}
}
pub fn approx_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxDistinct,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn approx_median(expr: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxMedian,
distinct: false,
args: vec![expr],
filter: None,
}
}
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxPercentileCont,
distinct: false,
args: vec![expr, percentile],
filter: None,
}
}
pub fn approx_percentile_cont_with_weight(
expr: Expr,
weight_expr: Expr,
percentile: Expr,
) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
distinct: false,
args: vec![expr, weight_expr, percentile],
filter: None,
}
}
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
Expr::Exists {
subquery: Subquery { subquery },
negated: false,
}
}
pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
Expr::Exists {
subquery: Subquery { subquery },
negated: true,
}
}
pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
Expr::InSubquery {
expr: Box::new(expr),
subquery: Subquery { subquery },
negated: false,
}
}
pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
Expr::InSubquery {
expr: Box::new(expr),
subquery: Subquery { subquery },
negated: true,
}
}
pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
Expr::ScalarSubquery(Subquery { subquery })
}
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
}
pub fn cube(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Cube(exprs))
}
pub fn rollup(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Rollup(exprs))
}
pub fn cast(expr: Expr, data_type: DataType) -> Expr {
Expr::Cast {
expr: Box::new(expr),
data_type,
}
}
pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
Expr::TryCast {
expr: Box::new(expr),
data_type,
}
}
pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}
pub fn is_true(expr: Expr) -> Expr {
Expr::IsTrue(Box::new(expr))
}
pub fn is_not_true(expr: Expr) -> Expr {
Expr::IsNotTrue(Box::new(expr))
}
pub fn is_false(expr: Expr) -> Expr {
Expr::IsFalse(Box::new(expr))
}
pub fn is_not_false(expr: Expr) -> Expr {
Expr::IsNotFalse(Box::new(expr))
}
pub fn is_unknown(expr: Expr) -> Expr {
Expr::IsUnknown(Box::new(expr))
}
pub fn is_not_unknown(expr: Expr) -> Expr {
Expr::IsNotUnknown(Box::new(expr))
}
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident, $DOC:expr) => {
#[doc = $DOC ]
pub fn $FUNC(e: Expr) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args: vec![e],
}
}
};
}
macro_rules! scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
#[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ]
pub fn $FUNC($($arg: Expr),*) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args: vec![$($arg),*],
}
}
};
}
macro_rules! nary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
#[doc = concat!("Scalar function definition for ", stringify!($FUNC) ) ]
pub fn $FUNC(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::$ENUM,
args,
}
}
};
}
unary_scalar_expr!(Sqrt, sqrt, "square root of a number");
unary_scalar_expr!(Sin, sin, "sine");
unary_scalar_expr!(Cos, cos, "cosine");
unary_scalar_expr!(Tan, tan, "tangent");
unary_scalar_expr!(Asin, asin, "inverse sine");
unary_scalar_expr!(Acos, acos, "inverse cosine");
unary_scalar_expr!(Atan, atan, "inverse tangent");
unary_scalar_expr!(
Floor,
floor,
"nearest integer less than or equal to argument"
);
unary_scalar_expr!(
Ceil,
ceil,
"nearest integer greater than or equal to argument"
);
unary_scalar_expr!(Round, round, "round to nearest integer");
unary_scalar_expr!(Trunc, trunc, "truncate toward zero");
unary_scalar_expr!(Abs, abs, "absolute value");
unary_scalar_expr!(Signum, signum, "sign of the argument (-1, 0, +1) ");
unary_scalar_expr!(Exp, exp, "base 2 logarithm");
unary_scalar_expr!(Log2, log2, "base 10 logarithm");
unary_scalar_expr!(Log10, log10, "base 10 logarithm");
unary_scalar_expr!(Ln, ln, "natural logarithm");
scalar_expr!(NullIf, nullif, arg_1, arg_2);
scalar_expr!(Power, power, base, exponent);
scalar_expr!(Atan2, atan2, y, x);
scalar_expr!(Ascii, ascii, string);
scalar_expr!(BitLength, bit_length, string);
scalar_expr!(CharacterLength, character_length, string);
scalar_expr!(CharacterLength, length, string);
scalar_expr!(Chr, chr, string);
scalar_expr!(Digest, digest, input, algorithm);
scalar_expr!(InitCap, initcap, string);
scalar_expr!(Left, left, string, count);
scalar_expr!(Lower, lower, string);
scalar_expr!(Ltrim, ltrim, string);
scalar_expr!(MD5, md5, string);
scalar_expr!(OctetLength, octet_length, string);
scalar_expr!(Replace, replace, string, from, to);
scalar_expr!(Repeat, repeat, string, count);
scalar_expr!(Reverse, reverse, string);
scalar_expr!(Right, right, string, count);
scalar_expr!(Rtrim, rtrim, string);
scalar_expr!(SHA224, sha224, string);
scalar_expr!(SHA256, sha256, string);
scalar_expr!(SHA384, sha384, string);
scalar_expr!(SHA512, sha512, string);
scalar_expr!(SplitPart, split_part, expr, delimiter, index);
scalar_expr!(StartsWith, starts_with, string, characters);
scalar_expr!(Strpos, strpos, string, substring);
scalar_expr!(Substr, substr, string, position);
scalar_expr!(ToHex, to_hex, string);
scalar_expr!(Translate, translate, string, from, to);
scalar_expr!(Trim, trim, string);
scalar_expr!(Upper, upper, string);
nary_scalar_expr!(Lpad, lpad);
nary_scalar_expr!(Rpad, rpad);
nary_scalar_expr!(RegexpReplace, regexp_replace);
nary_scalar_expr!(RegexpMatch, regexp_match);
nary_scalar_expr!(Btrim, btrim);
nary_scalar_expr!(ConcatWithSeparator, concat_ws_expr);
nary_scalar_expr!(Concat, concat_expr);
scalar_expr!(DatePart, date_part, part, date);
scalar_expr!(DateTrunc, date_trunc, part, date);
scalar_expr!(DateBin, date_bin, stride, source, origin);
scalar_expr!(ToTimestampMillis, to_timestamp_millis, date);
scalar_expr!(ToTimestampMicros, to_timestamp_micros, date);
scalar_expr!(ToTimestampSeconds, to_timestamp_seconds, date);
scalar_expr!(FromUnixtime, from_unixtime, unixtime);
unary_scalar_expr!(ArrowTypeof, arrow_typeof, "data type");
pub fn array(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::MakeArray,
args,
}
}
pub fn coalesce(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::Coalesce,
args,
}
}
pub fn now() -> Expr {
Expr::ScalarFunction {
fun: BuiltinScalarFunction::Now,
args: vec![],
}
}
pub fn case(expr: Expr) -> CaseBuilder {
CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
}
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}
pub fn combine_filters(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
let combined_filter = filters
.iter()
.skip(1)
.fold(filters[0].clone(), |acc, filter| and(acc, filter.clone()));
Some(combined_filter)
}
pub fn uncombine_filter(filter: Expr) -> Vec<Expr> {
match filter {
Expr::BinaryExpr {
left,
op: Operator::And,
right,
} => {
let mut exprs = uncombine_filter(*left);
exprs.extend(uncombine_filter(*right));
exprs
}
expr => {
vec![expr]
}
}
}
pub fn combine_filters_disjunctive(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
filters.iter().cloned().reduce(or)
}
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
AggregateUDF::new(
name,
&Signature::exact(vec![input_type], volatility),
&return_type,
&accumulator,
&state_type,
)
}
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
Err(e) => Err(e),
}
}
#[cfg(test)]
mod test {
use super::*;
use arrow::datatypes::{Field, Schema};
#[test]
fn filter_is_null_and_is_not_null() {
let col_null = col("col1");
let col_not_null = col("col2");
assert_eq!(format!("{:?}", col_null.is_null()), "col1 IS NULL");
assert_eq!(
format!("{:?}", col_not_null.is_not_null()),
"col2 IS NOT NULL"
);
}
macro_rules! test_unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {{
if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(1, args.len());
} else {
assert!(false, "unexpected");
}
}};
}
macro_rules! test_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = vec![$(stringify!($arg)),*];
let result = $FUNC(
$(
col(stringify!($arg.to_string()))
),*
);
if let Expr::ScalarFunction { fun, args } = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}
macro_rules! test_nary_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = vec![$(stringify!($arg)),*];
let result = $FUNC(
vec![
$(
col(stringify!($arg.to_string()))
),*
]
);
if let Expr::ScalarFunction { fun, args } = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}
#[test]
fn scalar_function_definitions() {
test_unary_scalar_expr!(Sqrt, sqrt);
test_unary_scalar_expr!(Sin, sin);
test_unary_scalar_expr!(Cos, cos);
test_unary_scalar_expr!(Tan, tan);
test_unary_scalar_expr!(Asin, asin);
test_unary_scalar_expr!(Acos, acos);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Floor, floor);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Round, round);
test_unary_scalar_expr!(Trunc, trunc);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
test_unary_scalar_expr!(Exp, exp);
test_unary_scalar_expr!(Log2, log2);
test_unary_scalar_expr!(Log10, log10);
test_unary_scalar_expr!(Ln, ln);
test_scalar_expr!(Atan2, atan2, y, x);
test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
test_nary_scalar_expr!(Btrim, btrim, string);
test_nary_scalar_expr!(Btrim, btrim, string, characters);
test_scalar_expr!(CharacterLength, character_length, string);
test_scalar_expr!(CharacterLength, length, string);
test_scalar_expr!(Chr, chr, string);
test_scalar_expr!(Digest, digest, string, algorithm);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Ltrim, ltrim, string);
test_scalar_expr!(MD5, md5, string);
test_scalar_expr!(OctetLength, octet_length, string);
test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern);
test_nary_scalar_expr!(RegexpMatch, regexp_match, string, pattern, flags);
test_nary_scalar_expr!(
RegexpReplace,
regexp_replace,
string,
pattern,
replacement
);
test_nary_scalar_expr!(
RegexpReplace,
regexp_replace,
string,
pattern,
replacement,
flags
);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(Rtrim, rtrim, string);
test_scalar_expr!(SHA224, sha224, string);
test_scalar_expr!(SHA256, sha256, string);
test_scalar_expr!(SHA384, sha384, string);
test_scalar_expr!(SHA512, sha512, string);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
test_scalar_expr!(ToHex, to_hex, string);
test_scalar_expr!(Translate, translate, string, from, to);
test_scalar_expr!(Trim, trim, string);
test_scalar_expr!(Upper, upper, string);
test_scalar_expr!(DatePart, date_part, part, date);
test_scalar_expr!(DateTrunc, date_trunc, part, date);
test_scalar_expr!(DateBin, date_bin, stride, source, origin);
test_scalar_expr!(FromUnixtime, from_unixtime, unixtime);
test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
}
#[test]
fn digest_function_definitions() {
if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
unreachable!();
}
}
#[test]
fn combine_zero_filters() {
let result = combine_filters(&[]);
assert_eq!(result, None);
}
#[test]
fn combine_one_filter() {
let filter = binary_expr(col("c1"), Operator::Lt, lit(1));
let result = combine_filters(&[filter.clone()]);
assert_eq!(result, Some(filter));
}
#[test]
fn combine_multiple_filters() {
let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1));
let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2));
let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3));
let result =
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
fn assert_predicates(actual: Vec<Expr>, expected: Vec<Expr>) {
assert_eq!(
actual.len(),
expected.len(),
"Predicates are not equal, found {} predicates but expected {}",
actual.len(),
expected.len()
);
for expr in expected.into_iter() {
assert!(
actual.contains(&expr),
"Predicates are not equal, predicate {:?} not found in {:?}",
expr,
actual
);
}
}
#[test]
fn test_uncombine_filter() {
let _schema = Schema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Utf8, true),
]);
let expr = col("a").eq(lit("s"));
let actual = uncombine_filter(expr);
assert_predicates(actual, vec![col("a").eq(lit("s"))]);
}
#[test]
fn test_uncombine_filter_recursively() {
let _schema = Schema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Utf8, true),
]);
let expr = and(col("a"), col("b"));
let actual = uncombine_filter(expr);
assert_predicates(actual, vec![col("a"), col("b")]);
let expr = col("a").and(col("b")).or(col("c"));
let actual = uncombine_filter(expr.clone());
assert_predicates(actual, vec![expr]);
}
}