Deprecated: The each() function is deprecated. This message will be suppressed on further calls in /home/zhenxiangba/zhenxiangba.com/public_html/phproxy-improved-master/index.php on line 456
udaf.rs - source
[go: Go Back, main page]

datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34    schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
35    WindowFunctionParams,
36};
37use crate::function::{
38    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65///    access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79    inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83    fn eq(&self, other: &Self) -> bool {
84        self.inner.equals(other.inner.as_ref())
85    }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91    fn hash<H: Hasher>(&self, state: &mut H) {
92        self.inner.hash_value().hash(state)
93    }
94}
95
96impl fmt::Display for AggregateUDF {
97    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98        write!(f, "{}", self.name())
99    }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105    /// The statistics of the aggregate input
106    pub statistics: &'a Statistics,
107    /// The resolved return type of the aggregate function
108    pub return_type: &'a DataType,
109    /// Whether the aggregate function is distinct.
110    ///
111    /// ```sql
112    /// SELECT COUNT(DISTINCT column1) FROM t;
113    /// ```
114    pub is_distinct: bool,
115    /// The physical expression of arguments the aggregate function takes.
116    pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121    ///
122    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124    where
125        F: AggregateUDFImpl + 'static,
126    {
127        Self::new_from_shared_impl(Arc::new(fun))
128    }
129
130    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132        Self { inner: fun }
133    }
134
135    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137        &self.inner
138    }
139
140    /// Adds additional names that can be used to invoke this function, in
141    /// addition to `name`
142    ///
143    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145        Self::new_from_impl(AliasedAggregateUDFImpl::new(
146            Arc::clone(&self.inner),
147            aliases,
148        ))
149    }
150
151    /// Creates an [`Expr`] that calls the aggregate function.
152    ///
153    /// This utility allows using the UDAF without requiring access to
154    /// the registry, such as with the DataFrame API.
155    pub fn call(&self, args: Vec<Expr>) -> Expr {
156        Expr::AggregateFunction(AggregateFunction::new_udf(
157            Arc::new(self.clone()),
158            args,
159            false,
160            None,
161            vec![],
162            None,
163        ))
164    }
165
166    /// Returns this function's name
167    ///
168    /// See [`AggregateUDFImpl::name`] for more details.
169    pub fn name(&self) -> &str {
170        self.inner.name()
171    }
172
173    /// Returns the aliases for this function.
174    pub fn aliases(&self) -> &[String] {
175        self.inner.aliases()
176    }
177
178    /// See [`AggregateUDFImpl::schema_name`] for more details.
179    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
180        self.inner.schema_name(params)
181    }
182
183    /// Returns a human readable expression.
184    ///
185    /// See [`Expr::human_display`] for details.
186    pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
187        self.inner.human_display(params)
188    }
189
190    pub fn window_function_schema_name(
191        &self,
192        params: &WindowFunctionParams,
193    ) -> Result<String> {
194        self.inner.window_function_schema_name(params)
195    }
196
197    /// See [`AggregateUDFImpl::display_name`] for more details.
198    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
199        self.inner.display_name(params)
200    }
201
202    pub fn window_function_display_name(
203        &self,
204        params: &WindowFunctionParams,
205    ) -> Result<String> {
206        self.inner.window_function_display_name(params)
207    }
208
209    pub fn is_nullable(&self) -> bool {
210        self.inner.is_nullable()
211    }
212
213    /// Returns this function's signature (what input types are accepted)
214    ///
215    /// See [`AggregateUDFImpl::signature`] for more details.
216    pub fn signature(&self) -> &Signature {
217        self.inner.signature()
218    }
219
220    /// Return the type of the function given its input types
221    ///
222    /// See [`AggregateUDFImpl::return_type`] for more details.
223    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
224        self.inner.return_type(args)
225    }
226
227    /// Return the field of the function given its input fields
228    ///
229    /// See [`AggregateUDFImpl::return_field`] for more details.
230    pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
231        self.inner.return_field(args)
232    }
233
234    /// Return an accumulator the given aggregate, given its return datatype
235    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
236        self.inner.accumulator(acc_args)
237    }
238
239    /// Return the fields used to store the intermediate state for this aggregator, given
240    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
241    /// for more details.
242    ///
243    /// This is used to support multi-phase aggregations
244    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
245        self.inner.state_fields(args)
246    }
247
248    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
249    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
250        self.inner.groups_accumulator_supported(args)
251    }
252
253    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
254    pub fn create_groups_accumulator(
255        &self,
256        args: AccumulatorArgs,
257    ) -> Result<Box<dyn GroupsAccumulator>> {
258        self.inner.create_groups_accumulator(args)
259    }
260
261    pub fn create_sliding_accumulator(
262        &self,
263        args: AccumulatorArgs,
264    ) -> Result<Box<dyn Accumulator>> {
265        self.inner.create_sliding_accumulator(args)
266    }
267
268    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
269        self.inner.coerce_types(arg_types)
270    }
271
272    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
273    pub fn with_beneficial_ordering(
274        self,
275        beneficial_ordering: bool,
276    ) -> Result<Option<AggregateUDF>> {
277        self.inner
278            .with_beneficial_ordering(beneficial_ordering)
279            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
280    }
281
282    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
283    /// for possible options.
284    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
285        self.inner.order_sensitivity()
286    }
287
288    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
289    /// generate same result with this `AggregateUDF` when iterated in reverse
290    /// order, and `None` if there is no such `AggregateUDF`).
291    pub fn reverse_udf(&self) -> ReversedUDAF {
292        self.inner.reverse_expr()
293    }
294
295    /// Do the function rewrite
296    ///
297    /// See [`AggregateUDFImpl::simplify`] for more details.
298    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
299        self.inner.simplify()
300    }
301
302    /// Returns true if the function is max, false if the function is min
303    /// None in all other cases, used in certain optimizations for
304    /// or aggregate
305    pub fn is_descending(&self) -> Option<bool> {
306        self.inner.is_descending()
307    }
308
309    /// Return the value of this aggregate function if it can be determined
310    /// entirely from statistics and arguments.
311    ///
312    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
313    pub fn value_from_stats(
314        &self,
315        statistics_args: &StatisticsArgs,
316    ) -> Option<ScalarValue> {
317        self.inner.value_from_stats(statistics_args)
318    }
319
320    /// See [`AggregateUDFImpl::default_value`] for more details.
321    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
322        self.inner.default_value(data_type)
323    }
324
325    /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
326    pub fn supports_null_handling_clause(&self) -> bool {
327        self.inner.supports_null_handling_clause()
328    }
329
330    /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details.
331    pub fn is_ordered_set_aggregate(&self) -> bool {
332        self.inner.is_ordered_set_aggregate()
333    }
334
335    /// Returns the documentation for this Aggregate UDF.
336    ///
337    /// Documentation can be accessed programmatically as well as
338    /// generating publicly facing documentation.
339    pub fn documentation(&self) -> Option<&Documentation> {
340        self.inner.documentation()
341    }
342}
343
344impl<F> From<F> for AggregateUDF
345where
346    F: AggregateUDFImpl + Send + Sync + 'static,
347{
348    fn from(fun: F) -> Self {
349        Self::new_from_impl(fun)
350    }
351}
352
353/// Trait for implementing [`AggregateUDF`].
354///
355/// This trait exposes the full API for implementing user defined aggregate functions and
356/// can be used to implement any function.
357///
358/// See [`advanced_udaf.rs`] for a full example with complete implementation and
359/// [`AggregateUDF`] for other available options.
360///
361/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
362///
363/// # Basic Example
364/// ```
365/// # use std::any::Any;
366/// # use std::sync::{Arc, LazyLock};
367/// # use arrow::datatypes::{DataType, FieldRef};
368/// # use datafusion_common::{DataFusionError, plan_err, Result};
369/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
370/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
371/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
372/// # use arrow::datatypes::Schema;
373/// # use arrow::datatypes::Field;
374///
375/// #[derive(Debug, Clone)]
376/// struct GeoMeanUdf {
377///   signature: Signature,
378/// }
379///
380/// impl GeoMeanUdf {
381///   fn new() -> Self {
382///     Self {
383///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
384///      }
385///   }
386/// }
387///
388/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
389///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
390///             .with_argument("arg1", "The Float64 number for the geometric mean")
391///             .build()
392///     });
393///
394/// fn get_doc() -> &'static Documentation {
395///     &DOCUMENTATION
396/// }
397///
398/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
399/// impl AggregateUDFImpl for GeoMeanUdf {
400///    fn as_any(&self) -> &dyn Any { self }
401///    fn name(&self) -> &str { "geo_mean" }
402///    fn signature(&self) -> &Signature { &self.signature }
403///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
404///      if !matches!(args.get(0), Some(&DataType::Float64)) {
405///        return plan_err!("geo_mean only accepts Float64 arguments");
406///      }
407///      Ok(DataType::Float64)
408///    }
409///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
410///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
411///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
412///        Ok(vec![
413///             Arc::new(args.return_field.as_ref().clone().with_name("value")),
414///             Arc::new(Field::new("ordering", DataType::UInt32, true))
415///        ])
416///    }
417///    fn documentation(&self) -> Option<&Documentation> {
418///        Some(get_doc())
419///    }
420/// }
421///
422/// // Create a new AggregateUDF from the implementation
423/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
424///
425/// // Call the function `geo_mean(col)`
426/// let expr = geometric_mean.call(vec![col("a")]);
427/// ```
428pub trait AggregateUDFImpl: Debug + Send + Sync {
429    // Note: When adding any methods (with default implementations), remember to add them also
430    // into the AliasedAggregateUDFImpl below!
431
432    /// Returns this object as an [`Any`] trait object
433    fn as_any(&self) -> &dyn Any;
434
435    /// Returns this function's name
436    fn name(&self) -> &str;
437
438    /// Returns any aliases (alternate names) for this function.
439    ///
440    /// Note: `aliases` should only include names other than [`Self::name`].
441    /// Defaults to `[]` (no aliases)
442    fn aliases(&self) -> &[String] {
443        &[]
444    }
445
446    /// Returns the name of the column this expression would create
447    ///
448    /// See [`Expr::schema_name`] for details
449    ///
450    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
451    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
452        let AggregateFunctionParams {
453            args,
454            distinct,
455            filter,
456            order_by,
457            null_treatment,
458        } = params;
459
460        // exclude the first function argument(= column) in ordered set aggregate function,
461        // because it is duplicated with the WITHIN GROUP clause in schema name.
462        let args = if self.is_ordered_set_aggregate() {
463            &args[1..]
464        } else {
465            &args[..]
466        };
467
468        let mut schema_name = String::new();
469
470        schema_name.write_fmt(format_args!(
471            "{}({}{})",
472            self.name(),
473            if *distinct { "DISTINCT " } else { "" },
474            schema_name_from_exprs_comma_separated_without_space(args)?
475        ))?;
476
477        if let Some(null_treatment) = null_treatment {
478            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
479        }
480
481        if let Some(filter) = filter {
482            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
483        };
484
485        if !order_by.is_empty() {
486            let clause = match self.is_ordered_set_aggregate() {
487                true => "WITHIN GROUP",
488                false => "ORDER BY",
489            };
490
491            schema_name.write_fmt(format_args!(
492                " {} [{}]",
493                clause,
494                schema_name_from_sorts(order_by)?
495            ))?;
496        };
497
498        Ok(schema_name)
499    }
500
501    /// Returns a human readable expression.
502    ///
503    /// See [`Expr::human_display`] for details.
504    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
505        let AggregateFunctionParams {
506            args,
507            distinct,
508            filter,
509            order_by,
510            null_treatment,
511        } = params;
512
513        let mut schema_name = String::new();
514
515        schema_name.write_fmt(format_args!(
516            "{}({}{})",
517            self.name(),
518            if *distinct { "DISTINCT " } else { "" },
519            ExprListDisplay::comma_separated(args.as_slice())
520        ))?;
521
522        if let Some(null_treatment) = null_treatment {
523            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
524        }
525
526        if let Some(filter) = filter {
527            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
528        };
529
530        if !order_by.is_empty() {
531            schema_name.write_fmt(format_args!(
532                " ORDER BY [{}]",
533                schema_name_from_sorts(order_by)?
534            ))?;
535        };
536
537        Ok(schema_name)
538    }
539
540    /// Returns the name of the column this expression would create
541    ///
542    /// See [`Expr::schema_name`] for details
543    ///
544    /// Different from `schema_name` in that it is used for window aggregate function
545    ///
546    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
547    fn window_function_schema_name(
548        &self,
549        params: &WindowFunctionParams,
550    ) -> Result<String> {
551        let WindowFunctionParams {
552            args,
553            partition_by,
554            order_by,
555            window_frame,
556            null_treatment,
557        } = params;
558
559        let mut schema_name = String::new();
560        schema_name.write_fmt(format_args!(
561            "{}({})",
562            self.name(),
563            schema_name_from_exprs(args)?
564        ))?;
565
566        if let Some(null_treatment) = null_treatment {
567            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
568        }
569
570        if !partition_by.is_empty() {
571            schema_name.write_fmt(format_args!(
572                " PARTITION BY [{}]",
573                schema_name_from_exprs(partition_by)?
574            ))?;
575        }
576
577        if !order_by.is_empty() {
578            schema_name.write_fmt(format_args!(
579                " ORDER BY [{}]",
580                schema_name_from_sorts(order_by)?
581            ))?;
582        };
583
584        schema_name.write_fmt(format_args!(" {window_frame}"))?;
585
586        Ok(schema_name)
587    }
588
589    /// Returns the user-defined display name of function, given the arguments
590    ///
591    /// This can be used to customize the output column name generated by this
592    /// function.
593    ///
594    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
595    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
596        let AggregateFunctionParams {
597            args,
598            distinct,
599            filter,
600            order_by,
601            null_treatment,
602        } = params;
603
604        let mut display_name = String::new();
605
606        display_name.write_fmt(format_args!(
607            "{}({}{})",
608            self.name(),
609            if *distinct { "DISTINCT " } else { "" },
610            expr_vec_fmt!(args)
611        ))?;
612
613        if let Some(nt) = null_treatment {
614            display_name.write_fmt(format_args!(" {nt}"))?;
615        }
616        if let Some(fe) = filter {
617            display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
618        }
619        if !order_by.is_empty() {
620            display_name.write_fmt(format_args!(
621                " ORDER BY [{}]",
622                order_by
623                    .iter()
624                    .map(|o| format!("{o}"))
625                    .collect::<Vec<String>>()
626                    .join(", ")
627            ))?;
628        }
629
630        Ok(display_name)
631    }
632
633    /// Returns the user-defined display name of function, given the arguments
634    ///
635    /// This can be used to customize the output column name generated by this
636    /// function.
637    ///
638    /// Different from `display_name` in that it is used for window aggregate function
639    ///
640    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
641    fn window_function_display_name(
642        &self,
643        params: &WindowFunctionParams,
644    ) -> Result<String> {
645        let WindowFunctionParams {
646            args,
647            partition_by,
648            order_by,
649            window_frame,
650            null_treatment,
651        } = params;
652
653        let mut display_name = String::new();
654
655        display_name.write_fmt(format_args!(
656            "{}({})",
657            self.name(),
658            expr_vec_fmt!(args)
659        ))?;
660
661        if let Some(null_treatment) = null_treatment {
662            display_name.write_fmt(format_args!(" {null_treatment}"))?;
663        }
664
665        if !partition_by.is_empty() {
666            display_name.write_fmt(format_args!(
667                " PARTITION BY [{}]",
668                expr_vec_fmt!(partition_by)
669            ))?;
670        }
671
672        if !order_by.is_empty() {
673            display_name
674                .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
675        };
676
677        display_name.write_fmt(format_args!(
678            " {} BETWEEN {} AND {}",
679            window_frame.units, window_frame.start_bound, window_frame.end_bound
680        ))?;
681
682        Ok(display_name)
683    }
684
685    /// Returns the function's [`Signature`] for information about what input
686    /// types are accepted and the function's Volatility.
687    fn signature(&self) -> &Signature;
688
689    /// What [`DataType`] will be returned by this function, given the types of
690    /// the arguments
691    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
692
693    /// What type will be returned by this function, given the arguments?
694    ///
695    /// By default, this function calls [`Self::return_type`] with the
696    /// types of each argument.
697    ///
698    /// # Notes
699    ///
700    /// Most UDFs should implement [`Self::return_type`] and not this
701    /// function as the output type for most functions only depends on the types
702    /// of their inputs (e.g. `sum(f64)` is always `f64`).
703    ///
704    /// This function can be used for more advanced cases such as:
705    ///
706    /// 1. specifying nullability
707    /// 2. return types based on the **values** of the arguments (rather than
708    ///    their **types**.
709    /// 3. return types based on metadata within the fields of the inputs
710    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
711        let arg_types: Vec<_> =
712            arg_fields.iter().map(|f| f.data_type()).cloned().collect();
713        let data_type = self.return_type(&arg_types)?;
714
715        Ok(Arc::new(Field::new(
716            self.name(),
717            data_type,
718            self.is_nullable(),
719        )))
720    }
721
722    /// Whether the aggregate function is nullable.
723    ///
724    /// Nullable means that the function could return `null` for any inputs.
725    /// For example, aggregate functions like `COUNT` always return a non null value
726    /// but others like `MIN` will return `NULL` if there is nullable input.
727    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
728    fn is_nullable(&self) -> bool {
729        true
730    }
731
732    /// Return a new [`Accumulator`] that aggregates values for a specific
733    /// group during query execution.
734    ///
735    /// acc_args: [`AccumulatorArgs`] contains information about how the
736    /// aggregate function was called.
737    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
738
739    /// Return the fields used to store the intermediate state of this accumulator.
740    ///
741    /// See [`Accumulator::state`] for background information.
742    ///
743    /// args:  [`StateFieldsArgs`] contains arguments passed to the
744    /// aggregate function's accumulator.
745    ///
746    /// # Notes:
747    ///
748    /// The default implementation returns a single state field named `name`
749    /// with the same type as `value_type`. This is suitable for aggregates such
750    /// as `SUM` or `MIN` where partial state can be combined by applying the
751    /// same aggregate.
752    ///
753    /// For aggregates such as `AVG` where the partial state is more complex
754    /// (e.g. a COUNT and a SUM), this method is used to define the additional
755    /// fields.
756    ///
757    /// The name of the fields must be unique within the query and thus should
758    /// be derived from `name`. See [`format_state_name`] for a utility function
759    /// to generate a unique name.
760    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
761        let fields = vec![args
762            .return_field
763            .as_ref()
764            .clone()
765            .with_name(format_state_name(args.name, "value"))];
766
767        Ok(fields
768            .into_iter()
769            .map(Arc::new)
770            .chain(args.ordering_fields.to_vec())
771            .collect())
772    }
773
774    /// If the aggregate expression has a specialized
775    /// [`GroupsAccumulator`] implementation. If this returns true,
776    /// `[Self::create_groups_accumulator]` will be called.
777    ///
778    /// # Notes
779    ///
780    /// Even if this function returns true, DataFusion will still use
781    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
782    /// used as a window function or when there no GROUP BY columns in the
783    /// query.
784    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
785        false
786    }
787
788    /// Return a specialized [`GroupsAccumulator`] that manages state
789    /// for all groups.
790    ///
791    /// For maximum performance, a [`GroupsAccumulator`] should be
792    /// implemented in addition to [`Accumulator`].
793    fn create_groups_accumulator(
794        &self,
795        _args: AccumulatorArgs,
796    ) -> Result<Box<dyn GroupsAccumulator>> {
797        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
798    }
799
800    /// Sliding accumulator is an alternative accumulator that can be used for
801    /// window functions. It has retract method to revert the previous update.
802    ///
803    /// See [retract_batch] for more details.
804    ///
805    /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
806    fn create_sliding_accumulator(
807        &self,
808        args: AccumulatorArgs,
809    ) -> Result<Box<dyn Accumulator>> {
810        self.accumulator(args)
811    }
812
813    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
814    /// satisfied by its input. If this is not the case, UDFs with order
815    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
816    /// the correct result with possibly more work internally.
817    ///
818    /// # Returns
819    ///
820    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
821    /// If the expression can benefit from existing input ordering, but does
822    /// not implement the method, returns an error. Order insensitive and hard
823    /// requirement aggregators return `Ok(None)`.
824    fn with_beneficial_ordering(
825        self: Arc<Self>,
826        _beneficial_ordering: bool,
827    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
828        if self.order_sensitivity().is_beneficial() {
829            return exec_err!(
830                "Should implement with satisfied for aggregator :{:?}",
831                self.name()
832            );
833        }
834        Ok(None)
835    }
836
837    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
838    /// for possible options.
839    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
840        // We have hard ordering requirements by default, meaning that order
841        // sensitive UDFs need their input orderings to satisfy their ordering
842        // requirements to generate correct results.
843        AggregateOrderSensitivity::HardRequirement
844    }
845
846    /// Optionally apply per-UDaF simplification / rewrite rules.
847    ///
848    /// This can be used to apply function specific simplification rules during
849    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
850    /// implementation does nothing.
851    ///
852    /// Note that DataFusion handles simplifying arguments and  "constant
853    /// folding" (replacing a function call with constant arguments such as
854    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
855    /// optimizations manually for specific UDFs.
856    ///
857    /// # Returns
858    ///
859    /// [None] if simplify is not defined or,
860    ///
861    /// Or, a closure with two arguments:
862    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
863    /// * 'info': [crate::simplify::SimplifyInfo]
864    ///
865    /// closure returns simplified [Expr] or an error.
866    ///
867    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
868        None
869    }
870
871    /// Returns the reverse expression of the aggregate function.
872    fn reverse_expr(&self) -> ReversedUDAF {
873        ReversedUDAF::NotSupported
874    }
875
876    /// Coerce arguments of a function call to types that the function can evaluate.
877    ///
878    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
879    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
880    /// cases
881    ///
882    /// See the [type coercion module](crate::type_coercion)
883    /// documentation for more details on type coercion
884    ///
885    /// For example, if your function requires a floating point arguments, but the user calls
886    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
887    /// to ensure the argument was cast to `1::double`
888    ///
889    /// # Parameters
890    /// * `arg_types`: The argument types of the arguments  this function with
891    ///
892    /// # Return value
893    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
894    /// arguments to these specific types.
895    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
896        not_impl_err!("Function {} does not implement coerce_types", self.name())
897    }
898
899    /// Return true if this aggregate UDF is equal to the other.
900    ///
901    /// Allows customizing the equality of aggregate UDFs.
902    /// *Must* be implemented explicitly if the UDF type has internal state.
903    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
904    ///
905    /// - reflexive: `a.equals(a)`;
906    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
907    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
908    ///
909    /// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
910    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
911        self.as_any().type_id() == other.as_any().type_id()
912            && self.name() == other.name()
913            && self.aliases() == other.aliases()
914            && self.signature() == other.signature()
915    }
916
917    /// Returns a hash value for this aggregate UDF.
918    ///
919    /// Allows customizing the hash code of aggregate UDFs.
920    /// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
921    ///
922    /// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
923    /// their `hash_value`s must be the same.
924    ///
925    /// By default, it is consistent with default implementation of [`Self::equals`].
926    fn hash_value(&self) -> u64 {
927        let hasher = &mut DefaultHasher::new();
928        self.as_any().type_id().hash(hasher);
929        self.name().hash(hasher);
930        self.aliases().hash(hasher);
931        self.signature().hash(hasher);
932        hasher.finish()
933    }
934
935    /// If this function is max, return true
936    /// If the function is min, return false
937    /// Otherwise return None (the default)
938    ///
939    ///
940    /// Note: this is used to use special aggregate implementations in certain conditions
941    fn is_descending(&self) -> Option<bool> {
942        None
943    }
944
945    /// Return the value of this aggregate function if it can be determined
946    /// entirely from statistics and arguments.
947    ///
948    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
949    /// improving query performance.
950    ///
951    /// For example, if the minimum value of column `x` is known to be `42` from
952    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
953    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
954        None
955    }
956
957    /// Returns default value of the function given the input is all `null`.
958    ///
959    /// Most of the aggregate function return Null if input is Null,
960    /// while `count` returns 0 if input is Null
961    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
962        ScalarValue::try_from(data_type)
963    }
964
965    /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true
966    /// If the function does not, return false
967    fn supports_null_handling_clause(&self) -> bool {
968        true
969    }
970
971    /// If this function is ordered-set aggregate function, return true
972    /// If the function is not, return false
973    fn is_ordered_set_aggregate(&self) -> bool {
974        false
975    }
976
977    /// Returns the documentation for this Aggregate UDF.
978    ///
979    /// Documentation can be accessed programmatically as well as
980    /// generating publicly facing documentation.
981    fn documentation(&self) -> Option<&Documentation> {
982        None
983    }
984
985    /// Indicates whether the aggregation function is monotonic as a set
986    /// function. See [`SetMonotonicity`] for details.
987    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
988        SetMonotonicity::NotMonotonic
989    }
990}
991
992impl PartialEq for dyn AggregateUDFImpl {
993    fn eq(&self, other: &Self) -> bool {
994        self.equals(other)
995    }
996}
997
998// Manual implementation of `PartialOrd`
999// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
1000// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
1001impl PartialOrd for dyn AggregateUDFImpl {
1002    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1003        match self.name().partial_cmp(other.name()) {
1004            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
1005            cmp => cmp,
1006        }
1007    }
1008}
1009
1010pub enum ReversedUDAF {
1011    /// The expression is the same as the original expression, like SUM, COUNT
1012    Identical,
1013    /// The expression does not support reverse calculation
1014    NotSupported,
1015    /// The expression is different from the original expression
1016    Reversed(Arc<AggregateUDF>),
1017}
1018
1019/// AggregateUDF that adds an alias to the underlying function. It is better to
1020/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1021#[derive(Debug)]
1022struct AliasedAggregateUDFImpl {
1023    inner: Arc<dyn AggregateUDFImpl>,
1024    aliases: Vec<String>,
1025}
1026
1027impl AliasedAggregateUDFImpl {
1028    pub fn new(
1029        inner: Arc<dyn AggregateUDFImpl>,
1030        new_aliases: impl IntoIterator<Item = &'static str>,
1031    ) -> Self {
1032        let mut aliases = inner.aliases().to_vec();
1033        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1034
1035        Self { inner, aliases }
1036    }
1037}
1038
1039impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1040    fn as_any(&self) -> &dyn Any {
1041        self
1042    }
1043
1044    fn name(&self) -> &str {
1045        self.inner.name()
1046    }
1047
1048    fn signature(&self) -> &Signature {
1049        self.inner.signature()
1050    }
1051
1052    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1053        self.inner.return_type(arg_types)
1054    }
1055
1056    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1057        self.inner.accumulator(acc_args)
1058    }
1059
1060    fn aliases(&self) -> &[String] {
1061        &self.aliases
1062    }
1063
1064    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1065        self.inner.state_fields(args)
1066    }
1067
1068    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1069        self.inner.groups_accumulator_supported(args)
1070    }
1071
1072    fn create_groups_accumulator(
1073        &self,
1074        args: AccumulatorArgs,
1075    ) -> Result<Box<dyn GroupsAccumulator>> {
1076        self.inner.create_groups_accumulator(args)
1077    }
1078
1079    fn create_sliding_accumulator(
1080        &self,
1081        args: AccumulatorArgs,
1082    ) -> Result<Box<dyn Accumulator>> {
1083        self.inner.accumulator(args)
1084    }
1085
1086    fn with_beneficial_ordering(
1087        self: Arc<Self>,
1088        beneficial_ordering: bool,
1089    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1090        Arc::clone(&self.inner)
1091            .with_beneficial_ordering(beneficial_ordering)
1092            .map(|udf| {
1093                udf.map(|udf| {
1094                    Arc::new(AliasedAggregateUDFImpl {
1095                        inner: udf,
1096                        aliases: self.aliases.clone(),
1097                    }) as Arc<dyn AggregateUDFImpl>
1098                })
1099            })
1100    }
1101
1102    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1103        self.inner.order_sensitivity()
1104    }
1105
1106    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1107        self.inner.simplify()
1108    }
1109
1110    fn reverse_expr(&self) -> ReversedUDAF {
1111        self.inner.reverse_expr()
1112    }
1113
1114    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1115        self.inner.coerce_types(arg_types)
1116    }
1117
1118    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1119        if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1120            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1121        } else {
1122            false
1123        }
1124    }
1125
1126    fn hash_value(&self) -> u64 {
1127        let hasher = &mut DefaultHasher::new();
1128        self.inner.hash_value().hash(hasher);
1129        self.aliases.hash(hasher);
1130        hasher.finish()
1131    }
1132
1133    fn is_descending(&self) -> Option<bool> {
1134        self.inner.is_descending()
1135    }
1136
1137    fn documentation(&self) -> Option<&Documentation> {
1138        self.inner.documentation()
1139    }
1140}
1141
1142// Aggregate UDF doc sections for use in public documentation
1143pub mod aggregate_doc_sections {
1144    use crate::DocSection;
1145
1146    pub fn doc_sections() -> Vec<DocSection> {
1147        vec![
1148            DOC_SECTION_GENERAL,
1149            DOC_SECTION_STATISTICAL,
1150            DOC_SECTION_APPROXIMATE,
1151        ]
1152    }
1153
1154    pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1155        include: true,
1156        label: "General Functions",
1157        description: None,
1158    };
1159
1160    pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1161        include: true,
1162        label: "Statistical Functions",
1163        description: None,
1164    };
1165
1166    pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1167        include: true,
1168        label: "Approximate Functions",
1169        description: None,
1170    };
1171}
1172
1173/// Indicates whether an aggregation function is monotonic as a set
1174/// function. A set function is monotonically increasing if its value
1175/// increases as its argument grows (as a set). Formally, `f` is a
1176/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1177/// is a superset of `T`.
1178///
1179/// For example `COUNT` and `MAX` are monotonically increasing as their
1180/// values always increase (or stay the same) as new values are seen. On
1181/// the other hand, `MIN` is monotonically decreasing as its value always
1182/// decreases or stays the same as new values are seen.
1183#[derive(Debug, Clone, PartialEq)]
1184pub enum SetMonotonicity {
1185    /// Aggregate value increases or stays the same as the input set grows.
1186    Increasing,
1187    /// Aggregate value decreases or stays the same as the input set grows.
1188    Decreasing,
1189    /// Aggregate value may increase, decrease, or stay the same as the input
1190    /// set grows.
1191    NotMonotonic,
1192}
1193
1194#[cfg(test)]
1195mod test {
1196    use crate::{AggregateUDF, AggregateUDFImpl};
1197    use arrow::datatypes::{DataType, FieldRef};
1198    use datafusion_common::Result;
1199    use datafusion_expr_common::accumulator::Accumulator;
1200    use datafusion_expr_common::signature::{Signature, Volatility};
1201    use datafusion_functions_aggregate_common::accumulator::{
1202        AccumulatorArgs, StateFieldsArgs,
1203    };
1204    use std::any::Any;
1205    use std::cmp::Ordering;
1206
1207    #[derive(Debug, Clone)]
1208    struct AMeanUdf {
1209        signature: Signature,
1210    }
1211
1212    impl AMeanUdf {
1213        fn new() -> Self {
1214            Self {
1215                signature: Signature::uniform(
1216                    1,
1217                    vec![DataType::Float64],
1218                    Volatility::Immutable,
1219                ),
1220            }
1221        }
1222    }
1223
1224    impl AggregateUDFImpl for AMeanUdf {
1225        fn as_any(&self) -> &dyn Any {
1226            self
1227        }
1228        fn name(&self) -> &str {
1229            "a"
1230        }
1231        fn signature(&self) -> &Signature {
1232            &self.signature
1233        }
1234        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1235            unimplemented!()
1236        }
1237        fn accumulator(
1238            &self,
1239            _acc_args: AccumulatorArgs,
1240        ) -> Result<Box<dyn Accumulator>> {
1241            unimplemented!()
1242        }
1243        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1244            unimplemented!()
1245        }
1246    }
1247
1248    #[derive(Debug, Clone)]
1249    struct BMeanUdf {
1250        signature: Signature,
1251    }
1252    impl BMeanUdf {
1253        fn new() -> Self {
1254            Self {
1255                signature: Signature::uniform(
1256                    1,
1257                    vec![DataType::Float64],
1258                    Volatility::Immutable,
1259                ),
1260            }
1261        }
1262    }
1263
1264    impl AggregateUDFImpl for BMeanUdf {
1265        fn as_any(&self) -> &dyn Any {
1266            self
1267        }
1268        fn name(&self) -> &str {
1269            "b"
1270        }
1271        fn signature(&self) -> &Signature {
1272            &self.signature
1273        }
1274        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1275            unimplemented!()
1276        }
1277        fn accumulator(
1278            &self,
1279            _acc_args: AccumulatorArgs,
1280        ) -> Result<Box<dyn Accumulator>> {
1281            unimplemented!()
1282        }
1283        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1284            unimplemented!()
1285        }
1286    }
1287
1288    #[test]
1289    fn test_partial_ord() {
1290        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1291        // not intended to exhaustively test all possibilities
1292        let a1 = AggregateUDF::from(AMeanUdf::new());
1293        let a2 = AggregateUDF::from(AMeanUdf::new());
1294        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1295
1296        let b1 = AggregateUDF::from(BMeanUdf::new());
1297        assert!(a1 < b1);
1298        assert!(!(a1 == b1));
1299    }
1300}