use std::any::Any;
use std::cmp::Ordering;
use std::fmt::{self, Debug, Formatter, Write};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use crate::expr::{
schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
WindowFunctionParams,
};
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{expr_vec_fmt, Accumulator, Expr};
use crate::{Documentation, Signature};
#[derive(Debug, Clone, PartialOrd)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.equals(other.inner.as_ref())
}
}
impl Eq for AggregateUDF {}
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}
impl fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug)]
pub struct StatisticsArgs<'a> {
pub statistics: &'a Statistics,
pub return_type: &'a DataType,
pub is_distinct: bool,
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}
impl AggregateUDF {
pub fn new_from_impl<F>(fun: F) -> AggregateUDF
where
F: AggregateUDFImpl + 'static,
{
Self::new_from_shared_impl(Arc::new(fun))
}
pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
Self { inner: fun }
}
pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedAggregateUDFImpl::new(
Arc::clone(&self.inner),
aliases,
))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
None,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
self.inner.schema_name(params)
}
pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
self.inner.human_display(params)
}
pub fn window_function_schema_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
self.inner.window_function_schema_name(params)
}
pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
self.inner.display_name(params)
}
pub fn window_function_display_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
self.inner.window_function_display_name(params)
}
pub fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
self.inner.return_type(args)
}
pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
self.inner.return_field(args)
}
pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
self.inner.state_fields(args)
}
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}
pub fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator(args)
}
pub fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.inner.create_sliding_accumulator(args)
}
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
pub fn with_beneficial_ordering(
self,
beneficial_ordering: bool,
) -> Result<Option<AggregateUDF>> {
self.inner
.with_beneficial_ordering(beneficial_ordering)
.map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
}
pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
self.inner.order_sensitivity()
}
pub fn reverse_udf(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}
pub fn is_descending(&self) -> Option<bool> {
self.inner.is_descending()
}
pub fn value_from_stats(
&self,
statistics_args: &StatisticsArgs,
) -> Option<ScalarValue> {
self.inner.value_from_stats(statistics_args)
}
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.inner.default_value(data_type)
}
pub fn supports_null_handling_clause(&self) -> bool {
self.inner.supports_null_handling_clause()
}
pub fn is_ordered_set_aggregate(&self) -> bool {
self.inner.is_ordered_set_aggregate()
}
pub fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
impl<F> From<F> for AggregateUDF
where
F: AggregateUDFImpl + Send + Sync + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
pub trait AggregateUDFImpl: Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;
let args = if self.is_ordered_set_aggregate() {
&args[1..]
} else {
&args[..]
};
let mut schema_name = String::new();
schema_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
schema_name_from_exprs_comma_separated_without_space(args)?
))?;
if let Some(null_treatment) = null_treatment {
schema_name.write_fmt(format_args!(" {null_treatment}"))?;
}
if let Some(filter) = filter {
schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
};
if let Some(order_by) = order_by {
let clause = match self.is_ordered_set_aggregate() {
true => "WITHIN GROUP",
false => "ORDER BY",
};
schema_name.write_fmt(format_args!(
" {} [{}]",
clause,
schema_name_from_sorts(order_by)?
))?;
};
Ok(schema_name)
}
fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;
let mut schema_name = String::new();
schema_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
ExprListDisplay::comma_separated(args.as_slice())
))?;
if let Some(null_treatment) = null_treatment {
schema_name.write_fmt(format_args!(" {null_treatment}"))?;
}
if let Some(filter) = filter {
schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
};
if let Some(order_by) = order_by {
schema_name.write_fmt(format_args!(
" ORDER BY [{}]",
schema_name_from_sorts(order_by)?
))?;
};
Ok(schema_name)
}
fn window_function_schema_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;
let mut schema_name = String::new();
schema_name.write_fmt(format_args!(
"{}({})",
self.name(),
schema_name_from_exprs(args)?
))?;
if let Some(null_treatment) = null_treatment {
schema_name.write_fmt(format_args!(" {null_treatment}"))?;
}
if !partition_by.is_empty() {
schema_name.write_fmt(format_args!(
" PARTITION BY [{}]",
schema_name_from_exprs(partition_by)?
))?;
}
if !order_by.is_empty() {
schema_name.write_fmt(format_args!(
" ORDER BY [{}]",
schema_name_from_sorts(order_by)?
))?;
};
schema_name.write_fmt(format_args!(" {window_frame}"))?;
Ok(schema_name)
}
fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = params;
let mut display_name = String::new();
display_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
expr_vec_fmt!(args)
))?;
if let Some(nt) = null_treatment {
display_name.write_fmt(format_args!(" {nt}"))?;
}
if let Some(fe) = filter {
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
}
if let Some(ob) = order_by {
display_name.write_fmt(format_args!(
" ORDER BY [{}]",
ob.iter()
.map(|o| format!("{o}"))
.collect::<Vec<String>>()
.join(", ")
))?;
}
Ok(display_name)
}
fn window_function_display_name(
&self,
params: &WindowFunctionParams,
) -> Result<String> {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;
let mut display_name = String::new();
display_name.write_fmt(format_args!(
"{}({})",
self.name(),
expr_vec_fmt!(args)
))?;
if let Some(null_treatment) = null_treatment {
display_name.write_fmt(format_args!(" {null_treatment}"))?;
}
if !partition_by.is_empty() {
display_name.write_fmt(format_args!(
" PARTITION BY [{}]",
expr_vec_fmt!(partition_by)
))?;
}
if !order_by.is_empty() {
display_name
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
};
display_name.write_fmt(format_args!(
" {} BETWEEN {} AND {}",
window_frame.units, window_frame.start_bound, window_frame.end_bound
))?;
Ok(display_name)
}
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
let arg_types: Vec<_> =
arg_fields.iter().map(|f| f.data_type()).cloned().collect();
let data_type = self.return_type(&arg_types)?;
Ok(Arc::new(Field::new(
self.name(),
data_type,
self.is_nullable(),
)))
}
fn is_nullable(&self) -> bool {
true
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
let fields = vec![args
.return_field
.as_ref()
.clone()
.with_name(format_state_name(args.name, "value"))];
Ok(fields
.into_iter()
.map(Arc::new)
.chain(args.ordering_fields.to_vec())
.collect())
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}
fn aliases(&self) -> &[String] {
&[]
}
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.accumulator(args)
}
fn with_beneficial_ordering(
self: Arc<Self>,
_beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
if self.order_sensitivity().is_beneficial() {
return exec_err!(
"Should implement with satisfied for aggregator :{:?}",
self.name()
);
}
Ok(None)
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::HardRequirement
}
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
fn is_descending(&self) -> Option<bool> {
None
}
fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
None
}
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
ScalarValue::try_from(data_type)
}
fn supports_null_handling_clause(&self) -> bool {
true
}
fn is_ordered_set_aggregate(&self) -> bool {
false
}
fn documentation(&self) -> Option<&Documentation> {
None
}
fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
SetMonotonicity::NotMonotonic
}
}
impl PartialEq for dyn AggregateUDFImpl {
fn eq(&self, other: &Self) -> bool {
self.equals(other)
}
}
impl PartialOrd for dyn AggregateUDFImpl {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.name().partial_cmp(other.name()) {
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
cmp => cmp,
}
}
}
pub enum ReversedUDAF {
Identical,
NotSupported,
Reversed(Arc<AggregateUDF>),
}
#[derive(Debug)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
aliases: Vec<String>,
}
impl AliasedAggregateUDFImpl {
pub fn new(
inner: Arc<dyn AggregateUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self { inner, aliases }
}
}
impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
self.inner.state_fields(args)
}
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}
fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator(args)
}
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(args)
}
fn with_beneficial_ordering(
self: Arc<Self>,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Arc::clone(&self.inner)
.with_beneficial_ordering(beneficial_ordering)
.map(|udf| {
udf.map(|udf| {
Arc::new(AliasedAggregateUDFImpl {
inner: udf,
aliases: self.aliases.clone(),
}) as Arc<dyn AggregateUDFImpl>
})
})
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
self.inner.order_sensitivity()
}
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}
fn reverse_expr(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
fn is_descending(&self) -> Option<bool> {
self.inner.is_descending()
}
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
pub mod aggregate_doc_sections {
use crate::DocSection;
pub fn doc_sections() -> Vec<DocSection> {
vec![
DOC_SECTION_GENERAL,
DOC_SECTION_STATISTICAL,
DOC_SECTION_APPROXIMATE,
]
}
pub const DOC_SECTION_GENERAL: DocSection = DocSection {
include: true,
label: "General Functions",
description: None,
};
pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
include: true,
label: "Statistical Functions",
description: None,
};
pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
include: true,
label: "Approximate Functions",
description: None,
};
}
#[derive(Debug, Clone, PartialEq)]
pub enum SetMonotonicity {
Increasing,
Decreasing,
NotMonotonic,
}
#[cfg(test)]
mod test {
use crate::{AggregateUDF, AggregateUDFImpl};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::Result;
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::signature::{Signature, Volatility};
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
};
use std::any::Any;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
struct AMeanUdf {
signature: Signature,
}
impl AMeanUdf {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
impl AggregateUDFImpl for AMeanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"a"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn accumulator(
&self,
_acc_args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
unimplemented!()
}
}
#[derive(Debug, Clone)]
struct BMeanUdf {
signature: Signature,
}
impl BMeanUdf {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
impl AggregateUDFImpl for BMeanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"b"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn accumulator(
&self,
_acc_args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
unimplemented!()
}
}
#[test]
fn test_partial_ord() {
let a1 = AggregateUDF::from(AMeanUdf::new());
let a2 = AggregateUDF::from(AMeanUdf::new());
assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
let b1 = AggregateUDF::from(BMeanUdf::new());
assert!(a1 < b1);
assert!(!(a1 == b1));
}
}