datafusion_expr/udaf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33 schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34 schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
35 WindowFunctionParams,
36};
37use crate::function::{
38 AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65/// access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79 inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83 fn eq(&self, other: &Self) -> bool {
84 self.inner.equals(other.inner.as_ref())
85 }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91 fn hash<H: Hasher>(&self, state: &mut H) {
92 self.inner.hash_value().hash(state)
93 }
94}
95
96impl fmt::Display for AggregateUDF {
97 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98 write!(f, "{}", self.name())
99 }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105 /// The statistics of the aggregate input
106 pub statistics: &'a Statistics,
107 /// The resolved return type of the aggregate function
108 pub return_type: &'a DataType,
109 /// Whether the aggregate function is distinct.
110 ///
111 /// ```sql
112 /// SELECT COUNT(DISTINCT column1) FROM t;
113 /// ```
114 pub is_distinct: bool,
115 /// The physical expression of arguments the aggregate function takes.
116 pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121 ///
122 /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123 pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124 where
125 F: AggregateUDFImpl + 'static,
126 {
127 Self::new_from_shared_impl(Arc::new(fun))
128 }
129
130 /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131 pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132 Self { inner: fun }
133 }
134
135 /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136 pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137 &self.inner
138 }
139
140 /// Adds additional names that can be used to invoke this function, in
141 /// addition to `name`
142 ///
143 /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145 Self::new_from_impl(AliasedAggregateUDFImpl::new(
146 Arc::clone(&self.inner),
147 aliases,
148 ))
149 }
150
151 /// Creates an [`Expr`] that calls the aggregate function.
152 ///
153 /// This utility allows using the UDAF without requiring access to
154 /// the registry, such as with the DataFrame API.
155 pub fn call(&self, args: Vec<Expr>) -> Expr {
156 Expr::AggregateFunction(AggregateFunction::new_udf(
157 Arc::new(self.clone()),
158 args,
159 false,
160 None,
161 None,
162 None,
163 ))
164 }
165
166 /// Returns this function's name
167 ///
168 /// See [`AggregateUDFImpl::name`] for more details.
169 pub fn name(&self) -> &str {
170 self.inner.name()
171 }
172
173 /// See [`AggregateUDFImpl::schema_name`] for more details.
174 pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
175 self.inner.schema_name(params)
176 }
177
178 /// Returns a human readable expression.
179 ///
180 /// See [`Expr::human_display`] for details.
181 pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
182 self.inner.human_display(params)
183 }
184
185 pub fn window_function_schema_name(
186 &self,
187 params: &WindowFunctionParams,
188 ) -> Result<String> {
189 self.inner.window_function_schema_name(params)
190 }
191
192 /// See [`AggregateUDFImpl::display_name`] for more details.
193 pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
194 self.inner.display_name(params)
195 }
196
197 pub fn window_function_display_name(
198 &self,
199 params: &WindowFunctionParams,
200 ) -> Result<String> {
201 self.inner.window_function_display_name(params)
202 }
203
204 pub fn is_nullable(&self) -> bool {
205 self.inner.is_nullable()
206 }
207
208 /// Returns the aliases for this function.
209 pub fn aliases(&self) -> &[String] {
210 self.inner.aliases()
211 }
212
213 /// Returns this function's signature (what input types are accepted)
214 ///
215 /// See [`AggregateUDFImpl::signature`] for more details.
216 pub fn signature(&self) -> &Signature {
217 self.inner.signature()
218 }
219
220 /// Return the type of the function given its input types
221 ///
222 /// See [`AggregateUDFImpl::return_type`] for more details.
223 pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
224 self.inner.return_type(args)
225 }
226
227 /// Return an accumulator the given aggregate, given its return datatype
228 pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
229 self.inner.accumulator(acc_args)
230 }
231
232 /// Return the fields used to store the intermediate state for this aggregator, given
233 /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
234 /// for more details.
235 ///
236 /// This is used to support multi-phase aggregations
237 pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
238 self.inner.state_fields(args)
239 }
240
241 /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
242 pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
243 self.inner.groups_accumulator_supported(args)
244 }
245
246 /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
247 pub fn create_groups_accumulator(
248 &self,
249 args: AccumulatorArgs,
250 ) -> Result<Box<dyn GroupsAccumulator>> {
251 self.inner.create_groups_accumulator(args)
252 }
253
254 pub fn create_sliding_accumulator(
255 &self,
256 args: AccumulatorArgs,
257 ) -> Result<Box<dyn Accumulator>> {
258 self.inner.create_sliding_accumulator(args)
259 }
260
261 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
262 self.inner.coerce_types(arg_types)
263 }
264
265 /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
266 pub fn with_beneficial_ordering(
267 self,
268 beneficial_ordering: bool,
269 ) -> Result<Option<AggregateUDF>> {
270 self.inner
271 .with_beneficial_ordering(beneficial_ordering)
272 .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
273 }
274
275 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
276 /// for possible options.
277 pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
278 self.inner.order_sensitivity()
279 }
280
281 /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
282 /// generate same result with this `AggregateUDF` when iterated in reverse
283 /// order, and `None` if there is no such `AggregateUDF`).
284 pub fn reverse_udf(&self) -> ReversedUDAF {
285 self.inner.reverse_expr()
286 }
287
288 /// Do the function rewrite
289 ///
290 /// See [`AggregateUDFImpl::simplify`] for more details.
291 pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
292 self.inner.simplify()
293 }
294
295 /// Returns true if the function is max, false if the function is min
296 /// None in all other cases, used in certain optimizations for
297 /// or aggregate
298 pub fn is_descending(&self) -> Option<bool> {
299 self.inner.is_descending()
300 }
301
302 /// Return the value of this aggregate function if it can be determined
303 /// entirely from statistics and arguments.
304 ///
305 /// See [`AggregateUDFImpl::value_from_stats`] for more details.
306 pub fn value_from_stats(
307 &self,
308 statistics_args: &StatisticsArgs,
309 ) -> Option<ScalarValue> {
310 self.inner.value_from_stats(statistics_args)
311 }
312
313 /// See [`AggregateUDFImpl::default_value`] for more details.
314 pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
315 self.inner.default_value(data_type)
316 }
317
318 /// Returns the documentation for this Aggregate UDF.
319 ///
320 /// Documentation can be accessed programmatically as well as
321 /// generating publicly facing documentation.
322 pub fn documentation(&self) -> Option<&Documentation> {
323 self.inner.documentation()
324 }
325}
326
327impl<F> From<F> for AggregateUDF
328where
329 F: AggregateUDFImpl + Send + Sync + 'static,
330{
331 fn from(fun: F) -> Self {
332 Self::new_from_impl(fun)
333 }
334}
335
336/// Trait for implementing [`AggregateUDF`].
337///
338/// This trait exposes the full API for implementing user defined aggregate functions and
339/// can be used to implement any function.
340///
341/// See [`advanced_udaf.rs`] for a full example with complete implementation and
342/// [`AggregateUDF`] for other available options.
343///
344/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
345///
346/// # Basic Example
347/// ```
348/// # use std::any::Any;
349/// # use std::sync::LazyLock;
350/// # use arrow::datatypes::DataType;
351/// # use datafusion_common::{DataFusionError, plan_err, Result};
352/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
353/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
354/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
355/// # use arrow::datatypes::Schema;
356/// # use arrow::datatypes::Field;
357///
358/// #[derive(Debug, Clone)]
359/// struct GeoMeanUdf {
360/// signature: Signature,
361/// }
362///
363/// impl GeoMeanUdf {
364/// fn new() -> Self {
365/// Self {
366/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
367/// }
368/// }
369/// }
370///
371/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
372/// Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
373/// .with_argument("arg1", "The Float64 number for the geometric mean")
374/// .build()
375/// });
376///
377/// fn get_doc() -> &'static Documentation {
378/// &DOCUMENTATION
379/// }
380///
381/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
382/// impl AggregateUDFImpl for GeoMeanUdf {
383/// fn as_any(&self) -> &dyn Any { self }
384/// fn name(&self) -> &str { "geo_mean" }
385/// fn signature(&self) -> &Signature { &self.signature }
386/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
387/// if !matches!(args.get(0), Some(&DataType::Float64)) {
388/// return plan_err!("geo_mean only accepts Float64 arguments");
389/// }
390/// Ok(DataType::Float64)
391/// }
392/// // This is the accumulator factory; DataFusion uses it to create new accumulators.
393/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
394/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
395/// Ok(vec![
396/// Field::new("value", args.return_type.clone(), true),
397/// Field::new("ordering", DataType::UInt32, true)
398/// ])
399/// }
400/// fn documentation(&self) -> Option<&Documentation> {
401/// Some(get_doc())
402/// }
403/// }
404///
405/// // Create a new AggregateUDF from the implementation
406/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
407///
408/// // Call the function `geo_mean(col)`
409/// let expr = geometric_mean.call(vec![col("a")]);
410/// ```
411pub trait AggregateUDFImpl: Debug + Send + Sync {
412 // Note: When adding any methods (with default implementations), remember to add them also
413 // into the AliasedAggregateUDFImpl below!
414
415 /// Returns this object as an [`Any`] trait object
416 fn as_any(&self) -> &dyn Any;
417
418 /// Returns this function's name
419 fn name(&self) -> &str;
420
421 /// Returns the name of the column this expression would create
422 ///
423 /// See [`Expr::schema_name`] for details
424 ///
425 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
426 fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
427 let AggregateFunctionParams {
428 args,
429 distinct,
430 filter,
431 order_by,
432 null_treatment,
433 } = params;
434
435 let mut schema_name = String::new();
436
437 schema_name.write_fmt(format_args!(
438 "{}({}{})",
439 self.name(),
440 if *distinct { "DISTINCT " } else { "" },
441 schema_name_from_exprs_comma_separated_without_space(args)?
442 ))?;
443
444 if let Some(null_treatment) = null_treatment {
445 schema_name.write_fmt(format_args!(" {}", null_treatment))?;
446 }
447
448 if let Some(filter) = filter {
449 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
450 };
451
452 if let Some(order_by) = order_by {
453 schema_name.write_fmt(format_args!(
454 " ORDER BY [{}]",
455 schema_name_from_sorts(order_by)?
456 ))?;
457 };
458
459 Ok(schema_name)
460 }
461
462 /// Returns a human readable expression.
463 ///
464 /// See [`Expr::human_display`] for details.
465 fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
466 let AggregateFunctionParams {
467 args,
468 distinct,
469 filter,
470 order_by,
471 null_treatment,
472 } = params;
473
474 let mut schema_name = String::new();
475
476 schema_name.write_fmt(format_args!(
477 "{}({}{})",
478 self.name(),
479 if *distinct { "DISTINCT " } else { "" },
480 ExprListDisplay::comma_separated(args.as_slice())
481 ))?;
482
483 if let Some(null_treatment) = null_treatment {
484 schema_name.write_fmt(format_args!(" {}", null_treatment))?;
485 }
486
487 if let Some(filter) = filter {
488 schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
489 };
490
491 if let Some(order_by) = order_by {
492 schema_name.write_fmt(format_args!(
493 " ORDER BY [{}]",
494 schema_name_from_sorts(order_by)?
495 ))?;
496 };
497
498 Ok(schema_name)
499 }
500
501 /// Returns the name of the column this expression would create
502 ///
503 /// See [`Expr::schema_name`] for details
504 ///
505 /// Different from `schema_name` in that it is used for window aggregate function
506 ///
507 /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
508 fn window_function_schema_name(
509 &self,
510 params: &WindowFunctionParams,
511 ) -> Result<String> {
512 let WindowFunctionParams {
513 args,
514 partition_by,
515 order_by,
516 window_frame,
517 null_treatment,
518 } = params;
519
520 let mut schema_name = String::new();
521 schema_name.write_fmt(format_args!(
522 "{}({})",
523 self.name(),
524 schema_name_from_exprs(args)?
525 ))?;
526
527 if let Some(null_treatment) = null_treatment {
528 schema_name.write_fmt(format_args!(" {}", null_treatment))?;
529 }
530
531 if !partition_by.is_empty() {
532 schema_name.write_fmt(format_args!(
533 " PARTITION BY [{}]",
534 schema_name_from_exprs(partition_by)?
535 ))?;
536 }
537
538 if !order_by.is_empty() {
539 schema_name.write_fmt(format_args!(
540 " ORDER BY [{}]",
541 schema_name_from_sorts(order_by)?
542 ))?;
543 };
544
545 schema_name.write_fmt(format_args!(" {window_frame}"))?;
546
547 Ok(schema_name)
548 }
549
550 /// Returns the user-defined display name of function, given the arguments
551 ///
552 /// This can be used to customize the output column name generated by this
553 /// function.
554 ///
555 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
556 fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
557 let AggregateFunctionParams {
558 args,
559 distinct,
560 filter,
561 order_by,
562 null_treatment,
563 } = params;
564
565 let mut display_name = String::new();
566
567 display_name.write_fmt(format_args!(
568 "{}({}{})",
569 self.name(),
570 if *distinct { "DISTINCT " } else { "" },
571 expr_vec_fmt!(args)
572 ))?;
573
574 if let Some(nt) = null_treatment {
575 display_name.write_fmt(format_args!(" {}", nt))?;
576 }
577 if let Some(fe) = filter {
578 display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
579 }
580 if let Some(ob) = order_by {
581 display_name.write_fmt(format_args!(
582 " ORDER BY [{}]",
583 ob.iter()
584 .map(|o| format!("{o}"))
585 .collect::<Vec<String>>()
586 .join(", ")
587 ))?;
588 }
589
590 Ok(display_name)
591 }
592
593 /// Returns the user-defined display name of function, given the arguments
594 ///
595 /// This can be used to customize the output column name generated by this
596 /// function.
597 ///
598 /// Different from `display_name` in that it is used for window aggregate function
599 ///
600 /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
601 fn window_function_display_name(
602 &self,
603 params: &WindowFunctionParams,
604 ) -> Result<String> {
605 let WindowFunctionParams {
606 args,
607 partition_by,
608 order_by,
609 window_frame,
610 null_treatment,
611 } = params;
612
613 let mut display_name = String::new();
614
615 display_name.write_fmt(format_args!(
616 "{}({})",
617 self.name(),
618 expr_vec_fmt!(args)
619 ))?;
620
621 if let Some(null_treatment) = null_treatment {
622 display_name.write_fmt(format_args!(" {}", null_treatment))?;
623 }
624
625 if !partition_by.is_empty() {
626 display_name.write_fmt(format_args!(
627 " PARTITION BY [{}]",
628 expr_vec_fmt!(partition_by)
629 ))?;
630 }
631
632 if !order_by.is_empty() {
633 display_name
634 .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
635 };
636
637 display_name.write_fmt(format_args!(
638 " {} BETWEEN {} AND {}",
639 window_frame.units, window_frame.start_bound, window_frame.end_bound
640 ))?;
641
642 Ok(display_name)
643 }
644
645 /// Returns the function's [`Signature`] for information about what input
646 /// types are accepted and the function's Volatility.
647 fn signature(&self) -> &Signature;
648
649 /// What [`DataType`] will be returned by this function, given the types of
650 /// the arguments
651 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
652
653 /// Whether the aggregate function is nullable.
654 ///
655 /// Nullable means that the function could return `null` for any inputs.
656 /// For example, aggregate functions like `COUNT` always return a non null value
657 /// but others like `MIN` will return `NULL` if there is nullable input.
658 /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
659 fn is_nullable(&self) -> bool {
660 true
661 }
662
663 /// Return a new [`Accumulator`] that aggregates values for a specific
664 /// group during query execution.
665 ///
666 /// acc_args: [`AccumulatorArgs`] contains information about how the
667 /// aggregate function was called.
668 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
669
670 /// Return the fields used to store the intermediate state of this accumulator.
671 ///
672 /// See [`Accumulator::state`] for background information.
673 ///
674 /// args: [`StateFieldsArgs`] contains arguments passed to the
675 /// aggregate function's accumulator.
676 ///
677 /// # Notes:
678 ///
679 /// The default implementation returns a single state field named `name`
680 /// with the same type as `value_type`. This is suitable for aggregates such
681 /// as `SUM` or `MIN` where partial state can be combined by applying the
682 /// same aggregate.
683 ///
684 /// For aggregates such as `AVG` where the partial state is more complex
685 /// (e.g. a COUNT and a SUM), this method is used to define the additional
686 /// fields.
687 ///
688 /// The name of the fields must be unique within the query and thus should
689 /// be derived from `name`. See [`format_state_name`] for a utility function
690 /// to generate a unique name.
691 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
692 let fields = vec![Field::new(
693 format_state_name(args.name, "value"),
694 args.return_type.clone(),
695 true,
696 )];
697
698 Ok(fields
699 .into_iter()
700 .chain(args.ordering_fields.to_vec())
701 .collect())
702 }
703
704 /// If the aggregate expression has a specialized
705 /// [`GroupsAccumulator`] implementation. If this returns true,
706 /// `[Self::create_groups_accumulator]` will be called.
707 ///
708 /// # Notes
709 ///
710 /// Even if this function returns true, DataFusion will still use
711 /// [`Self::accumulator`] for certain queries, such as when this aggregate is
712 /// used as a window function or when there no GROUP BY columns in the
713 /// query.
714 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
715 false
716 }
717
718 /// Return a specialized [`GroupsAccumulator`] that manages state
719 /// for all groups.
720 ///
721 /// For maximum performance, a [`GroupsAccumulator`] should be
722 /// implemented in addition to [`Accumulator`].
723 fn create_groups_accumulator(
724 &self,
725 _args: AccumulatorArgs,
726 ) -> Result<Box<dyn GroupsAccumulator>> {
727 not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
728 }
729
730 /// Returns any aliases (alternate names) for this function.
731 ///
732 /// Note: `aliases` should only include names other than [`Self::name`].
733 /// Defaults to `[]` (no aliases)
734 fn aliases(&self) -> &[String] {
735 &[]
736 }
737
738 /// Sliding accumulator is an alternative accumulator that can be used for
739 /// window functions. It has retract method to revert the previous update.
740 ///
741 /// See [retract_batch] for more details.
742 ///
743 /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
744 fn create_sliding_accumulator(
745 &self,
746 args: AccumulatorArgs,
747 ) -> Result<Box<dyn Accumulator>> {
748 self.accumulator(args)
749 }
750
751 /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
752 /// satisfied by its input. If this is not the case, UDFs with order
753 /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
754 /// the correct result with possibly more work internally.
755 ///
756 /// # Returns
757 ///
758 /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
759 /// If the expression can benefit from existing input ordering, but does
760 /// not implement the method, returns an error. Order insensitive and hard
761 /// requirement aggregators return `Ok(None)`.
762 fn with_beneficial_ordering(
763 self: Arc<Self>,
764 _beneficial_ordering: bool,
765 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
766 if self.order_sensitivity().is_beneficial() {
767 return exec_err!(
768 "Should implement with satisfied for aggregator :{:?}",
769 self.name()
770 );
771 }
772 Ok(None)
773 }
774
775 /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
776 /// for possible options.
777 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
778 // We have hard ordering requirements by default, meaning that order
779 // sensitive UDFs need their input orderings to satisfy their ordering
780 // requirements to generate correct results.
781 AggregateOrderSensitivity::HardRequirement
782 }
783
784 /// Optionally apply per-UDaF simplification / rewrite rules.
785 ///
786 /// This can be used to apply function specific simplification rules during
787 /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
788 /// implementation does nothing.
789 ///
790 /// Note that DataFusion handles simplifying arguments and "constant
791 /// folding" (replacing a function call with constant arguments such as
792 /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
793 /// optimizations manually for specific UDFs.
794 ///
795 /// # Returns
796 ///
797 /// [None] if simplify is not defined or,
798 ///
799 /// Or, a closure with two arguments:
800 /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
801 /// * 'info': [crate::simplify::SimplifyInfo]
802 ///
803 /// closure returns simplified [Expr] or an error.
804 ///
805 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
806 None
807 }
808
809 /// Returns the reverse expression of the aggregate function.
810 fn reverse_expr(&self) -> ReversedUDAF {
811 ReversedUDAF::NotSupported
812 }
813
814 /// Coerce arguments of a function call to types that the function can evaluate.
815 ///
816 /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
817 /// UDAFs should return one of the other variants of `TypeSignature` which handle common
818 /// cases
819 ///
820 /// See the [type coercion module](crate::type_coercion)
821 /// documentation for more details on type coercion
822 ///
823 /// For example, if your function requires a floating point arguments, but the user calls
824 /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
825 /// to ensure the argument was cast to `1::double`
826 ///
827 /// # Parameters
828 /// * `arg_types`: The argument types of the arguments this function with
829 ///
830 /// # Return value
831 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
832 /// arguments to these specific types.
833 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
834 not_impl_err!("Function {} does not implement coerce_types", self.name())
835 }
836
837 /// Return true if this aggregate UDF is equal to the other.
838 ///
839 /// Allows customizing the equality of aggregate UDFs.
840 /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
841 ///
842 /// - reflexive: `a.equals(a)`;
843 /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
844 /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
845 ///
846 /// By default, compares [`Self::name`] and [`Self::signature`].
847 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
848 self.name() == other.name() && self.signature() == other.signature()
849 }
850
851 /// Returns a hash value for this aggregate UDF.
852 ///
853 /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
854 /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
855 ///
856 /// By default, hashes [`Self::name`] and [`Self::signature`].
857 fn hash_value(&self) -> u64 {
858 let hasher = &mut DefaultHasher::new();
859 self.name().hash(hasher);
860 self.signature().hash(hasher);
861 hasher.finish()
862 }
863
864 /// If this function is max, return true
865 /// If the function is min, return false
866 /// Otherwise return None (the default)
867 ///
868 ///
869 /// Note: this is used to use special aggregate implementations in certain conditions
870 fn is_descending(&self) -> Option<bool> {
871 None
872 }
873
874 /// Return the value of this aggregate function if it can be determined
875 /// entirely from statistics and arguments.
876 ///
877 /// Using a [`ScalarValue`] rather than a runtime computation can significantly
878 /// improving query performance.
879 ///
880 /// For example, if the minimum value of column `x` is known to be `42` from
881 /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
882 fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
883 None
884 }
885
886 /// Returns default value of the function given the input is all `null`.
887 ///
888 /// Most of the aggregate function return Null if input is Null,
889 /// while `count` returns 0 if input is Null
890 fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
891 ScalarValue::try_from(data_type)
892 }
893
894 /// Returns the documentation for this Aggregate UDF.
895 ///
896 /// Documentation can be accessed programmatically as well as
897 /// generating publicly facing documentation.
898 fn documentation(&self) -> Option<&Documentation> {
899 None
900 }
901
902 /// Indicates whether the aggregation function is monotonic as a set
903 /// function. See [`SetMonotonicity`] for details.
904 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
905 SetMonotonicity::NotMonotonic
906 }
907}
908
909impl PartialEq for dyn AggregateUDFImpl {
910 fn eq(&self, other: &Self) -> bool {
911 self.equals(other)
912 }
913}
914
915// Manual implementation of `PartialOrd`
916// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
917// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
918impl PartialOrd for dyn AggregateUDFImpl {
919 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
920 match self.name().partial_cmp(other.name()) {
921 Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
922 cmp => cmp,
923 }
924 }
925}
926
927pub enum ReversedUDAF {
928 /// The expression is the same as the original expression, like SUM, COUNT
929 Identical,
930 /// The expression does not support reverse calculation
931 NotSupported,
932 /// The expression is different from the original expression
933 Reversed(Arc<AggregateUDF>),
934}
935
936/// AggregateUDF that adds an alias to the underlying function. It is better to
937/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
938#[derive(Debug)]
939struct AliasedAggregateUDFImpl {
940 inner: Arc<dyn AggregateUDFImpl>,
941 aliases: Vec<String>,
942}
943
944impl AliasedAggregateUDFImpl {
945 pub fn new(
946 inner: Arc<dyn AggregateUDFImpl>,
947 new_aliases: impl IntoIterator<Item = &'static str>,
948 ) -> Self {
949 let mut aliases = inner.aliases().to_vec();
950 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
951
952 Self { inner, aliases }
953 }
954}
955
956impl AggregateUDFImpl for AliasedAggregateUDFImpl {
957 fn as_any(&self) -> &dyn Any {
958 self
959 }
960
961 fn name(&self) -> &str {
962 self.inner.name()
963 }
964
965 fn signature(&self) -> &Signature {
966 self.inner.signature()
967 }
968
969 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
970 self.inner.return_type(arg_types)
971 }
972
973 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
974 self.inner.accumulator(acc_args)
975 }
976
977 fn aliases(&self) -> &[String] {
978 &self.aliases
979 }
980
981 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
982 self.inner.state_fields(args)
983 }
984
985 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
986 self.inner.groups_accumulator_supported(args)
987 }
988
989 fn create_groups_accumulator(
990 &self,
991 args: AccumulatorArgs,
992 ) -> Result<Box<dyn GroupsAccumulator>> {
993 self.inner.create_groups_accumulator(args)
994 }
995
996 fn create_sliding_accumulator(
997 &self,
998 args: AccumulatorArgs,
999 ) -> Result<Box<dyn Accumulator>> {
1000 self.inner.accumulator(args)
1001 }
1002
1003 fn with_beneficial_ordering(
1004 self: Arc<Self>,
1005 beneficial_ordering: bool,
1006 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1007 Arc::clone(&self.inner)
1008 .with_beneficial_ordering(beneficial_ordering)
1009 .map(|udf| {
1010 udf.map(|udf| {
1011 Arc::new(AliasedAggregateUDFImpl {
1012 inner: udf,
1013 aliases: self.aliases.clone(),
1014 }) as Arc<dyn AggregateUDFImpl>
1015 })
1016 })
1017 }
1018
1019 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1020 self.inner.order_sensitivity()
1021 }
1022
1023 fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1024 self.inner.simplify()
1025 }
1026
1027 fn reverse_expr(&self) -> ReversedUDAF {
1028 self.inner.reverse_expr()
1029 }
1030
1031 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1032 self.inner.coerce_types(arg_types)
1033 }
1034
1035 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1036 if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1037 self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1038 } else {
1039 false
1040 }
1041 }
1042
1043 fn hash_value(&self) -> u64 {
1044 let hasher = &mut DefaultHasher::new();
1045 self.inner.hash_value().hash(hasher);
1046 self.aliases.hash(hasher);
1047 hasher.finish()
1048 }
1049
1050 fn is_descending(&self) -> Option<bool> {
1051 self.inner.is_descending()
1052 }
1053
1054 fn documentation(&self) -> Option<&Documentation> {
1055 self.inner.documentation()
1056 }
1057}
1058
1059// Aggregate UDF doc sections for use in public documentation
1060pub mod aggregate_doc_sections {
1061 use crate::DocSection;
1062
1063 pub fn doc_sections() -> Vec<DocSection> {
1064 vec![
1065 DOC_SECTION_GENERAL,
1066 DOC_SECTION_STATISTICAL,
1067 DOC_SECTION_APPROXIMATE,
1068 ]
1069 }
1070
1071 pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1072 include: true,
1073 label: "General Functions",
1074 description: None,
1075 };
1076
1077 pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1078 include: true,
1079 label: "Statistical Functions",
1080 description: None,
1081 };
1082
1083 pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1084 include: true,
1085 label: "Approximate Functions",
1086 description: None,
1087 };
1088}
1089
1090/// Indicates whether an aggregation function is monotonic as a set
1091/// function. A set function is monotonically increasing if its value
1092/// increases as its argument grows (as a set). Formally, `f` is a
1093/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1094/// is a superset of `T`.
1095///
1096/// For example `COUNT` and `MAX` are monotonically increasing as their
1097/// values always increase (or stay the same) as new values are seen. On
1098/// the other hand, `MIN` is monotonically decreasing as its value always
1099/// decreases or stays the same as new values are seen.
1100#[derive(Debug, Clone, PartialEq)]
1101pub enum SetMonotonicity {
1102 /// Aggregate value increases or stays the same as the input set grows.
1103 Increasing,
1104 /// Aggregate value decreases or stays the same as the input set grows.
1105 Decreasing,
1106 /// Aggregate value may increase, decrease, or stay the same as the input
1107 /// set grows.
1108 NotMonotonic,
1109}
1110
1111#[cfg(test)]
1112mod test {
1113 use crate::{AggregateUDF, AggregateUDFImpl};
1114 use arrow::datatypes::{DataType, Field};
1115 use datafusion_common::Result;
1116 use datafusion_expr_common::accumulator::Accumulator;
1117 use datafusion_expr_common::signature::{Signature, Volatility};
1118 use datafusion_functions_aggregate_common::accumulator::{
1119 AccumulatorArgs, StateFieldsArgs,
1120 };
1121 use std::any::Any;
1122 use std::cmp::Ordering;
1123
1124 #[derive(Debug, Clone)]
1125 struct AMeanUdf {
1126 signature: Signature,
1127 }
1128
1129 impl AMeanUdf {
1130 fn new() -> Self {
1131 Self {
1132 signature: Signature::uniform(
1133 1,
1134 vec![DataType::Float64],
1135 Volatility::Immutable,
1136 ),
1137 }
1138 }
1139 }
1140
1141 impl AggregateUDFImpl for AMeanUdf {
1142 fn as_any(&self) -> &dyn Any {
1143 self
1144 }
1145 fn name(&self) -> &str {
1146 "a"
1147 }
1148 fn signature(&self) -> &Signature {
1149 &self.signature
1150 }
1151 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1152 unimplemented!()
1153 }
1154 fn accumulator(
1155 &self,
1156 _acc_args: AccumulatorArgs,
1157 ) -> Result<Box<dyn Accumulator>> {
1158 unimplemented!()
1159 }
1160 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1161 unimplemented!()
1162 }
1163 }
1164
1165 #[derive(Debug, Clone)]
1166 struct BMeanUdf {
1167 signature: Signature,
1168 }
1169 impl BMeanUdf {
1170 fn new() -> Self {
1171 Self {
1172 signature: Signature::uniform(
1173 1,
1174 vec![DataType::Float64],
1175 Volatility::Immutable,
1176 ),
1177 }
1178 }
1179 }
1180
1181 impl AggregateUDFImpl for BMeanUdf {
1182 fn as_any(&self) -> &dyn Any {
1183 self
1184 }
1185 fn name(&self) -> &str {
1186 "b"
1187 }
1188 fn signature(&self) -> &Signature {
1189 &self.signature
1190 }
1191 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1192 unimplemented!()
1193 }
1194 fn accumulator(
1195 &self,
1196 _acc_args: AccumulatorArgs,
1197 ) -> Result<Box<dyn Accumulator>> {
1198 unimplemented!()
1199 }
1200 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1201 unimplemented!()
1202 }
1203 }
1204
1205 #[test]
1206 fn test_partial_ord() {
1207 // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1208 // not intended to exhaustively test all possibilities
1209 let a1 = AggregateUDF::from(AMeanUdf::new());
1210 let a2 = AggregateUDF::from(AMeanUdf::new());
1211 assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1212
1213 let b1 = AggregateUDF::from(BMeanUdf::new());
1214 assert!(a1 < b1);
1215 assert!(!(a1 == b1));
1216 }
1217}