use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
Placeholder, TryCast, Unnest,
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
};
use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl};
use arrow::compute::kernels::cast_utils::{
parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{Column, Result, ScalarValue};
use std::any::Any;
use std::fmt::Debug;
use std::ops::Not;
use std::sync::Arc;
pub fn col(ident: impl Into<Column>) -> Expr {
Expr::Column(ident.into())
}
pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
Expr::OuterReferenceColumn(dt, ident.into())
}
pub fn ident(name: impl Into<String>) -> Expr {
Expr::Column(Column::from_name(name))
}
pub fn placeholder(id: impl Into<String>) -> Expr {
Expr::Placeholder(Placeholder {
id: id.into(),
data_type: None,
})
}
pub fn wildcard() -> Expr {
Expr::Wildcard { qualifier: None }
}
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
}
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::And,
Box::new(right),
))
}
pub fn or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Or,
Box::new(right),
))
}
pub fn not(expr: Expr) -> Expr {
expr.not()
}
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Min,
vec![expr],
false,
None,
None,
None,
))
}
pub fn max(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Max,
vec![expr],
false,
None,
None,
None,
))
}
pub fn array_agg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::ArrayAgg,
vec![expr],
false,
None,
None,
None,
))
}
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseAnd,
Box::new(right),
))
}
pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseOr,
Box::new(right),
))
}
pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseXor,
Box::new(right),
))
}
pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseShiftRight,
Box::new(right),
))
}
pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::BitwiseShiftLeft,
Box::new(right),
))
}
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(expr), list, negated))
}
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::Exists(Exists {
subquery: Subquery {
subquery,
outer_ref_columns,
},
negated: false,
})
}
pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::Exists(Exists {
subquery: Subquery {
subquery,
outer_ref_columns,
},
negated: true,
})
}
pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::InSubquery(InSubquery::new(
Box::new(expr),
Subquery {
subquery,
outer_ref_columns,
},
false,
))
}
pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::InSubquery(InSubquery::new(
Box::new(expr),
Subquery {
subquery,
outer_ref_columns,
},
true,
))
}
pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
Expr::ScalarSubquery(Subquery {
subquery,
outer_ref_columns,
})
}
pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
}
pub fn cube(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Cube(exprs))
}
pub fn rollup(exprs: Vec<Expr>) -> Expr {
Expr::GroupingSet(GroupingSet::Rollup(exprs))
}
pub fn cast(expr: Expr, data_type: DataType) -> Expr {
Expr::Cast(Cast::new(Box::new(expr), data_type))
}
pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
Expr::TryCast(TryCast::new(Box::new(expr), data_type))
}
pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}
pub fn is_true(expr: Expr) -> Expr {
Expr::IsTrue(Box::new(expr))
}
pub fn is_not_true(expr: Expr) -> Expr {
Expr::IsNotTrue(Box::new(expr))
}
pub fn is_false(expr: Expr) -> Expr {
Expr::IsFalse(Box::new(expr))
}
pub fn is_not_false(expr: Expr) -> Expr {
Expr::IsNotFalse(Box::new(expr))
}
pub fn is_unknown(expr: Expr) -> Expr {
Expr::IsUnknown(Box::new(expr))
}
pub fn is_not_unknown(expr: Expr) -> Expr {
Expr::IsNotUnknown(Box::new(expr))
}
pub fn case(expr: Expr) -> CaseBuilder {
CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
}
pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}
pub fn unnest(expr: Expr) -> Expr {
Expr::Unnest(Unnest {
expr: Box::new(expr),
})
}
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
ScalarUDF::from(SimpleScalarUDF::new(
name,
input_types,
return_type,
volatility,
fun,
))
}
pub struct SimpleScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
fun: ScalarFunctionImplementation,
}
impl Debug for SimpleScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl SimpleScalarUDF {
pub fn new(
name: impl Into<String>,
input_types: Vec<DataType>,
return_type: DataType,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_types, volatility);
Self {
name,
signature,
return_type,
fun,
}
}
}
impl ScalarUDFImpl for SimpleScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
(self.fun)(args)
}
}
pub fn create_udaf(
name: &str,
input_type: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone());
let state_fields = state_type
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.collect::<Vec<_>>();
AggregateUDF::from(SimpleAggregateUDF::new(
name,
input_type,
return_type,
volatility,
accumulator,
state_fields,
))
}
pub struct SimpleAggregateUDF {
name: String,
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
}
impl Debug for SimpleAggregateUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl SimpleAggregateUDF {
pub fn new(
name: impl Into<String>,
input_type: Vec<DataType>,
return_type: DataType,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_type, volatility);
Self {
name,
signature,
return_type,
accumulator,
state_fields,
}
}
pub fn new_with_signature(
name: impl Into<String>,
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
) -> Self {
let name = name.into();
Self {
name,
signature,
return_type,
accumulator,
state_fields,
}
}
}
impl AggregateUDFImpl for SimpleAggregateUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn accumulator(
&self,
acc_args: AccumulatorArgs,
) -> Result<Box<dyn crate::Accumulator>> {
(self.accumulator)(acc_args)
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
pub fn create_udwf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
partition_evaluator_factory: PartitionEvaluatorFactory,
) -> WindowUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone());
WindowUDF::from(SimpleWindowUDF::new(
name,
input_type,
return_type,
volatility,
partition_evaluator_factory,
))
}
pub struct SimpleWindowUDF {
name: String,
signature: Signature,
return_type: DataType,
partition_evaluator_factory: PartitionEvaluatorFactory,
}
impl Debug for SimpleWindowUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("WindowUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &"<func>")
.field("partition_evaluator_factory", &"<FUNC>")
.finish()
}
}
impl SimpleWindowUDF {
pub fn new(
name: impl Into<String>,
input_type: DataType,
return_type: DataType,
volatility: Volatility,
partition_evaluator_factory: PartitionEvaluatorFactory,
) -> Self {
let name = name.into();
let signature = Signature::exact([input_type].to_vec(), volatility);
Self {
name,
signature,
return_type,
partition_evaluator_factory,
}
}
}
impl WindowUDFImpl for SimpleWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}
}
pub fn interval_year_month_lit(value: &str) -> Expr {
let interval = parse_interval_year_month(value).ok();
Expr::Literal(ScalarValue::IntervalYearMonth(interval))
}
pub fn interval_datetime_lit(value: &str) -> Expr {
let interval = parse_interval_day_time(value).ok();
Expr::Literal(ScalarValue::IntervalDayTime(interval))
}
pub fn interval_month_day_nano_lit(value: &str) -> Expr {
let interval = parse_interval_month_day_nano(value).ok();
Expr::Literal(ScalarValue::IntervalMonthDayNano(interval))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn filter_is_null_and_is_not_null() {
let col_null = col("col1");
let col_not_null = ident("col2");
assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
assert_eq!(
format!("{}", col_not_null.is_not_null()),
"col2 IS NOT NULL"
);
}
}