use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::DataType;
use datafusion_common::{
internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
Dependency, Result,
};
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::{
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
};
use datafusion_expr::{
expr::{ScalarFunction, Unnest},
BuiltInWindowFunction,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType,
};
use std::str::FromStr;
use strum::IntoEnumIterator;
pub fn suggest_valid_function(
input_function_name: &str,
is_window_func: bool,
ctx: &dyn ContextProvider,
) -> String {
let valid_funcs = if is_window_func {
let mut funcs = Vec::new();
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());
funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udwfs_names());
funcs
} else {
let mut funcs = Vec::new();
funcs.extend(ctx.udfs_names());
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
funcs.extend(ctx.udafs_names());
funcs
};
find_closest_match(valid_funcs, input_function_name)
}
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
let target = target.to_lowercase();
candidates
.into_iter()
.min_by_key(|candidate| {
datafusion_common::utils::datafusion_strsim::levenshtein(
&candidate.to_lowercase(),
&target,
)
})
.expect("No candidates provided.") }
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(super) fn sql_function_to_expr(
&self,
function: SQLFunction,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let SQLFunction {
name,
args,
over,
distinct,
filter,
null_treatment,
special: _, order_by,
} = function;
let is_function_window = over.is_some();
let name = if name.0.len() > 1 {
name.to_string()
} else {
crate::utils::normalize_ident(name.0[0].clone())
};
if let Some(fm) = self.context_provider.get_function_meta(&name) {
let args = self.function_args_to_expr(args, schema, planner_context)?;
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args)));
}
if name.eq("unnest") {
let mut exprs =
self.function_args_to_expr(args.clone(), schema, planner_context)?;
if exprs.len() != 1 {
return plan_err!("unnest() requires exactly one argument");
}
let expr = exprs.swap_remove(0);
Self::check_unnest_arg(&expr, schema)?;
return Ok(Expr::Unnest(Unnest::new(expr)));
}
if !order_by.is_empty() && is_function_window {
return plan_err!(
"Aggregate ORDER BY is not implemented for window functions"
);
}
if let Some(WindowType::WindowSpec(window)) = over {
let partition_by = window
.partition_by
.into_iter()
.filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },))
.map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context))
.collect::<Result<Vec<_>>>()?;
let mut order_by = self.order_by_to_sort_expr(
&window.order_by,
schema,
planner_context,
false,
None,
)?;
let func_deps = schema.functional_dependencies();
let is_ordering_strict = order_by.iter().find_map(|orderby_expr| {
if let Expr::Sort(sort_expr) = orderby_expr {
if let Expr::Column(col) = sort_expr.expr.as_ref() {
let idx = schema.index_of_column(col).ok()?;
return if func_deps.iter().any(|dep| {
dep.source_indices == vec![idx]
&& dep.mode == Dependency::Single
}) {
Some(true)
} else {
Some(false)
};
}
}
Some(false)
});
let window_frame = window
.window_frame
.as_ref()
.map(|window_frame| {
let window_frame = window_frame.clone().try_into()?;
check_window_frame(&window_frame, order_by.len())
.map(|_| window_frame)
})
.transpose()?;
let window_frame = if let Some(window_frame) = window_frame {
regularize_window_order_by(&window_frame, &mut order_by)?;
window_frame
} else if let Some(is_ordering_strict) = is_ordering_strict {
WindowFrame::new(Some(is_ordering_strict))
} else {
WindowFrame::new((!order_by.is_empty()).then_some(false))
};
if let Ok(fun) = self.find_window_func(&name) {
let expr = match fun {
WindowFunctionDefinition::AggregateFunction(aggregate_fun) => {
let args =
self.function_args_to_expr(args, schema, planner_context)?;
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(aggregate_fun),
args,
partition_by,
order_by,
window_frame,
null_treatment,
))
}
_ => Expr::WindowFunction(expr::WindowFunction::new(
fun,
self.function_args_to_expr(args, schema, planner_context)?,
partition_by,
order_by,
window_frame,
null_treatment,
)),
};
return Ok(expr);
}
} else {
if let Some(fm) = self.context_provider.get_aggregate_meta(&name) {
let order_by = self.order_by_to_sort_expr(
&order_by,
schema,
planner_context,
true,
None,
)?;
let order_by = (!order_by.is_empty()).then_some(order_by);
let args = self.function_args_to_expr(args, schema, planner_context)?;
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
fm,
args,
false,
None,
order_by,
null_treatment,
)));
}
if let Ok(fun) = AggregateFunction::from_str(&name) {
let order_by = self.order_by_to_sort_expr(
&order_by,
schema,
planner_context,
true,
None,
)?;
let order_by = (!order_by.is_empty()).then_some(order_by);
let args = self.function_args_to_expr(args, schema, planner_context)?;
let filter: Option<Box<Expr>> = filter
.map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context))
.transpose()?
.map(Box::new);
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
args,
distinct,
filter,
order_by,
null_treatment,
)));
};
}
let suggested_func_name =
suggest_valid_function(&name, is_function_window, self.context_provider);
plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?")
}
pub(super) fn sql_fn_name_to_expr(
&self,
expr: SQLExpr,
fn_name: &str,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let fun = self
.context_provider
.get_function_meta(fn_name)
.ok_or_else(|| {
internal_datafusion_err!("Unable to find expected '{fn_name}' function")
})?;
let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?];
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args)))
}
pub(super) fn find_window_func(
&self,
name: &str,
) -> Result<WindowFunctionDefinition> {
expr::find_df_window_func(name)
.or_else(|| {
self.context_provider
.get_aggregate_meta(name)
.map(WindowFunctionDefinition::AggregateUDF)
})
.or_else(|| {
self.context_provider
.get_window_meta(name)
.map(WindowFunctionDefinition::WindowUDF)
})
.ok_or_else(|| {
plan_datafusion_err!("There is no window function named {name}")
})
}
fn sql_fn_arg_to_logical_expr(
&self,
sql: FunctionArg,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
match sql {
FunctionArg::Named {
name: _,
arg: FunctionArgExpr::Expr(arg),
operator: _,
} => self.sql_expr_to_logical_expr(arg, schema, planner_context),
FunctionArg::Named {
name: _,
arg: FunctionArgExpr::Wildcard,
operator: _,
} => Ok(Expr::Wildcard { qualifier: None }),
FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => {
self.sql_expr_to_logical_expr(arg, schema, planner_context)
}
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Wildcard { qualifier: None })
}
_ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"),
}
}
pub(super) fn function_args_to_expr(
&self,
args: Vec<FunctionArg>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Vec<Expr>> {
args.into_iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context))
.collect::<Result<Vec<Expr>>>()
}
pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> {
match arg.get_type(schema)? {
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => Ok(()),
DataType::Struct(_) => {
not_impl_err!("unnest() does not support struct yet")
}
DataType::Null => {
not_impl_err!("unnest() does not support null yet")
}
_ => {
plan_err!("unnest() can only be applied to array, struct and null")
}
}
}
}