use std::collections::HashSet;
use std::fmt::{self, Display, Formatter, Write};
use std::hash::Hash;
use std::str::FromStr;
use std::sync::Arc;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::{expr_to_columns, find_out_reference_exprs};
use crate::window_frame;
use crate::{
aggregate_function, built_in_function, built_in_window_function, udaf,
BuiltinScalarFunction, ExprSchemable, Operator, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{
internal_err, plan_err, Column, DFSchema, OwnedTableReference, Result, ScalarValue,
};
use sqlparser::ast::NullTreatment;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Expr {
Alias(Alias),
Column(Column),
ScalarVariable(DataType, Vec<String>),
Literal(ScalarValue),
BinaryExpr(BinaryExpr),
Like(Like),
SimilarTo(Like),
Not(Box<Expr>),
IsNotNull(Box<Expr>),
IsNull(Box<Expr>),
IsTrue(Box<Expr>),
IsFalse(Box<Expr>),
IsUnknown(Box<Expr>),
IsNotTrue(Box<Expr>),
IsNotFalse(Box<Expr>),
IsNotUnknown(Box<Expr>),
Negative(Box<Expr>),
GetIndexedField(GetIndexedField),
Between(Between),
Case(Case),
Cast(Cast),
TryCast(TryCast),
Sort(Sort),
ScalarFunction(ScalarFunction),
AggregateFunction(AggregateFunction),
WindowFunction(WindowFunction),
InList(InList),
Exists(Exists),
InSubquery(InSubquery),
ScalarSubquery(Subquery),
Wildcard { qualifier: Option<String> },
GroupingSet(GroupingSet),
Placeholder(Placeholder),
OuterReferenceColumn(DataType, Column),
Unnest(Unnest),
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Unnest {
pub exprs: Vec<Expr>,
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Alias {
pub expr: Box<Expr>,
pub relation: Option<OwnedTableReference>,
pub name: String,
}
impl Alias {
pub fn new(
expr: Expr,
relation: Option<impl Into<OwnedTableReference>>,
name: impl Into<String>,
) -> Self {
Self {
expr: Box::new(expr),
relation: relation.map(|r| r.into()),
name: name.into(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BinaryExpr {
pub left: Box<Expr>,
pub op: Operator,
pub right: Box<Expr>,
}
impl BinaryExpr {
pub fn new(left: Box<Expr>, op: Operator, right: Box<Expr>) -> Self {
Self { left, op, right }
}
}
impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.op.precedence();
if p == 0 || p < precedence {
write!(f, "({child})")?;
} else {
write!(f, "{child}")?;
}
}
_ => write!(f, "{expr}")?,
}
Ok(())
}
let precedence = self.op.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Case {
pub expr: Option<Box<Expr>>,
pub when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
pub else_expr: Option<Box<Expr>>,
}
impl Case {
pub fn new(
expr: Option<Box<Expr>>,
when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
else_expr: Option<Box<Expr>>,
) -> Self {
Self {
expr,
when_then_expr,
else_expr,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Like {
pub negated: bool,
pub expr: Box<Expr>,
pub pattern: Box<Expr>,
pub escape_char: Option<char>,
pub case_insensitive: bool,
}
impl Like {
pub fn new(
negated: bool,
expr: Box<Expr>,
pattern: Box<Expr>,
escape_char: Option<char>,
case_insensitive: bool,
) -> Self {
Self {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Between {
pub expr: Box<Expr>,
pub negated: bool,
pub low: Box<Expr>,
pub high: Box<Expr>,
}
impl Between {
pub fn new(expr: Box<Expr>, negated: bool, low: Box<Expr>, high: Box<Expr>) -> Self {
Self {
expr,
negated,
low,
high,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ScalarFunctionDefinition {
BuiltIn(BuiltinScalarFunction),
UDF(Arc<crate::ScalarUDF>),
Name(Arc<str>),
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
pub func_def: ScalarFunctionDefinition,
pub args: Vec<Expr>,
}
impl ScalarFunction {
pub fn name(&self) -> &str {
self.func_def.name()
}
}
impl ScalarFunctionDefinition {
pub fn name(&self) -> &str {
match self {
ScalarFunctionDefinition::BuiltIn(fun) => fun.name(),
ScalarFunctionDefinition::UDF(udf) => udf.name(),
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}
pub fn is_volatile(&self) -> Result<bool> {
match self {
ScalarFunctionDefinition::BuiltIn(fun) => {
Ok(fun.volatility() == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::UDF(udf) => {
Ok(udf.signature().volatility == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::Name(func) => {
internal_err!(
"Cannot determine volatility of unresolved function: {func}"
)
}
}
}
}
impl ScalarFunction {
pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec<Expr>) -> Self {
Self {
func_def: ScalarFunctionDefinition::BuiltIn(fun),
args,
}
}
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
Self {
func_def: ScalarFunctionDefinition::UDF(udf),
args,
}
}
pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec<Expr>) -> Self {
Self { func_def, args }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum GetFieldAccess {
NamedStructField { name: ScalarValue },
ListIndex { key: Box<Expr> },
ListRange {
start: Box<Expr>,
stop: Box<Expr>,
stride: Box<Expr>,
},
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct GetIndexedField {
pub expr: Box<Expr>,
pub field: GetFieldAccess,
}
impl GetIndexedField {
pub fn new(expr: Box<Expr>, field: GetFieldAccess) -> Self {
Self { expr, field }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Cast {
pub expr: Box<Expr>,
pub data_type: DataType,
}
impl Cast {
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct TryCast {
pub expr: Box<Expr>,
pub data_type: DataType,
}
impl TryCast {
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Sort {
pub expr: Box<Expr>,
pub asc: bool,
pub nulls_first: bool,
}
impl Sort {
pub fn new(expr: Box<Expr>, asc: bool, nulls_first: bool) -> Self {
Self {
expr,
asc,
nulls_first,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AggregateFunctionDefinition {
BuiltIn(aggregate_function::AggregateFunction),
UDF(Arc<crate::AggregateUDF>),
Name(Arc<str>),
}
impl AggregateFunctionDefinition {
pub fn name(&self) -> &str {
match self {
AggregateFunctionDefinition::BuiltIn(fun) => fun.name(),
AggregateFunctionDefinition::UDF(udf) => udf.name(),
AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
pub func_def: AggregateFunctionDefinition,
pub args: Vec<Expr>,
pub distinct: bool,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<Expr>>,
pub null_treatment: Option<NullTreatment>,
}
impl AggregateFunction {
pub fn new(
fun: aggregate_function::AggregateFunction,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
args,
distinct,
filter,
order_by,
null_treatment,
}
}
pub fn new_udf(
udf: Arc<crate::AggregateUDF>,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
distinct,
filter,
order_by,
null_treatment: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunctionDefinition {
AggregateFunction(aggregate_function::AggregateFunction),
BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction),
AggregateUDF(Arc<crate::AggregateUDF>),
WindowUDF(Arc<crate::WindowUDF>),
}
impl WindowFunctionDefinition {
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::AggregateUDF(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types),
}
}
pub fn signature(&self) -> Signature {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(),
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(),
WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(),
WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(),
}
}
}
impl fmt::Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f),
WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct WindowFunction {
pub fun: WindowFunctionDefinition,
pub args: Vec<Expr>,
pub partition_by: Vec<Expr>,
pub order_by: Vec<Expr>,
pub window_frame: window_frame::WindowFrame,
pub null_treatment: Option<NullTreatment>,
}
impl WindowFunction {
pub fn new(
fun: WindowFunctionDefinition,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
fun,
args,
partition_by,
order_by,
window_frame,
null_treatment,
}
}
}
pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
let name = name.to_lowercase();
if let Ok(built_in_function) =
built_in_window_function::BuiltInWindowFunction::from_str(name.as_str())
{
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
))
} else if let Ok(aggregate) =
aggregate_function::AggregateFunction::from_str(name.as_str())
{
Some(WindowFunctionDefinition::AggregateFunction(aggregate))
} else {
None
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Exists {
pub subquery: Subquery,
pub negated: bool,
}
impl Exists {
pub fn new(subquery: Subquery, negated: bool) -> Self {
Self { subquery, negated }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateUDF {
pub fun: Arc<udaf::AggregateUDF>,
pub args: Vec<Expr>,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<Expr>>,
}
impl AggregateUDF {
pub fn new(
fun: Arc<udaf::AggregateUDF>,
args: Vec<Expr>,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
) -> Self {
Self {
fun,
args,
filter,
order_by,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct InList {
pub expr: Box<Expr>,
pub list: Vec<Expr>,
pub negated: bool,
}
impl InList {
pub fn new(expr: Box<Expr>, list: Vec<Expr>, negated: bool) -> Self {
Self {
expr,
list,
negated,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct InSubquery {
pub expr: Box<Expr>,
pub subquery: Subquery,
pub negated: bool,
}
impl InSubquery {
pub fn new(expr: Box<Expr>, subquery: Subquery, negated: bool) -> Self {
Self {
expr,
subquery,
negated,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Placeholder {
pub id: String,
pub data_type: Option<DataType>,
}
impl Placeholder {
pub fn new(id: String, data_type: Option<DataType>) -> Self {
Self { id, data_type }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum GroupingSet {
Rollup(Vec<Expr>),
Cube(Vec<Expr>),
GroupingSets(Vec<Vec<Expr>>),
}
impl GroupingSet {
pub fn distinct_expr(&self) -> Vec<Expr> {
match self {
GroupingSet::Rollup(exprs) => exprs.clone(),
GroupingSet::Cube(exprs) => exprs.clone(),
GroupingSet::GroupingSets(groups) => {
let mut exprs: Vec<Expr> = vec![];
for exp in groups.iter().flatten() {
if !exprs.contains(exp) {
exprs.push(exp.clone());
}
}
exprs
}
}
}
}
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
impl PartialOrd for Expr {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let s = SEED.hash_one(self);
let o = SEED.hash_one(other);
Some(s.cmp(&o))
}
}
impl Expr {
pub fn display_name(&self) -> Result<String> {
create_name(self)
}
pub fn canonical_name(&self) -> String {
format!("{self}")
}
pub fn variant_name(&self) -> &str {
match self {
Expr::AggregateFunction { .. } => "AggregateFunction",
Expr::Alias(..) => "Alias",
Expr::Between { .. } => "Between",
Expr::BinaryExpr { .. } => "BinaryExpr",
Expr::Case { .. } => "Case",
Expr::Cast { .. } => "Cast",
Expr::Column(..) => "Column",
Expr::OuterReferenceColumn(_, _) => "Outer",
Expr::Exists { .. } => "Exists",
Expr::GetIndexedField { .. } => "GetIndexedField",
Expr::GroupingSet(..) => "GroupingSet",
Expr::InList { .. } => "InList",
Expr::InSubquery(..) => "InSubquery",
Expr::IsNotNull(..) => "IsNotNull",
Expr::IsNull(..) => "IsNull",
Expr::Like { .. } => "Like",
Expr::SimilarTo { .. } => "RLike",
Expr::IsTrue(..) => "IsTrue",
Expr::IsFalse(..) => "IsFalse",
Expr::IsUnknown(..) => "IsUnknown",
Expr::IsNotTrue(..) => "IsNotTrue",
Expr::IsNotFalse(..) => "IsNotFalse",
Expr::IsNotUnknown(..) => "IsNotUnknown",
Expr::Literal(..) => "Literal",
Expr::Negative(..) => "Negative",
Expr::Not(..) => "Not",
Expr::Placeholder(_) => "Placeholder",
Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
Expr::ScalarVariable(..) => "ScalarVariable",
Expr::Sort { .. } => "Sort",
Expr::TryCast { .. } => "TryCast",
Expr::WindowFunction { .. } => "WindowFunction",
Expr::Wildcard { .. } => "Wildcard",
Expr::Unnest { .. } => "Unnest",
}
}
pub fn eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::Eq, other)
}
pub fn not_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::NotEq, other)
}
pub fn gt(self, other: Expr) -> Expr {
binary_expr(self, Operator::Gt, other)
}
pub fn gt_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::GtEq, other)
}
pub fn lt(self, other: Expr) -> Expr {
binary_expr(self, Operator::Lt, other)
}
pub fn lt_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::LtEq, other)
}
pub fn and(self, other: Expr) -> Expr {
binary_expr(self, Operator::And, other)
}
pub fn or(self, other: Expr) -> Expr {
binary_expr(self, Operator::Or, other)
}
pub fn like(self, other: Expr) -> Expr {
Expr::Like(Like::new(
false,
Box::new(self),
Box::new(other),
None,
false,
))
}
pub fn not_like(self, other: Expr) -> Expr {
Expr::Like(Like::new(
true,
Box::new(self),
Box::new(other),
None,
false,
))
}
pub fn ilike(self, other: Expr) -> Expr {
Expr::Like(Like::new(
false,
Box::new(self),
Box::new(other),
None,
true,
))
}
pub fn not_ilike(self, other: Expr) -> Expr {
Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true))
}
pub fn name_for_alias(&self) -> Result<String> {
match self {
Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(),
expr => expr.display_name(),
}
}
pub fn alias_if_changed(self, original_name: String) -> Result<Expr> {
let new_name = self.name_for_alias()?;
if new_name == original_name {
return Ok(self);
}
Ok(self.alias(original_name))
}
pub fn alias(self, name: impl Into<String>) -> Expr {
match self {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)),
_ => Expr::Alias(Alias::new(self, None::<&str>, name.into())),
}
}
pub fn alias_qualified(
self,
relation: Option<impl Into<OwnedTableReference>>,
name: impl Into<String>,
) -> Expr {
match self {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => Expr::Sort(Sort::new(
Box::new(expr.alias_qualified(relation, name)),
asc,
nulls_first,
)),
_ => Expr::Alias(Alias::new(self, relation, name.into())),
}
}
pub fn unalias(self) -> Expr {
match self {
Expr::Alias(alias) => *alias.expr,
_ => self,
}
}
pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(self), list, negated))
}
pub fn is_null(self) -> Expr {
Expr::IsNull(Box::new(self))
}
pub fn is_not_null(self) -> Expr {
Expr::IsNotNull(Box::new(self))
}
pub fn sort(self, asc: bool, nulls_first: bool) -> Expr {
Expr::Sort(Sort::new(Box::new(self), asc, nulls_first))
}
pub fn is_true(self) -> Expr {
Expr::IsTrue(Box::new(self))
}
pub fn is_not_true(self) -> Expr {
Expr::IsNotTrue(Box::new(self))
}
pub fn is_false(self) -> Expr {
Expr::IsFalse(Box::new(self))
}
pub fn is_not_false(self) -> Expr {
Expr::IsNotFalse(Box::new(self))
}
pub fn is_unknown(self) -> Expr {
Expr::IsUnknown(Box::new(self))
}
pub fn is_not_unknown(self) -> Expr {
Expr::IsNotUnknown(Box::new(self))
}
pub fn between(self, low: Expr, high: Expr) -> Expr {
Expr::Between(Between::new(
Box::new(self),
false,
Box::new(low),
Box::new(high),
))
}
pub fn not_between(self, low: Expr, high: Expr) -> Expr {
Expr::Between(Between::new(
Box::new(self),
true,
Box::new(low),
Box::new(high),
))
}
pub fn field(self, name: impl Into<String>) -> Self {
Expr::GetIndexedField(GetIndexedField {
expr: Box::new(self),
field: GetFieldAccess::NamedStructField {
name: ScalarValue::from(name.into()),
},
})
}
pub fn index(self, key: Expr) -> Self {
Expr::GetIndexedField(GetIndexedField {
expr: Box::new(self),
field: GetFieldAccess::ListIndex { key: Box::new(key) },
})
}
pub fn range(self, start: Expr, stop: Expr) -> Self {
Expr::GetIndexedField(GetIndexedField {
expr: Box::new(self),
field: GetFieldAccess::ListRange {
start: Box::new(start),
stop: Box::new(stop),
stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))),
},
})
}
pub fn try_into_col(&self) -> Result<Column> {
match self {
Expr::Column(it) => Ok(it.clone()),
_ => plan_err!("Could not coerce '{self}' into Column!"),
}
}
pub fn to_columns(&self) -> Result<HashSet<Column>> {
let mut using_columns = HashSet::new();
expr_to_columns(self, &mut using_columns)?;
Ok(using_columns)
}
pub fn contains_outer(&self) -> bool {
!find_out_reference_exprs(self).is_empty()
}
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
self.transform(&|mut expr| {
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::yes(expr))
})
.data()
}
pub fn short_circuits(&self) -> bool {
match self {
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce)
}
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expr::Case { .. } => true,
Expr::AggregateFunction(..)
| Expr::Alias(..)
| Expr::Between(..)
| Expr::Cast(..)
| Expr::Column(..)
| Expr::Exists(..)
| Expr::GetIndexedField(..)
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
| Expr::IsNotTrue(..)
| Expr::IsNotUnknown(..)
| Expr::IsNull(..)
| Expr::IsTrue(..)
| Expr::IsUnknown(..)
| Expr::Like(..)
| Expr::ScalarSubquery(..)
| Expr::ScalarVariable(_, _)
| Expr::SimilarTo(..)
| Expr::Not(..)
| Expr::Negative(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::TryCast(..)
| Expr::Unnest(..)
| Expr::Wildcard { .. }
| Expr::WindowFunction(..)
| Expr::Literal(..)
| Expr::Sort(..)
| Expr::Placeholder(..) => false,
}
}
}
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}
#[macro_export]
macro_rules! expr_vec_fmt {
( $ARRAY:expr ) => {{
$ARRAY
.iter()
.map(|e| format!("{e}"))
.collect::<Vec<String>>()
.join(", ")
}};
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"),
Expr::Column(c) => write!(f, "{c}"),
Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"),
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
Expr::Literal(v) => write!(f, "{v:?}"),
Expr::Case(case) => {
write!(f, "CASE ")?;
if let Some(e) = &case.expr {
write!(f, "{e} ")?;
}
for (w, t) in &case.when_then_expr {
write!(f, "WHEN {w} THEN {t} ")?;
}
if let Some(e) = &case.else_expr {
write!(f, "ELSE {e} ")?;
}
write!(f, "END")
}
Expr::Cast(Cast { expr, data_type }) => {
write!(f, "CAST({expr} AS {data_type:?})")
}
Expr::TryCast(TryCast { expr, data_type }) => {
write!(f, "TRY_CAST({expr} AS {data_type:?})")
}
Expr::Not(expr) => write!(f, "NOT {expr}"),
Expr::Negative(expr) => write!(f, "(- {expr})"),
Expr::IsNull(expr) => write!(f, "{expr} IS NULL"),
Expr::IsNotNull(expr) => write!(f, "{expr} IS NOT NULL"),
Expr::IsTrue(expr) => write!(f, "{expr} IS TRUE"),
Expr::IsFalse(expr) => write!(f, "{expr} IS FALSE"),
Expr::IsUnknown(expr) => write!(f, "{expr} IS UNKNOWN"),
Expr::IsNotTrue(expr) => write!(f, "{expr} IS NOT TRUE"),
Expr::IsNotFalse(expr) => write!(f, "{expr} IS NOT FALSE"),
Expr::IsNotUnknown(expr) => write!(f, "{expr} IS NOT UNKNOWN"),
Expr::Exists(Exists {
subquery,
negated: true,
}) => write!(f, "NOT EXISTS ({subquery:?})"),
Expr::Exists(Exists {
subquery,
negated: false,
}) => write!(f, "EXISTS ({subquery:?})"),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: true,
}) => write!(f, "{expr} NOT IN ({subquery:?})"),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: false,
}) => write!(f, "{expr} IN ({subquery:?})"),
Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"),
Expr::BinaryExpr(expr) => write!(f, "{expr}"),
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
if *asc {
write!(f, "{expr} ASC")?;
} else {
write!(f, "{expr} DESC")?;
}
if *nulls_first {
write!(f, " NULLS FIRST")
} else {
write!(f, " NULLS LAST")
}
}
Expr::ScalarFunction(fun) => {
fmt_function(f, fun.name(), false, &fun.args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
fmt_function(f, &fun.to_string(), false, args, true)?;
if let Some(nt) = null_treatment {
write!(f, "{}", nt)?;
}
if !partition_by.is_empty() {
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
}
if !order_by.is_empty() {
write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?;
}
write!(
f,
" {} BETWEEN {} AND {}",
window_frame.units, window_frame.start_bound, window_frame.end_bound
)?;
Ok(())
}
Expr::AggregateFunction(AggregateFunction {
func_def,
distinct,
ref args,
filter,
order_by,
null_treatment,
..
}) => {
fmt_function(f, func_def.name(), *distinct, args, true)?;
if let Some(nt) = null_treatment {
write!(f, " {}", nt)?;
}
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
if let Some(ob) = order_by {
write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?;
}
Ok(())
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
write!(f, "{expr} NOT BETWEEN {low} AND {high}")
} else {
write!(f, "{expr} BETWEEN {low} AND {high}")
}
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
write!(f, "{expr}")?;
let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" };
if *negated {
write!(f, " NOT")?;
}
if let Some(char) = escape_char {
write!(f, " {op_name} {pattern} ESCAPE '{char}'")
} else {
write!(f, " {op_name} {pattern}")
}
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive: _,
}) => {
write!(f, "{expr}")?;
if *negated {
write!(f, " NOT")?;
}
if let Some(char) = escape_char {
write!(f, " SIMILAR TO {pattern} ESCAPE '{char}'")
} else {
write!(f, " SIMILAR TO {pattern}")
}
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
if *negated {
write!(f, "{expr} NOT IN ([{}])", expr_vec_fmt!(list))
} else {
write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list))
}
}
Expr::Wildcard { qualifier } => match qualifier {
Some(qualifier) => write!(f, "{qualifier}.*"),
None => write!(f, "*"),
},
Expr::GetIndexedField(GetIndexedField { field, expr }) => match field {
GetFieldAccess::NamedStructField { name } => {
write!(f, "({expr})[{name}]")
}
GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"),
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
write!(f, "({expr})[{start}:{stop}:{stride}]")
}
},
Expr::GroupingSet(grouping_sets) => match grouping_sets {
GroupingSet::Rollup(exprs) => {
write!(f, "ROLLUP ({})", expr_vec_fmt!(exprs))
}
GroupingSet::Cube(exprs) => {
write!(f, "CUBE ({})", expr_vec_fmt!(exprs))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
write!(
f,
"GROUPING SETS ({})",
lists_of_exprs
.iter()
.map(|exprs| format!("({})", expr_vec_fmt!(exprs)))
.collect::<Vec<String>>()
.join(", ")
)
}
},
Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"),
Expr::Unnest(Unnest { exprs }) => {
write!(f, "UNNEST({exprs:?})")
}
}
}
}
fn fmt_function(
f: &mut fmt::Formatter,
fun: &str,
distinct: bool,
args: &[Expr],
display: bool,
) -> fmt::Result {
let args: Vec<String> = match display {
true => args.iter().map(|arg| format!("{arg}")).collect(),
false => args.iter().map(|arg| format!("{arg:?}")).collect(),
};
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
};
write!(f, "{}({}{})", fun, distinct_str, args.join(", "))
}
fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(create_name).collect::<Result<_>>()?;
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
};
Ok(format!("{}({}{})", fun, distinct_str, names.join(",")))
}
fn create_name(e: &Expr) -> Result<String> {
match e {
Expr::Alias(Alias { name, .. }) => Ok(name.clone()),
Expr::Column(c) => Ok(c.flat_name()),
Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())),
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
Expr::Literal(value) => Ok(format!("{value:?}")),
Expr::BinaryExpr(binary_expr) => {
let left = create_name(binary_expr.left.as_ref())?;
let right = create_name(binary_expr.right.as_ref())?;
Ok(format!("{} {} {}", left, binary_expr.op, right))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
let s = format!(
"{} {}{} {} {}",
expr,
if *negated { "NOT " } else { "" },
if *case_insensitive { "ILIKE" } else { "LIKE" },
pattern,
if let Some(char) = escape_char {
format!("CHAR '{char}'")
} else {
"".to_string()
}
);
Ok(s)
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive: _,
}) => {
let s = format!(
"{} {} {} {}",
expr,
if *negated {
"NOT SIMILAR TO"
} else {
"SIMILAR TO"
},
pattern,
if let Some(char) = escape_char {
format!("CHAR '{char}'")
} else {
"".to_string()
}
);
Ok(s)
}
Expr::Case(case) => {
let mut name = "CASE ".to_string();
if let Some(e) = &case.expr {
let e = create_name(e)?;
let _ = write!(name, "{e} ");
}
for (w, t) in &case.when_then_expr {
let when = create_name(w)?;
let then = create_name(t)?;
let _ = write!(name, "WHEN {when} THEN {then} ");
}
if let Some(e) = &case.else_expr {
let e = create_name(e)?;
let _ = write!(name, "ELSE {e} ");
}
name += "END";
Ok(name)
}
Expr::Cast(Cast { expr, .. }) => {
create_name(expr)
}
Expr::TryCast(TryCast { expr, .. }) => {
create_name(expr)
}
Expr::Not(expr) => {
let expr = create_name(expr)?;
Ok(format!("NOT {expr}"))
}
Expr::Negative(expr) => {
let expr = create_name(expr)?;
Ok(format!("(- {expr})"))
}
Expr::IsNull(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS NULL"))
}
Expr::IsNotNull(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS NOT NULL"))
}
Expr::IsTrue(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS TRUE"))
}
Expr::IsFalse(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS FALSE"))
}
Expr::IsUnknown(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS UNKNOWN"))
}
Expr::IsNotTrue(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS NOT TRUE"))
}
Expr::IsNotFalse(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS NOT FALSE"))
}
Expr::IsNotUnknown(expr) => {
let expr = create_name(expr)?;
Ok(format!("{expr} IS NOT UNKNOWN"))
}
Expr::Exists(Exists { negated: true, .. }) => Ok("NOT EXISTS".to_string()),
Expr::Exists(Exists { negated: false, .. }) => Ok("EXISTS".to_string()),
Expr::InSubquery(InSubquery { negated: true, .. }) => Ok("NOT IN".to_string()),
Expr::InSubquery(InSubquery { negated: false, .. }) => Ok("IN".to_string()),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).name().clone())
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
let expr = create_name(expr)?;
match field {
GetFieldAccess::NamedStructField { name } => {
Ok(format!("{expr}[{name}]"))
}
GetFieldAccess::ListIndex { key } => {
let key = create_name(key)?;
Ok(format!("{expr}[{key}]"))
}
GetFieldAccess::ListRange {
start,
stop,
stride,
} => {
let start = create_name(start)?;
let stop = create_name(stop)?;
let stride = create_name(stride)?;
Ok(format!("{expr}[{start}:{stop}:{stride}]"))
}
}
}
Expr::Unnest(Unnest { exprs }) => create_function_name("unnest", false, exprs),
Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args),
Expr::WindowFunction(WindowFunction {
fun,
args,
window_frame,
partition_by,
order_by,
null_treatment,
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];
if let Some(nt) = null_treatment {
parts.push(format!("{}", nt));
}
if !partition_by.is_empty() {
parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by)));
}
if !order_by.is_empty() {
parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by)));
}
parts.push(format!("{window_frame}"));
Ok(parts.join(" "))
}
Expr::AggregateFunction(AggregateFunction {
func_def,
distinct,
args,
filter,
order_by,
null_treatment,
}) => {
let name = match func_def {
AggregateFunctionDefinition::BuiltIn(..)
| AggregateFunctionDefinition::Name(..) => {
create_function_name(func_def.name(), *distinct, args)?
}
AggregateFunctionDefinition::UDF(..) => {
let names: Vec<String> =
args.iter().map(create_name).collect::<Result<_>>()?;
names.join(",")
}
};
let mut info = String::new();
if let Some(fe) = filter {
info += &format!(" FILTER (WHERE {fe})");
};
if let Some(order_by) = order_by {
info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by));
};
if let Some(nt) = null_treatment {
info += &format!(" {}", nt);
}
match func_def {
AggregateFunctionDefinition::BuiltIn(..)
| AggregateFunctionDefinition::Name(..) => {
Ok(format!("{}{}", name, info))
}
AggregateFunctionDefinition::UDF(fun) => {
Ok(format!("{}({}){}", fun.name(), name, info))
}
}
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Ok(format!("ROLLUP ({})", create_names(exprs.as_slice())?))
}
GroupingSet::Cube(exprs) => {
Ok(format!("CUBE ({})", create_names(exprs.as_slice())?))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
let mut list_of_names = vec![];
for exprs in lists_of_exprs {
list_of_names.push(format!("({})", create_names(exprs.as_slice())?));
}
Ok(format!("GROUPING SETS ({})", list_of_names.join(", ")))
}
},
Expr::InList(InList {
expr,
list,
negated,
}) => {
let expr = create_name(expr)?;
let list = list.iter().map(create_name);
if *negated {
Ok(format!("{expr} NOT IN ({list:?})"))
} else {
Ok(format!("{expr} IN ({list:?})"))
}
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
let expr = create_name(expr)?;
let low = create_name(low)?;
let high = create_name(high)?;
if *negated {
Ok(format!("{expr} NOT BETWEEN {low} AND {high}"))
} else {
Ok(format!("{expr} BETWEEN {low} AND {high}"))
}
}
Expr::Sort { .. } => {
internal_err!("Create name does not support sort expression")
}
Expr::Wildcard { qualifier } => match qualifier {
Some(qualifier) => internal_err!(
"Create name does not support qualified wildcard, got {qualifier}"
),
None => Ok("*".to_string()),
},
Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()),
}
}
fn create_names(exprs: &[Expr]) -> Result<String> {
Ok(exprs
.iter()
.map(create_name)
.collect::<Result<Vec<String>>>()?
.join(", "))
}
pub fn is_volatile(expr: &Expr) -> Result<bool> {
match expr {
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
_ => Ok(false),
}
}
#[cfg(test)]
mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
use std::any::Any;
use std::sync::Arc;
#[test]
fn format_case_when() -> Result<()> {
let expr = case(col("a"))
.when(lit(1), lit(true))
.when(lit(0), lit(false))
.otherwise(lit(ScalarValue::Null))?;
let expected = "CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END";
assert_eq!(expected, expr.canonical_name());
assert_eq!(expected, format!("{expr}"));
assert_eq!(expected, expr.display_name()?);
Ok(())
}
#[test]
fn format_cast() -> Result<()> {
let expr = Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))),
data_type: DataType::Utf8,
});
let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
assert_eq!(expected_canonical, expr.canonical_name());
assert_eq!(expected_canonical, format!("{expr}"));
assert_eq!("Float32(1.23)", expr.display_name()?);
Ok(())
}
#[test]
fn test_partial_ord() {
let exp1 = col("a") + lit(1);
let exp2 = col("a") + lit(2);
let exp3 = !(col("a") + lit(2));
let greater_or_equal = exp1 >= exp2;
assert_eq!(exp1 < exp2, !greater_or_equal);
let greater_or_equal = exp3 >= exp2;
assert_eq!(exp3 < exp2, !greater_or_equal);
}
#[test]
fn test_collect_expr() -> Result<()> {
{
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
let columns = expr.to_columns()?;
assert_eq!(1, columns.len());
assert!(columns.contains(&Column::from_name("a")));
}
{
let expr = col("a") + col("b") + lit(1);
let columns = expr.to_columns()?;
assert_eq!(2, columns.len());
assert!(columns.contains(&Column::from_name("a")));
assert!(columns.contains(&Column::from_name("b")));
}
Ok(())
}
#[test]
fn test_logical_ops() {
assert_eq!(
format!("{}", lit(1u32).eq(lit(2u32))),
"UInt32(1) = UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).not_eq(lit(2u32))),
"UInt32(1) != UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).gt(lit(2u32))),
"UInt32(1) > UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).gt_eq(lit(2u32))),
"UInt32(1) >= UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).lt(lit(2u32))),
"UInt32(1) < UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).lt_eq(lit(2u32))),
"UInt32(1) <= UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).and(lit(2u32))),
"UInt32(1) AND UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).or(lit(2u32))),
"UInt32(1) OR UInt32(2)"
);
}
#[test]
fn test_is_volatile_scalar_func_definition() {
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);
#[derive(Debug)]
struct TestScalarUDF {
signature: Signature,
}
impl ScalarUDFImpl for TestScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"TestScalarUDF"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
}
}
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
}));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(
1,
vec![DataType::Float32],
Volatility::Volatile,
),
}));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
.is_volatile()
.expect_err("Shouldn't determine volatility of unresolved function");
}
use super::*;
#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);
let observed = fun.return_type(&[DataType::UInt64])?;
assert_eq!(DataType::Int64, observed);
Ok(())
}
#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = fun.return_type(&[DataType::UInt64])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = fun.return_type(&[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = fun.return_type(&[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Utf8, observed);
let observed = fun.return_type(&[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
assert_eq!(DataType::Utf8, observed);
let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_percent_rank_return_type() -> Result<()> {
let fun = find_df_window_func("percent_rank").unwrap();
let observed = fun.return_type(&[])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_cume_dist_return_type() -> Result<()> {
let fun = find_df_window_func("cume_dist").unwrap();
let observed = fun.return_type(&[])?;
assert_eq!(DataType::Float64, observed);
Ok(())
}
#[test]
fn test_ntile_return_type() -> Result<()> {
let fun = find_df_window_func("ntile").unwrap();
let observed = fun.return_type(&[DataType::Int16])?;
assert_eq!(DataType::UInt64, observed);
Ok(())
}
#[test]
fn test_window_function_case_insensitive() -> Result<()> {
let names = vec![
"row_number",
"rank",
"dense_rank",
"percent_rank",
"cume_dist",
"ntile",
"lag",
"lead",
"first_value",
"last_value",
"nth_value",
"min",
"max",
"count",
"avg",
"sum",
];
for name in names {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
assert_eq!(fun.to_string(), name.to_uppercase());
}
Ok(())
}
#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("max"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Max
))
);
assert_eq!(
find_df_window_func("min"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Min
))
);
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Avg
))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::CumeDist
))
);
assert_eq!(
find_df_window_func("first_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::FirstValue
))
);
assert_eq!(
find_df_window_func("LAST_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::LastValue
))
);
assert_eq!(
find_df_window_func("LAG"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::Lag
))
);
assert_eq!(
find_df_window_func("LEAD"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_window_function::BuiltInWindowFunction::Lead
))
);
assert_eq!(find_df_window_func("not_exist"), None)
}
}