datafusion_expr/udaf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33 schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34 schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
35 WindowFunctionParams,
36};
37use crate::function::{
38 AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65/// access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79 inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83 fn eq(&self, other: &Self) -> bool {
84 self.inner.equals(other.inner.as_ref())
85 }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91 fn hash<H: Hasher>(&self, state: &mut H) {
92 self.inner.hash_value().hash(state)
93 }
94}
95
96impl fmt::Display for AggregateUDF {
97 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98 write!(f, "{}", self.name())
99 }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105 /// The statistics of the aggregate input
106 pub statistics: &'a Statistics,
107 /// The resolved return type of the aggregate function
108 pub return_type: &'a DataType,
109 /// Whether the aggregate function is distinct.
110 ///
111 /// ```sql
112 /// SELECT COUNT(DISTINCT column1) FROM t;
113 /// ```
114 pub is_distinct: bool,
115 /// The physical expression of arguments the aggregate function takes.
116 pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121 ///
122 /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123 pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124 where
125 F: AggregateUDFImpl + 'static,
126 {
127 Self::new_from_shared_impl(Arc::new(fun))
128 }
129
130 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131 pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132 Self { inner: fun }
133 }
134
135 /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136 pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137 &self.inner
138 }
139
140 /// Adds additional names that can be used to invoke this function, in
141 /// addition to `name`
142 ///
143 /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145 Self::new_from_impl(AliasedAggregateUDFImpl::new(
146 Arc::clone(&self.inner),
147 aliases,
148 ))
149 }
150
151 /// Creates an [`Expr`] that calls the aggregate function.
152 ///
153 /// This utility allows using the UDAF without requiring access to
154 /// the registry, such as with the DataFrame API.
155 pub fn call(&self, args: Vec<Expr>) -> Expr {
156 Expr::AggregateFunction(AggregateFunction::new_udf(
157 Arc::new(self.clone()),
158 args,
159 false,
160 None,
161 vec![],
162 None,
163 ))
164 }
165
166 /// Returns this function's name
167 ///
168 /// See [`AggregateUDFImpl::name`] for more details.
169 pub fn name(&self) -> &str {
170 self.inner.name()
171 }
172
173 /// Returns the aliases for this function.
174 pub fn aliases(&self) -> &[String] {
175 self.inner.aliases()
176 }
177
178 /// See [`AggregateUDFImpl::schema_name`] for more details.
179 pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
180 self.inner.schema_name(params)
181 }
182
183 /// Returns a human readable expression.
184 ///
185 /// See [`Expr::human_display`] for details.
186 pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
187 self.inner.human_display(params)
188 }
189
190 pub fn window_function_schema_name(
191 &self,
192 params: &WindowFunctionParams,
193 ) -> Result<String> {
194 self.inner.window_function_schema_name(params)
195 }
196
197 /// See [`AggregateUDFImpl::display_name`] for more details.
198 pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
199 self.inner.display_name(params)
200 }
201
202 pub fn window_function_display_name(
203 &self,
204 params: &WindowFunctionParams,
205 ) -> Result<String> {
206 self.inner.window_function_display_name(params)
207 }
208
209 pub fn is_nullable(&self) -> bool {
210 self.inner.is_nullable()
211 }
212
213 /// Returns this function's signature (what input types are accepted)
214 ///
215 /// See [`AggregateUDFImpl::signature`] for more details.
216 pub fn signature(&self) -> &Signature {
217 self.inner.signature()
218 }
219
220 /// Return the type of the function given its input types
221 ///
222 /// See [`AggregateUDFImpl::return_type`] for more details.
223 pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
224 self.inner.return_type(args)
225 }
226
227 /// Return the field of the function given its input fields
228 ///
229 /// See [`AggregateUDFImpl::return_field`] for more details.
230 pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
231 self.inner.return_field(args)
232 }
233
234 /// Return an accumulator the given aggregate, given its return datatype
235 pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
236 self.inner.accumulator(acc_args)
237 }
238
239 /// Return the fields used to store the intermediate state for this aggregator, given
240 /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
241 /// for more details.
242 ///
243 /// This is used to support multi-phase aggregations
244 pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
245 self.inner.state_fields(args)
246 }
247
248 /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
249 pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
250 self.inner.groups_accumulator_supported(args)
251 }
252
253 /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
254 pub fn create_groups_accumulator(
255 &self,
256 args: AccumulatorArgs,
257 ) -> Result<Box<dyn GroupsAccumulator>> {
258 self.inner.create_groups_accumulator(args)
259 }
260
261 pub fn create_sliding_accumulator(
262 &self,
263 args: AccumulatorArgs,
264 ) -> Result<Box<dyn Accumulator>> {
265 self.inner.create_sliding_accumulator(args)
266 }
267
268 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
269 self.inner.coerce_types(arg_types)
270 }
271
272 /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
273 pub fn with_beneficial_ordering(
274 self,
275 beneficial_ordering: bool,
276 ) -> Result<Option<AggregateUDF>> {
277 self.inner
278 .with_beneficial_ordering(beneficial_ordering)
279 .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
280 }
281
282 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
283 /// for possible options.
284 pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
285 self.inner.order_sensitivity()
286 }
287
288 /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
289 /// generate same result with this `AggregateUDF` when iterated in reverse
290 /// order, and `None` if there is no such `AggregateUDF`).
291 pub fn reverse_udf(&self) -> ReversedUDAF {
292 self.inner.reverse_expr()
293 }
294
295 /// Do the function rewrite
296 ///
297 /// See [`AggregateUDFImpl::simplify`] for more details.
298 pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
299 self.inner.simplify()
300 }
301
302 /// Returns true if the function is max, false if the function is min
303 /// None in all other cases, used in certain optimizations for
304 /// or aggregate
305 pub fn is_descending(&self) -> Option<bool> {
306 self.inner.is_descending()
307 }
308
309 /// Return the value of this aggregate function if it can be determined
310 /// entirely from statistics and arguments.
311 ///
312 /// See [`AggregateUDFImpl::value_from_stats`] for more details.
313 pub fn value_from_stats(
314 &self,
315 statistics_args: &StatisticsArgs,
316 ) -> Option<ScalarValue> {
317 self.inner.value_from_stats(statistics_args)
318 }
319
320 /// See [`AggregateUDFImpl::default_value`] for more details.
321 pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
322 self.inner.default_value(data_type)
323 }
324
325 /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
326 pub fn supports_null_handling_clause(&self) -> bool {
327 self.inner.supports_null_handling_clause()
328 }
329
330 /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details.
331 pub fn is_ordered_set_aggregate(&self) -> bool {
332 self.inner.is_ordered_set_aggregate()
333 }
334
335 /// Returns the documentation for this Aggregate UDF.
336 ///
337 /// Documentation can be accessed programmatically as well as
338 /// generating publicly facing documentation.
339 pub fn documentation(&self) -> Option<&Documentation> {
340 self.inner.documentation()
341 }
342}
343
344impl<F> From<F> for AggregateUDF
345where
346 F: AggregateUDFImpl + Send + Sync + 'static,
347{
348 fn from(fun: F) -> Self {
349 Self::new_from_impl(fun)
350 }
351}
352
353/// Trait for implementing [`AggregateUDF`].
354///
355/// This trait exposes the full API for implementing user defined aggregate functions and
356/// can be used to implement any function.
357///
358/// See [`advanced_udaf.rs`] for a full example with complete implementation and
359/// [`AggregateUDF`] for other available options.
360///
361/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
362///
363/// # Basic Example
364/// ```
365/// # use std::any::Any;
366/// # use std::sync::{Arc, LazyLock};
367/// # use arrow::datatypes::{DataType, FieldRef};
368/// # use datafusion_common::{DataFusionError, plan_err, Result};
369/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
370/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
371/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
372/// # use arrow::datatypes::Schema;
373/// # use arrow::datatypes::Field;
374///
375/// #[derive(Debug, Clone)]
376/// struct GeoMeanUdf {
377/// signature: Signature,
378/// }
379///
380/// impl GeoMeanUdf {
381/// fn new() -> Self {
382/// Self {
383/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
384/// }
385/// }
386/// }
387///
388/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
389/// Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
390/// .with_argument("arg1", "The Float64 number for the geometric mean")
391/// .build()
392/// });
393///
394/// fn get_doc() -> &'static Documentation {
395/// &DOCUMENTATION
396/// }
397///
398/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
399/// impl AggregateUDFImpl for GeoMeanUdf {
400/// fn as_any(&self) -> &dyn Any { self }
401/// fn name(&self) -> &str { "geo_mean" }
402/// fn signature(&self) -> &Signature { &self.signature }
403/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
404/// if !matches!(args.get(0), Some(&DataType::Float64)) {
405/// return plan_err!("geo_mean only accepts Float64 arguments");
406/// }
407/// Ok(DataType::Float64)
408/// }
409/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
410/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
411/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
412/// Ok(vec![
413/// Arc::new(args.return_field.as_ref().clone().with_name("value")),
414/// Arc::new(Field::new("ordering", DataType::UInt32, true))
415/// ])
416/// }
417/// fn documentation(&self) -> Option<&Documentation> {
418/// Some(get_doc())
419/// }
420/// }
421///
422/// // Create a new AggregateUDF from the implementation
423/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
424///
425/// // Call the function `geo_mean(col)`
426/// let expr = geometric_mean.call(vec![col("a")]);
427/// ```
428pub trait AggregateUDFImpl: Debug + Send + Sync {
429 // Note: When adding any methods (with default implementations), remember to add them also
430 // into the AliasedAggregateUDFImpl below!
431
432 /// Returns this object as an [`Any`] trait object
433 fn as_any(&self) -> &dyn Any;
434
435 /// Returns this function's name
436 fn name(&self) -> &str;
437
438 /// Returns any aliases (alternate names) for this function.
439 ///
440 /// Note: `aliases` should only include names other than [`Self::name`].
441 /// Defaults to `[]` (no aliases)
442 fn aliases(&self) -> &[String] {
443 &[]
444 }
445
446 /// Returns the name of the column this expression would create
447 ///
448 /// See [`Expr::schema_name`] for details
449 ///
450 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
451 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
452 let AggregateFunctionParams {
453 args,
454 distinct,
455 filter,
456 order_by,
457 null_treatment,
458 } = params;
459
460 // exclude the first function argument(= column) in ordered set aggregate function,
461 // because it is duplicated with the WITHIN GROUP clause in schema name.
462 let args = if self.is_ordered_set_aggregate() {
463 &args[1..]
464 } else {
465 &args[..]
466 };
467
468 let mut schema_name = String::new();
469
470 schema_name.write_fmt(format_args!(
471 "{}({}{})",
472 self.name(),
473 if *distinct { "DISTINCT " } else { "" },
474 schema_name_from_exprs_comma_separated_without_space(args)?
475 ))?;
476
477 if let Some(null_treatment) = null_treatment {
478 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
479 }
480
481 if let Some(filter) = filter {
482 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
483 };
484
485 if !order_by.is_empty() {
486 let clause = match self.is_ordered_set_aggregate() {
487 true => "WITHIN GROUP",
488 false => "ORDER BY",
489 };
490
491 schema_name.write_fmt(format_args!(
492 " {} [{}]",
493 clause,
494 schema_name_from_sorts(order_by)?
495 ))?;
496 };
497
498 Ok(schema_name)
499 }
500
501 /// Returns a human readable expression.
502 ///
503 /// See [`Expr::human_display`] for details.
504 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
505 let AggregateFunctionParams {
506 args,
507 distinct,
508 filter,
509 order_by,
510 null_treatment,
511 } = params;
512
513 let mut schema_name = String::new();
514
515 schema_name.write_fmt(format_args!(
516 "{}({}{})",
517 self.name(),
518 if *distinct { "DISTINCT " } else { "" },
519 ExprListDisplay::comma_separated(args.as_slice())
520 ))?;
521
522 if let Some(null_treatment) = null_treatment {
523 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
524 }
525
526 if let Some(filter) = filter {
527 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
528 };
529
530 if !order_by.is_empty() {
531 schema_name.write_fmt(format_args!(
532 " ORDER BY [{}]",
533 schema_name_from_sorts(order_by)?
534 ))?;
535 };
536
537 Ok(schema_name)
538 }
539
540 /// Returns the name of the column this expression would create
541 ///
542 /// See [`Expr::schema_name`] for details
543 ///
544 /// Different from `schema_name` in that it is used for window aggregate function
545 ///
546 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
547 fn window_function_schema_name(
548 &self,
549 params: &WindowFunctionParams,
550 ) -> Result<String> {
551 let WindowFunctionParams {
552 args,
553 partition_by,
554 order_by,
555 window_frame,
556 null_treatment,
557 } = params;
558
559 let mut schema_name = String::new();
560 schema_name.write_fmt(format_args!(
561 "{}({})",
562 self.name(),
563 schema_name_from_exprs(args)?
564 ))?;
565
566 if let Some(null_treatment) = null_treatment {
567 schema_name.write_fmt(format_args!(" {null_treatment}"))?;
568 }
569
570 if !partition_by.is_empty() {
571 schema_name.write_fmt(format_args!(
572 " PARTITION BY [{}]",
573 schema_name_from_exprs(partition_by)?
574 ))?;
575 }
576
577 if !order_by.is_empty() {
578 schema_name.write_fmt(format_args!(
579 " ORDER BY [{}]",
580 schema_name_from_sorts(order_by)?
581 ))?;
582 };
583
584 schema_name.write_fmt(format_args!(" {window_frame}"))?;
585
586 Ok(schema_name)
587 }
588
589 /// Returns the user-defined display name of function, given the arguments
590 ///
591 /// This can be used to customize the output column name generated by this
592 /// function.
593 ///
594 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
595 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
596 let AggregateFunctionParams {
597 args,
598 distinct,
599 filter,
600 order_by,
601 null_treatment,
602 } = params;
603
604 let mut display_name = String::new();
605
606 display_name.write_fmt(format_args!(
607 "{}({}{})",
608 self.name(),
609 if *distinct { "DISTINCT " } else { "" },
610 expr_vec_fmt!(args)
611 ))?;
612
613 if let Some(nt) = null_treatment {
614 display_name.write_fmt(format_args!(" {nt}"))?;
615 }
616 if let Some(fe) = filter {
617 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
618 }
619 if !order_by.is_empty() {
620 display_name.write_fmt(format_args!(
621 " ORDER BY [{}]",
622 order_by
623 .iter()
624 .map(|o| format!("{o}"))
625 .collect::<Vec<String>>()
626 .join(", ")
627 ))?;
628 }
629
630 Ok(display_name)
631 }
632
633 /// Returns the user-defined display name of function, given the arguments
634 ///
635 /// This can be used to customize the output column name generated by this
636 /// function.
637 ///
638 /// Different from `display_name` in that it is used for window aggregate function
639 ///
640 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
641 fn window_function_display_name(
642 &self,
643 params: &WindowFunctionParams,
644 ) -> Result<String> {
645 let WindowFunctionParams {
646 args,
647 partition_by,
648 order_by,
649 window_frame,
650 null_treatment,
651 } = params;
652
653 let mut display_name = String::new();
654
655 display_name.write_fmt(format_args!(
656 "{}({})",
657 self.name(),
658 expr_vec_fmt!(args)
659 ))?;
660
661 if let Some(null_treatment) = null_treatment {
662 display_name.write_fmt(format_args!(" {null_treatment}"))?;
663 }
664
665 if !partition_by.is_empty() {
666 display_name.write_fmt(format_args!(
667 " PARTITION BY [{}]",
668 expr_vec_fmt!(partition_by)
669 ))?;
670 }
671
672 if !order_by.is_empty() {
673 display_name
674 .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
675 };
676
677 display_name.write_fmt(format_args!(
678 " {} BETWEEN {} AND {}",
679 window_frame.units, window_frame.start_bound, window_frame.end_bound
680 ))?;
681
682 Ok(display_name)
683 }
684
685 /// Returns the function's [`Signature`] for information about what input
686 /// types are accepted and the function's Volatility.
687 fn signature(&self) -> &Signature;
688
689 /// What [`DataType`] will be returned by this function, given the types of
690 /// the arguments
691 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
692
693 /// What type will be returned by this function, given the arguments?
694 ///
695 /// By default, this function calls [`Self::return_type`] with the
696 /// types of each argument.
697 ///
698 /// # Notes
699 ///
700 /// Most UDFs should implement [`Self::return_type`] and not this
701 /// function as the output type for most functions only depends on the types
702 /// of their inputs (e.g. `sum(f64)` is always `f64`).
703 ///
704 /// This function can be used for more advanced cases such as:
705 ///
706 /// 1. specifying nullability
707 /// 2. return types based on the **values** of the arguments (rather than
708 /// their **types**.
709 /// 3. return types based on metadata within the fields of the inputs
710 fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
711 let arg_types: Vec<_> =
712 arg_fields.iter().map(|f| f.data_type()).cloned().collect();
713 let data_type = self.return_type(&arg_types)?;
714
715 Ok(Arc::new(Field::new(
716 self.name(),
717 data_type,
718 self.is_nullable(),
719 )))
720 }
721
722 /// Whether the aggregate function is nullable.
723 ///
724 /// Nullable means that the function could return `null` for any inputs.
725 /// For example, aggregate functions like `COUNT` always return a non null value
726 /// but others like `MIN` will return `NULL` if there is nullable input.
727 /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
728 fn is_nullable(&self) -> bool {
729 true
730 }
731
732 /// Return a new [`Accumulator`] that aggregates values for a specific
733 /// group during query execution.
734 ///
735 /// acc_args: [`AccumulatorArgs`] contains information about how the
736 /// aggregate function was called.
737 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
738
739 /// Return the fields used to store the intermediate state of this accumulator.
740 ///
741 /// See [`Accumulator::state`] for background information.
742 ///
743 /// args: [`StateFieldsArgs`] contains arguments passed to the
744 /// aggregate function's accumulator.
745 ///
746 /// # Notes:
747 ///
748 /// The default implementation returns a single state field named `name`
749 /// with the same type as `value_type`. This is suitable for aggregates such
750 /// as `SUM` or `MIN` where partial state can be combined by applying the
751 /// same aggregate.
752 ///
753 /// For aggregates such as `AVG` where the partial state is more complex
754 /// (e.g. a COUNT and a SUM), this method is used to define the additional
755 /// fields.
756 ///
757 /// The name of the fields must be unique within the query and thus should
758 /// be derived from `name`. See [`format_state_name`] for a utility function
759 /// to generate a unique name.
760 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
761 let fields = vec![args
762 .return_field
763 .as_ref()
764 .clone()
765 .with_name(format_state_name(args.name, "value"))];
766
767 Ok(fields
768 .into_iter()
769 .map(Arc::new)
770 .chain(args.ordering_fields.to_vec())
771 .collect())
772 }
773
774 /// If the aggregate expression has a specialized
775 /// [`GroupsAccumulator`] implementation. If this returns true,
776 /// `[Self::create_groups_accumulator]` will be called.
777 ///
778 /// # Notes
779 ///
780 /// Even if this function returns true, DataFusion will still use
781 /// [`Self::accumulator`] for certain queries, such as when this aggregate is
782 /// used as a window function or when there no GROUP BY columns in the
783 /// query.
784 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
785 false
786 }
787
788 /// Return a specialized [`GroupsAccumulator`] that manages state
789 /// for all groups.
790 ///
791 /// For maximum performance, a [`GroupsAccumulator`] should be
792 /// implemented in addition to [`Accumulator`].
793 fn create_groups_accumulator(
794 &self,
795 _args: AccumulatorArgs,
796 ) -> Result<Box<dyn GroupsAccumulator>> {
797 not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
798 }
799
800 /// Sliding accumulator is an alternative accumulator that can be used for
801 /// window functions. It has retract method to revert the previous update.
802 ///
803 /// See [retract_batch] for more details.
804 ///
805 /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
806 fn create_sliding_accumulator(
807 &self,
808 args: AccumulatorArgs,
809 ) -> Result<Box<dyn Accumulator>> {
810 self.accumulator(args)
811 }
812
813 /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
814 /// satisfied by its input. If this is not the case, UDFs with order
815 /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
816 /// the correct result with possibly more work internally.
817 ///
818 /// # Returns
819 ///
820 /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
821 /// If the expression can benefit from existing input ordering, but does
822 /// not implement the method, returns an error. Order insensitive and hard
823 /// requirement aggregators return `Ok(None)`.
824 fn with_beneficial_ordering(
825 self: Arc<Self>,
826 _beneficial_ordering: bool,
827 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
828 if self.order_sensitivity().is_beneficial() {
829 return exec_err!(
830 "Should implement with satisfied for aggregator :{:?}",
831 self.name()
832 );
833 }
834 Ok(None)
835 }
836
837 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
838 /// for possible options.
839 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
840 // We have hard ordering requirements by default, meaning that order
841 // sensitive UDFs need their input orderings to satisfy their ordering
842 // requirements to generate correct results.
843 AggregateOrderSensitivity::HardRequirement
844 }
845
846 /// Optionally apply per-UDaF simplification / rewrite rules.
847 ///
848 /// This can be used to apply function specific simplification rules during
849 /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
850 /// implementation does nothing.
851 ///
852 /// Note that DataFusion handles simplifying arguments and "constant
853 /// folding" (replacing a function call with constant arguments such as
854 /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
855 /// optimizations manually for specific UDFs.
856 ///
857 /// # Returns
858 ///
859 /// [None] if simplify is not defined or,
860 ///
861 /// Or, a closure with two arguments:
862 /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
863 /// * 'info': [crate::simplify::SimplifyInfo]
864 ///
865 /// closure returns simplified [Expr] or an error.
866 ///
867 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
868 None
869 }
870
871 /// Returns the reverse expression of the aggregate function.
872 fn reverse_expr(&self) -> ReversedUDAF {
873 ReversedUDAF::NotSupported
874 }
875
876 /// Coerce arguments of a function call to types that the function can evaluate.
877 ///
878 /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
879 /// UDAFs should return one of the other variants of `TypeSignature` which handle common
880 /// cases
881 ///
882 /// See the [type coercion module](crate::type_coercion)
883 /// documentation for more details on type coercion
884 ///
885 /// For example, if your function requires a floating point arguments, but the user calls
886 /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
887 /// to ensure the argument was cast to `1::double`
888 ///
889 /// # Parameters
890 /// * `arg_types`: The argument types of the arguments this function with
891 ///
892 /// # Return value
893 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
894 /// arguments to these specific types.
895 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
896 not_impl_err!("Function {} does not implement coerce_types", self.name())
897 }
898
899 /// Return true if this aggregate UDF is equal to the other.
900 ///
901 /// Allows customizing the equality of aggregate UDFs.
902 /// *Must* be implemented explicitly if the UDF type has internal state.
903 /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
904 ///
905 /// - reflexive: `a.equals(a)`;
906 /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
907 /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
908 ///
909 /// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
910 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
911 self.as_any().type_id() == other.as_any().type_id()
912 && self.name() == other.name()
913 && self.aliases() == other.aliases()
914 && self.signature() == other.signature()
915 }
916
917 /// Returns a hash value for this aggregate UDF.
918 ///
919 /// Allows customizing the hash code of aggregate UDFs.
920 /// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
921 ///
922 /// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
923 /// their `hash_value`s must be the same.
924 ///
925 /// By default, it is consistent with default implementation of [`Self::equals`].
926 fn hash_value(&self) -> u64 {
927 let hasher = &mut DefaultHasher::new();
928 self.as_any().type_id().hash(hasher);
929 self.name().hash(hasher);
930 self.aliases().hash(hasher);
931 self.signature().hash(hasher);
932 hasher.finish()
933 }
934
935 /// If this function is max, return true
936 /// If the function is min, return false
937 /// Otherwise return None (the default)
938 ///
939 ///
940 /// Note: this is used to use special aggregate implementations in certain conditions
941 fn is_descending(&self) -> Option<bool> {
942 None
943 }
944
945 /// Return the value of this aggregate function if it can be determined
946 /// entirely from statistics and arguments.
947 ///
948 /// Using a [`ScalarValue`] rather than a runtime computation can significantly
949 /// improving query performance.
950 ///
951 /// For example, if the minimum value of column `x` is known to be `42` from
952 /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
953 fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
954 None
955 }
956
957 /// Returns default value of the function given the input is all `null`.
958 ///
959 /// Most of the aggregate function return Null if input is Null,
960 /// while `count` returns 0 if input is Null
961 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
962 ScalarValue::try_from(data_type)
963 }
964
965 /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true
966 /// If the function does not, return false
967 fn supports_null_handling_clause(&self) -> bool {
968 true
969 }
970
971 /// If this function is ordered-set aggregate function, return true
972 /// If the function is not, return false
973 fn is_ordered_set_aggregate(&self) -> bool {
974 false
975 }
976
977 /// Returns the documentation for this Aggregate UDF.
978 ///
979 /// Documentation can be accessed programmatically as well as
980 /// generating publicly facing documentation.
981 fn documentation(&self) -> Option<&Documentation> {
982 None
983 }
984
985 /// Indicates whether the aggregation function is monotonic as a set
986 /// function. See [`SetMonotonicity`] for details.
987 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
988 SetMonotonicity::NotMonotonic
989 }
990}
991
992impl PartialEq for dyn AggregateUDFImpl {
993 fn eq(&self, other: &Self) -> bool {
994 self.equals(other)
995 }
996}
997
998// Manual implementation of `PartialOrd`
999// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
1000// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
1001impl PartialOrd for dyn AggregateUDFImpl {
1002 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1003 match self.name().partial_cmp(other.name()) {
1004 Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
1005 cmp => cmp,
1006 }
1007 }
1008}
1009
1010pub enum ReversedUDAF {
1011 /// The expression is the same as the original expression, like SUM, COUNT
1012 Identical,
1013 /// The expression does not support reverse calculation
1014 NotSupported,
1015 /// The expression is different from the original expression
1016 Reversed(Arc<AggregateUDF>),
1017}
1018
1019/// AggregateUDF that adds an alias to the underlying function. It is better to
1020/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1021#[derive(Debug)]
1022struct AliasedAggregateUDFImpl {
1023 inner: Arc<dyn AggregateUDFImpl>,
1024 aliases: Vec<String>,
1025}
1026
1027impl AliasedAggregateUDFImpl {
1028 pub fn new(
1029 inner: Arc<dyn AggregateUDFImpl>,
1030 new_aliases: impl IntoIterator<Item = &'static str>,
1031 ) -> Self {
1032 let mut aliases = inner.aliases().to_vec();
1033 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1034
1035 Self { inner, aliases }
1036 }
1037}
1038
1039impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1040 fn as_any(&self) -> &dyn Any {
1041 self
1042 }
1043
1044 fn name(&self) -> &str {
1045 self.inner.name()
1046 }
1047
1048 fn signature(&self) -> &Signature {
1049 self.inner.signature()
1050 }
1051
1052 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1053 self.inner.return_type(arg_types)
1054 }
1055
1056 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1057 self.inner.accumulator(acc_args)
1058 }
1059
1060 fn aliases(&self) -> &[String] {
1061 &self.aliases
1062 }
1063
1064 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1065 self.inner.state_fields(args)
1066 }
1067
1068 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1069 self.inner.groups_accumulator_supported(args)
1070 }
1071
1072 fn create_groups_accumulator(
1073 &self,
1074 args: AccumulatorArgs,
1075 ) -> Result<Box<dyn GroupsAccumulator>> {
1076 self.inner.create_groups_accumulator(args)
1077 }
1078
1079 fn create_sliding_accumulator(
1080 &self,
1081 args: AccumulatorArgs,
1082 ) -> Result<Box<dyn Accumulator>> {
1083 self.inner.accumulator(args)
1084 }
1085
1086 fn with_beneficial_ordering(
1087 self: Arc<Self>,
1088 beneficial_ordering: bool,
1089 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1090 Arc::clone(&self.inner)
1091 .with_beneficial_ordering(beneficial_ordering)
1092 .map(|udf| {
1093 udf.map(|udf| {
1094 Arc::new(AliasedAggregateUDFImpl {
1095 inner: udf,
1096 aliases: self.aliases.clone(),
1097 }) as Arc<dyn AggregateUDFImpl>
1098 })
1099 })
1100 }
1101
1102 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1103 self.inner.order_sensitivity()
1104 }
1105
1106 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1107 self.inner.simplify()
1108 }
1109
1110 fn reverse_expr(&self) -> ReversedUDAF {
1111 self.inner.reverse_expr()
1112 }
1113
1114 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1115 self.inner.coerce_types(arg_types)
1116 }
1117
1118 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1119 if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1120 self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1121 } else {
1122 false
1123 }
1124 }
1125
1126 fn hash_value(&self) -> u64 {
1127 let hasher = &mut DefaultHasher::new();
1128 self.inner.hash_value().hash(hasher);
1129 self.aliases.hash(hasher);
1130 hasher.finish()
1131 }
1132
1133 fn is_descending(&self) -> Option<bool> {
1134 self.inner.is_descending()
1135 }
1136
1137 fn documentation(&self) -> Option<&Documentation> {
1138 self.inner.documentation()
1139 }
1140}
1141
1142// Aggregate UDF doc sections for use in public documentation
1143pub mod aggregate_doc_sections {
1144 use crate::DocSection;
1145
1146 pub fn doc_sections() -> Vec<DocSection> {
1147 vec![
1148 DOC_SECTION_GENERAL,
1149 DOC_SECTION_STATISTICAL,
1150 DOC_SECTION_APPROXIMATE,
1151 ]
1152 }
1153
1154 pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1155 include: true,
1156 label: "General Functions",
1157 description: None,
1158 };
1159
1160 pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1161 include: true,
1162 label: "Statistical Functions",
1163 description: None,
1164 };
1165
1166 pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1167 include: true,
1168 label: "Approximate Functions",
1169 description: None,
1170 };
1171}
1172
1173/// Indicates whether an aggregation function is monotonic as a set
1174/// function. A set function is monotonically increasing if its value
1175/// increases as its argument grows (as a set). Formally, `f` is a
1176/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1177/// is a superset of `T`.
1178///
1179/// For example `COUNT` and `MAX` are monotonically increasing as their
1180/// values always increase (or stay the same) as new values are seen. On
1181/// the other hand, `MIN` is monotonically decreasing as its value always
1182/// decreases or stays the same as new values are seen.
1183#[derive(Debug, Clone, PartialEq)]
1184pub enum SetMonotonicity {
1185 /// Aggregate value increases or stays the same as the input set grows.
1186 Increasing,
1187 /// Aggregate value decreases or stays the same as the input set grows.
1188 Decreasing,
1189 /// Aggregate value may increase, decrease, or stay the same as the input
1190 /// set grows.
1191 NotMonotonic,
1192}
1193
1194#[cfg(test)]
1195mod test {
1196 use crate::{AggregateUDF, AggregateUDFImpl};
1197 use arrow::datatypes::{DataType, FieldRef};
1198 use datafusion_common::Result;
1199 use datafusion_expr_common::accumulator::Accumulator;
1200 use datafusion_expr_common::signature::{Signature, Volatility};
1201 use datafusion_functions_aggregate_common::accumulator::{
1202 AccumulatorArgs, StateFieldsArgs,
1203 };
1204 use std::any::Any;
1205 use std::cmp::Ordering;
1206
1207 #[derive(Debug, Clone)]
1208 struct AMeanUdf {
1209 signature: Signature,
1210 }
1211
1212 impl AMeanUdf {
1213 fn new() -> Self {
1214 Self {
1215 signature: Signature::uniform(
1216 1,
1217 vec![DataType::Float64],
1218 Volatility::Immutable,
1219 ),
1220 }
1221 }
1222 }
1223
1224 impl AggregateUDFImpl for AMeanUdf {
1225 fn as_any(&self) -> &dyn Any {
1226 self
1227 }
1228 fn name(&self) -> &str {
1229 "a"
1230 }
1231 fn signature(&self) -> &Signature {
1232 &self.signature
1233 }
1234 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1235 unimplemented!()
1236 }
1237 fn accumulator(
1238 &self,
1239 _acc_args: AccumulatorArgs,
1240 ) -> Result<Box<dyn Accumulator>> {
1241 unimplemented!()
1242 }
1243 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1244 unimplemented!()
1245 }
1246 }
1247
1248 #[derive(Debug, Clone)]
1249 struct BMeanUdf {
1250 signature: Signature,
1251 }
1252 impl BMeanUdf {
1253 fn new() -> Self {
1254 Self {
1255 signature: Signature::uniform(
1256 1,
1257 vec![DataType::Float64],
1258 Volatility::Immutable,
1259 ),
1260 }
1261 }
1262 }
1263
1264 impl AggregateUDFImpl for BMeanUdf {
1265 fn as_any(&self) -> &dyn Any {
1266 self
1267 }
1268 fn name(&self) -> &str {
1269 "b"
1270 }
1271 fn signature(&self) -> &Signature {
1272 &self.signature
1273 }
1274 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1275 unimplemented!()
1276 }
1277 fn accumulator(
1278 &self,
1279 _acc_args: AccumulatorArgs,
1280 ) -> Result<Box<dyn Accumulator>> {
1281 unimplemented!()
1282 }
1283 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1284 unimplemented!()
1285 }
1286 }
1287
1288 #[test]
1289 fn test_partial_ord() {
1290 // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1291 // not intended to exhaustively test all possibilities
1292 let a1 = AggregateUDF::from(AMeanUdf::new());
1293 let a2 = AggregateUDF::from(AMeanUdf::new());
1294 assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1295
1296 let b1 = AggregateUDF::from(BMeanUdf::new());
1297 assert!(a1 < b1);
1298 assert!(!(a1 == b1));
1299 }
1300}