use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{
expr::{Between, BinaryExpr},
expr_fn::{and, concat_ws, or},
lit, BuiltinScalarFunction, Expr, Operator,
};
pub static POWS_OF_TEN: [i128; 38] = [
1,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000,
10000000000,
100000000000,
1000000000000,
10000000000000,
100000000000000,
1000000000000000,
10000000000000000,
100000000000000000,
1000000000000000000,
10000000000000000000,
100000000000000000000,
1000000000000000000000,
10000000000000000000000,
100000000000000000000000,
1000000000000000000000000,
10000000000000000000000000,
100000000000000000000000000,
1000000000000000000000000000,
10000000000000000000000000000,
100000000000000000000000000000,
1000000000000000000000000000000,
10000000000000000000000000000000,
100000000000000000000000000000000,
1000000000000000000000000000000000,
10000000000000000000000000000000000,
100000000000000000000000000000000000,
1000000000000000000000000000000000000,
10000000000000000000000000000000000000,
];
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
expr_contains(left, needle, search_op)
|| expr_contains(right, needle, search_op)
}
_ => expr == needle,
}
}
pub fn is_zero(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(0)))
| Expr::Literal(ScalarValue::Int16(Some(0)))
| Expr::Literal(ScalarValue::Int32(Some(0)))
| Expr::Literal(ScalarValue::Int64(Some(0)))
| Expr::Literal(ScalarValue::UInt8(Some(0)))
| Expr::Literal(ScalarValue::UInt16(Some(0)))
| Expr::Literal(ScalarValue::UInt32(Some(0)))
| Expr::Literal(ScalarValue::UInt64(Some(0))) => true,
Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true,
Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true,
Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true,
_ => false,
}
}
pub fn is_one(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(1)))
| Expr::Literal(ScalarValue::Int16(Some(1)))
| Expr::Literal(ScalarValue::Int32(Some(1)))
| Expr::Literal(ScalarValue::Int64(Some(1)))
| Expr::Literal(ScalarValue::UInt8(Some(1)))
| Expr::Literal(ScalarValue::UInt16(Some(1)))
| Expr::Literal(ScalarValue::UInt32(Some(1)))
| Expr::Literal(ScalarValue::UInt64(Some(1))) => true,
Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true,
Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true,
Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => {
*s >= 0
&& POWS_OF_TEN
.get(*s as usize)
.map(|x| x == v)
.unwrap_or_default()
}
_ => false,
}
}
pub fn is_true(expr: &Expr) -> bool {
match expr {
Expr::Literal(ScalarValue::Boolean(Some(v))) => *v,
_ => false,
}
}
pub fn is_bool_lit(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(ScalarValue::Boolean(_)))
}
pub fn lit_bool_null() -> Expr {
Expr::Literal(ScalarValue::Boolean(None))
}
pub fn is_null(expr: &Expr) -> bool {
match expr {
Expr::Literal(v) => v.is_null(),
_ => false,
}
}
pub fn is_false(expr: &Expr) -> bool {
match expr {
Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v),
_ => false,
}
}
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()))
}
pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
}
pub fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {
match expr {
Expr::Literal(ScalarValue::Boolean(v)) => Ok(v),
_ => Err(DataFusionError::Internal(format!(
"Expected boolean literal, got {:?}",
expr
))),
}
}
pub fn negate_clause(expr: Expr) -> Expr {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
if let Some(negated_op) = op.negate() {
return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right));
}
match op {
Operator::And => {
let left = negate_clause(*left);
let right = negate_clause(*right);
or(left, right)
}
Operator::Or => {
let left = negate_clause(*left);
let right = negate_clause(*right);
and(left, right)
}
_ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new(
left, op, right,
)))),
}
}
Expr::Not(expr) => *expr,
Expr::IsNotNull(expr) => expr.is_null(),
Expr::IsNull(expr) => expr.is_not_null(),
Expr::InList {
expr,
list,
negated,
} => expr.in_list(list, !negated),
Expr::Between(between) => Expr::Between(Between::new(
between.expr,
!between.negated,
between.low,
between.high,
)),
_ => Expr::Not(Box::new(expr)),
}
}
pub fn simpl_concat(args: Vec<Expr>) -> Result<Expr> {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();
for arg in args {
match arg {
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
Expr::Literal(
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)),
) => contiguous_scalar += &v,
Expr::Literal(x) => {
return Err(DataFusionError::Internal(format!(
"The scalar {} should be casted to string type during the type coercion.",
x
)))
}
arg => {
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
contiguous_scalar = "".to_string();
}
new_args.push(arg);
}
}
}
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
}
Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: new_args,
})
}
pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<Expr> {
match delimiter {
Expr::Literal(
ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter),
) => {
match delimiter {
Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()),
Some(delimiter) => {
let mut new_args = Vec::with_capacity(args.len());
new_args.push(lit(delimiter));
let mut contiguous_scalar = None;
for arg in args {
match arg {
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => {
match contiguous_scalar {
None => contiguous_scalar = Some(v.to_string()),
Some(mut pre) => {
pre += delimiter;
pre += v;
contiguous_scalar = Some(pre)
}
}
}
Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))),
arg => {
if let Some(val) = contiguous_scalar {
new_args.push(lit(val));
}
new_args.push(arg.clone());
contiguous_scalar = None;
}
}
}
if let Some(val) = contiguous_scalar {
new_args.push(lit(val));
}
Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args: new_args,
})
}
None => Ok(Expr::Literal(ScalarValue::Utf8(None))),
}
}
Expr::Literal(d) => Err(DataFusionError::Internal(format!(
"The scalar {} should be casted to string type during the type coercion.",
d
))),
d => Ok(concat_ws(
d.clone(),
args.iter()
.cloned()
.filter(|x| !is_null(x))
.collect::<Vec<Expr>>(),
)),
}
}
#[cfg(test)]
pub mod for_test {
use arrow::datatypes::DataType;
use datafusion_expr::{call_fn, lit, Cast, Expr};
pub fn now_expr() -> Expr {
call_fn("now", vec![]).unwrap()
}
pub fn cast_to_int64_expr(expr: Expr) -> Expr {
Expr::Cast(Cast::new(expr.into(), DataType::Int64))
}
pub fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
call_fn("to_timestamp", vec![lit(arg.into())]).unwrap()
}
}