Deprecated: The each() function is deprecated. This message will be suppressed on further calls in /home/zhenxiangba/zhenxiangba.com/public_html/phproxy-improved-master/index.php on line 456
udaf.rs - source
[go: Go Back, main page]

datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34    schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
35    WindowFunctionParams,
36};
37use crate::function::{
38    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65///    access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79    inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83    fn eq(&self, other: &Self) -> bool {
84        self.inner.equals(other.inner.as_ref())
85    }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91    fn hash<H: Hasher>(&self, state: &mut H) {
92        self.inner.hash_value().hash(state)
93    }
94}
95
96impl fmt::Display for AggregateUDF {
97    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98        write!(f, "{}", self.name())
99    }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105    /// The statistics of the aggregate input
106    pub statistics: &'a Statistics,
107    /// The resolved return type of the aggregate function
108    pub return_type: &'a DataType,
109    /// Whether the aggregate function is distinct.
110    ///
111    /// ```sql
112    /// SELECT COUNT(DISTINCT column1) FROM t;
113    /// ```
114    pub is_distinct: bool,
115    /// The physical expression of arguments the aggregate function takes.
116    pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121    ///
122    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124    where
125        F: AggregateUDFImpl + 'static,
126    {
127        Self::new_from_shared_impl(Arc::new(fun))
128    }
129
130    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132        Self { inner: fun }
133    }
134
135    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137        &self.inner
138    }
139
140    /// Adds additional names that can be used to invoke this function, in
141    /// addition to `name`
142    ///
143    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145        Self::new_from_impl(AliasedAggregateUDFImpl::new(
146            Arc::clone(&self.inner),
147            aliases,
148        ))
149    }
150
151    /// Creates an [`Expr`] that calls the aggregate function.
152    ///
153    /// This utility allows using the UDAF without requiring access to
154    /// the registry, such as with the DataFrame API.
155    pub fn call(&self, args: Vec<Expr>) -> Expr {
156        Expr::AggregateFunction(AggregateFunction::new_udf(
157            Arc::new(self.clone()),
158            args,
159            false,
160            None,
161            None,
162            None,
163        ))
164    }
165
166    /// Returns this function's name
167    ///
168    /// See [`AggregateUDFImpl::name`] for more details.
169    pub fn name(&self) -> &str {
170        self.inner.name()
171    }
172
173    /// See [`AggregateUDFImpl::schema_name`] for more details.
174    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
175        self.inner.schema_name(params)
176    }
177
178    /// Returns a human readable expression.
179    ///
180    /// See [`Expr::human_display`] for details.
181    pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
182        self.inner.human_display(params)
183    }
184
185    pub fn window_function_schema_name(
186        &self,
187        params: &WindowFunctionParams,
188    ) -> Result<String> {
189        self.inner.window_function_schema_name(params)
190    }
191
192    /// See [`AggregateUDFImpl::display_name`] for more details.
193    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
194        self.inner.display_name(params)
195    }
196
197    pub fn window_function_display_name(
198        &self,
199        params: &WindowFunctionParams,
200    ) -> Result<String> {
201        self.inner.window_function_display_name(params)
202    }
203
204    pub fn is_nullable(&self) -> bool {
205        self.inner.is_nullable()
206    }
207
208    /// Returns the aliases for this function.
209    pub fn aliases(&self) -> &[String] {
210        self.inner.aliases()
211    }
212
213    /// Returns this function's signature (what input types are accepted)
214    ///
215    /// See [`AggregateUDFImpl::signature`] for more details.
216    pub fn signature(&self) -> &Signature {
217        self.inner.signature()
218    }
219
220    /// Return the type of the function given its input types
221    ///
222    /// See [`AggregateUDFImpl::return_type`] for more details.
223    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
224        self.inner.return_type(args)
225    }
226
227    /// Return an accumulator the given aggregate, given its return datatype
228    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
229        self.inner.accumulator(acc_args)
230    }
231
232    /// Return the fields used to store the intermediate state for this aggregator, given
233    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
234    /// for more details.
235    ///
236    /// This is used to support multi-phase aggregations
237    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
238        self.inner.state_fields(args)
239    }
240
241    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
242    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
243        self.inner.groups_accumulator_supported(args)
244    }
245
246    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
247    pub fn create_groups_accumulator(
248        &self,
249        args: AccumulatorArgs,
250    ) -> Result<Box<dyn GroupsAccumulator>> {
251        self.inner.create_groups_accumulator(args)
252    }
253
254    pub fn create_sliding_accumulator(
255        &self,
256        args: AccumulatorArgs,
257    ) -> Result<Box<dyn Accumulator>> {
258        self.inner.create_sliding_accumulator(args)
259    }
260
261    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
262        self.inner.coerce_types(arg_types)
263    }
264
265    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
266    pub fn with_beneficial_ordering(
267        self,
268        beneficial_ordering: bool,
269    ) -> Result<Option<AggregateUDF>> {
270        self.inner
271            .with_beneficial_ordering(beneficial_ordering)
272            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
273    }
274
275    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
276    /// for possible options.
277    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
278        self.inner.order_sensitivity()
279    }
280
281    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
282    /// generate same result with this `AggregateUDF` when iterated in reverse
283    /// order, and `None` if there is no such `AggregateUDF`).
284    pub fn reverse_udf(&self) -> ReversedUDAF {
285        self.inner.reverse_expr()
286    }
287
288    /// Do the function rewrite
289    ///
290    /// See [`AggregateUDFImpl::simplify`] for more details.
291    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
292        self.inner.simplify()
293    }
294
295    /// Returns true if the function is max, false if the function is min
296    /// None in all other cases, used in certain optimizations for
297    /// or aggregate
298    pub fn is_descending(&self) -> Option<bool> {
299        self.inner.is_descending()
300    }
301
302    /// Return the value of this aggregate function if it can be determined
303    /// entirely from statistics and arguments.
304    ///
305    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
306    pub fn value_from_stats(
307        &self,
308        statistics_args: &StatisticsArgs,
309    ) -> Option<ScalarValue> {
310        self.inner.value_from_stats(statistics_args)
311    }
312
313    /// See [`AggregateUDFImpl::default_value`] for more details.
314    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
315        self.inner.default_value(data_type)
316    }
317
318    /// Returns the documentation for this Aggregate UDF.
319    ///
320    /// Documentation can be accessed programmatically as well as
321    /// generating publicly facing documentation.
322    pub fn documentation(&self) -> Option<&Documentation> {
323        self.inner.documentation()
324    }
325}
326
327impl<F> From<F> for AggregateUDF
328where
329    F: AggregateUDFImpl + Send + Sync + 'static,
330{
331    fn from(fun: F) -> Self {
332        Self::new_from_impl(fun)
333    }
334}
335
336/// Trait for implementing [`AggregateUDF`].
337///
338/// This trait exposes the full API for implementing user defined aggregate functions and
339/// can be used to implement any function.
340///
341/// See [`advanced_udaf.rs`] for a full example with complete implementation and
342/// [`AggregateUDF`] for other available options.
343///
344/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
345///
346/// # Basic Example
347/// ```
348/// # use std::any::Any;
349/// # use std::sync::LazyLock;
350/// # use arrow::datatypes::DataType;
351/// # use datafusion_common::{DataFusionError, plan_err, Result};
352/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
353/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
354/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
355/// # use arrow::datatypes::Schema;
356/// # use arrow::datatypes::Field;
357///
358/// #[derive(Debug, Clone)]
359/// struct GeoMeanUdf {
360///   signature: Signature,
361/// }
362///
363/// impl GeoMeanUdf {
364///   fn new() -> Self {
365///     Self {
366///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
367///      }
368///   }
369/// }
370///
371/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
372///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
373///             .with_argument("arg1", "The Float64 number for the geometric mean")
374///             .build()
375///     });
376///
377/// fn get_doc() -> &'static Documentation {
378///     &DOCUMENTATION
379/// }
380///    
381/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
382/// impl AggregateUDFImpl for GeoMeanUdf {
383///    fn as_any(&self) -> &dyn Any { self }
384///    fn name(&self) -> &str { "geo_mean" }
385///    fn signature(&self) -> &Signature { &self.signature }
386///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
387///      if !matches!(args.get(0), Some(&DataType::Float64)) {
388///        return plan_err!("geo_mean only accepts Float64 arguments");
389///      }
390///      Ok(DataType::Float64)
391///    }
392///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
393///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
394///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
395///        Ok(vec![
396///             Field::new("value", args.return_type.clone(), true),
397///             Field::new("ordering", DataType::UInt32, true)
398///        ])
399///    }
400///    fn documentation(&self) -> Option<&Documentation> {
401///        Some(get_doc())  
402///    }
403/// }
404///
405/// // Create a new AggregateUDF from the implementation
406/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
407///
408/// // Call the function `geo_mean(col)`
409/// let expr = geometric_mean.call(vec![col("a")]);
410/// ```
411pub trait AggregateUDFImpl: Debug + Send + Sync {
412    // Note: When adding any methods (with default implementations), remember to add them also
413    // into the AliasedAggregateUDFImpl below!
414
415    /// Returns this object as an [`Any`] trait object
416    fn as_any(&self) -> &dyn Any;
417
418    /// Returns this function's name
419    fn name(&self) -> &str;
420
421    /// Returns the name of the column this expression would create
422    ///
423    /// See [`Expr::schema_name`] for details
424    ///
425    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
426    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
427        let AggregateFunctionParams {
428            args,
429            distinct,
430            filter,
431            order_by,
432            null_treatment,
433        } = params;
434
435        let mut schema_name = String::new();
436
437        schema_name.write_fmt(format_args!(
438            "{}({}{})",
439            self.name(),
440            if *distinct { "DISTINCT " } else { "" },
441            schema_name_from_exprs_comma_separated_without_space(args)?
442        ))?;
443
444        if let Some(null_treatment) = null_treatment {
445            schema_name.write_fmt(format_args!(" {}", null_treatment))?;
446        }
447
448        if let Some(filter) = filter {
449            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
450        };
451
452        if let Some(order_by) = order_by {
453            schema_name.write_fmt(format_args!(
454                " ORDER BY [{}]",
455                schema_name_from_sorts(order_by)?
456            ))?;
457        };
458
459        Ok(schema_name)
460    }
461
462    /// Returns a human readable expression.
463    ///
464    /// See [`Expr::human_display`] for details.
465    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
466        let AggregateFunctionParams {
467            args,
468            distinct,
469            filter,
470            order_by,
471            null_treatment,
472        } = params;
473
474        let mut schema_name = String::new();
475
476        schema_name.write_fmt(format_args!(
477            "{}({}{})",
478            self.name(),
479            if *distinct { "DISTINCT " } else { "" },
480            ExprListDisplay::comma_separated(args.as_slice())
481        ))?;
482
483        if let Some(null_treatment) = null_treatment {
484            schema_name.write_fmt(format_args!(" {}", null_treatment))?;
485        }
486
487        if let Some(filter) = filter {
488            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
489        };
490
491        if let Some(order_by) = order_by {
492            schema_name.write_fmt(format_args!(
493                " ORDER BY [{}]",
494                schema_name_from_sorts(order_by)?
495            ))?;
496        };
497
498        Ok(schema_name)
499    }
500
501    /// Returns the name of the column this expression would create
502    ///
503    /// See [`Expr::schema_name`] for details
504    ///
505    /// Different from `schema_name` in that it is used for window aggregate function
506    ///
507    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
508    fn window_function_schema_name(
509        &self,
510        params: &WindowFunctionParams,
511    ) -> Result<String> {
512        let WindowFunctionParams {
513            args,
514            partition_by,
515            order_by,
516            window_frame,
517            null_treatment,
518        } = params;
519
520        let mut schema_name = String::new();
521        schema_name.write_fmt(format_args!(
522            "{}({})",
523            self.name(),
524            schema_name_from_exprs(args)?
525        ))?;
526
527        if let Some(null_treatment) = null_treatment {
528            schema_name.write_fmt(format_args!(" {}", null_treatment))?;
529        }
530
531        if !partition_by.is_empty() {
532            schema_name.write_fmt(format_args!(
533                " PARTITION BY [{}]",
534                schema_name_from_exprs(partition_by)?
535            ))?;
536        }
537
538        if !order_by.is_empty() {
539            schema_name.write_fmt(format_args!(
540                " ORDER BY [{}]",
541                schema_name_from_sorts(order_by)?
542            ))?;
543        };
544
545        schema_name.write_fmt(format_args!(" {window_frame}"))?;
546
547        Ok(schema_name)
548    }
549
550    /// Returns the user-defined display name of function, given the arguments
551    ///
552    /// This can be used to customize the output column name generated by this
553    /// function.
554    ///
555    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
556    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
557        let AggregateFunctionParams {
558            args,
559            distinct,
560            filter,
561            order_by,
562            null_treatment,
563        } = params;
564
565        let mut display_name = String::new();
566
567        display_name.write_fmt(format_args!(
568            "{}({}{})",
569            self.name(),
570            if *distinct { "DISTINCT " } else { "" },
571            expr_vec_fmt!(args)
572        ))?;
573
574        if let Some(nt) = null_treatment {
575            display_name.write_fmt(format_args!(" {}", nt))?;
576        }
577        if let Some(fe) = filter {
578            display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
579        }
580        if let Some(ob) = order_by {
581            display_name.write_fmt(format_args!(
582                " ORDER BY [{}]",
583                ob.iter()
584                    .map(|o| format!("{o}"))
585                    .collect::<Vec<String>>()
586                    .join(", ")
587            ))?;
588        }
589
590        Ok(display_name)
591    }
592
593    /// Returns the user-defined display name of function, given the arguments
594    ///
595    /// This can be used to customize the output column name generated by this
596    /// function.
597    ///
598    /// Different from `display_name` in that it is used for window aggregate function
599    ///
600    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
601    fn window_function_display_name(
602        &self,
603        params: &WindowFunctionParams,
604    ) -> Result<String> {
605        let WindowFunctionParams {
606            args,
607            partition_by,
608            order_by,
609            window_frame,
610            null_treatment,
611        } = params;
612
613        let mut display_name = String::new();
614
615        display_name.write_fmt(format_args!(
616            "{}({})",
617            self.name(),
618            expr_vec_fmt!(args)
619        ))?;
620
621        if let Some(null_treatment) = null_treatment {
622            display_name.write_fmt(format_args!(" {}", null_treatment))?;
623        }
624
625        if !partition_by.is_empty() {
626            display_name.write_fmt(format_args!(
627                " PARTITION BY [{}]",
628                expr_vec_fmt!(partition_by)
629            ))?;
630        }
631
632        if !order_by.is_empty() {
633            display_name
634                .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
635        };
636
637        display_name.write_fmt(format_args!(
638            " {} BETWEEN {} AND {}",
639            window_frame.units, window_frame.start_bound, window_frame.end_bound
640        ))?;
641
642        Ok(display_name)
643    }
644
645    /// Returns the function's [`Signature`] for information about what input
646    /// types are accepted and the function's Volatility.
647    fn signature(&self) -> &Signature;
648
649    /// What [`DataType`] will be returned by this function, given the types of
650    /// the arguments
651    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
652
653    /// Whether the aggregate function is nullable.
654    ///
655    /// Nullable means that the function could return `null` for any inputs.
656    /// For example, aggregate functions like `COUNT` always return a non null value
657    /// but others like `MIN` will return `NULL` if there is nullable input.
658    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
659    fn is_nullable(&self) -> bool {
660        true
661    }
662
663    /// Return a new [`Accumulator`] that aggregates values for a specific
664    /// group during query execution.
665    ///
666    /// acc_args: [`AccumulatorArgs`] contains information about how the
667    /// aggregate function was called.
668    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
669
670    /// Return the fields used to store the intermediate state of this accumulator.
671    ///
672    /// See [`Accumulator::state`] for background information.
673    ///
674    /// args:  [`StateFieldsArgs`] contains arguments passed to the
675    /// aggregate function's accumulator.
676    ///
677    /// # Notes:
678    ///
679    /// The default implementation returns a single state field named `name`
680    /// with the same type as `value_type`. This is suitable for aggregates such
681    /// as `SUM` or `MIN` where partial state can be combined by applying the
682    /// same aggregate.
683    ///
684    /// For aggregates such as `AVG` where the partial state is more complex
685    /// (e.g. a COUNT and a SUM), this method is used to define the additional
686    /// fields.
687    ///
688    /// The name of the fields must be unique within the query and thus should
689    /// be derived from `name`. See [`format_state_name`] for a utility function
690    /// to generate a unique name.
691    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
692        let fields = vec![Field::new(
693            format_state_name(args.name, "value"),
694            args.return_type.clone(),
695            true,
696        )];
697
698        Ok(fields
699            .into_iter()
700            .chain(args.ordering_fields.to_vec())
701            .collect())
702    }
703
704    /// If the aggregate expression has a specialized
705    /// [`GroupsAccumulator`] implementation. If this returns true,
706    /// `[Self::create_groups_accumulator]` will be called.
707    ///
708    /// # Notes
709    ///
710    /// Even if this function returns true, DataFusion will still use
711    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
712    /// used as a window function or when there no GROUP BY columns in the
713    /// query.
714    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
715        false
716    }
717
718    /// Return a specialized [`GroupsAccumulator`] that manages state
719    /// for all groups.
720    ///
721    /// For maximum performance, a [`GroupsAccumulator`] should be
722    /// implemented in addition to [`Accumulator`].
723    fn create_groups_accumulator(
724        &self,
725        _args: AccumulatorArgs,
726    ) -> Result<Box<dyn GroupsAccumulator>> {
727        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
728    }
729
730    /// Returns any aliases (alternate names) for this function.
731    ///
732    /// Note: `aliases` should only include names other than [`Self::name`].
733    /// Defaults to `[]` (no aliases)
734    fn aliases(&self) -> &[String] {
735        &[]
736    }
737
738    /// Sliding accumulator is an alternative accumulator that can be used for
739    /// window functions. It has retract method to revert the previous update.
740    ///
741    /// See [retract_batch] for more details.
742    ///
743    /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
744    fn create_sliding_accumulator(
745        &self,
746        args: AccumulatorArgs,
747    ) -> Result<Box<dyn Accumulator>> {
748        self.accumulator(args)
749    }
750
751    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
752    /// satisfied by its input. If this is not the case, UDFs with order
753    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
754    /// the correct result with possibly more work internally.
755    ///
756    /// # Returns
757    ///
758    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
759    /// If the expression can benefit from existing input ordering, but does
760    /// not implement the method, returns an error. Order insensitive and hard
761    /// requirement aggregators return `Ok(None)`.
762    fn with_beneficial_ordering(
763        self: Arc<Self>,
764        _beneficial_ordering: bool,
765    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
766        if self.order_sensitivity().is_beneficial() {
767            return exec_err!(
768                "Should implement with satisfied for aggregator :{:?}",
769                self.name()
770            );
771        }
772        Ok(None)
773    }
774
775    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
776    /// for possible options.
777    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
778        // We have hard ordering requirements by default, meaning that order
779        // sensitive UDFs need their input orderings to satisfy their ordering
780        // requirements to generate correct results.
781        AggregateOrderSensitivity::HardRequirement
782    }
783
784    /// Optionally apply per-UDaF simplification / rewrite rules.
785    ///
786    /// This can be used to apply function specific simplification rules during
787    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
788    /// implementation does nothing.
789    ///
790    /// Note that DataFusion handles simplifying arguments and  "constant
791    /// folding" (replacing a function call with constant arguments such as
792    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
793    /// optimizations manually for specific UDFs.
794    ///
795    /// # Returns
796    ///
797    /// [None] if simplify is not defined or,
798    ///
799    /// Or, a closure with two arguments:
800    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
801    /// * 'info': [crate::simplify::SimplifyInfo]
802    ///
803    /// closure returns simplified [Expr] or an error.
804    ///
805    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
806        None
807    }
808
809    /// Returns the reverse expression of the aggregate function.
810    fn reverse_expr(&self) -> ReversedUDAF {
811        ReversedUDAF::NotSupported
812    }
813
814    /// Coerce arguments of a function call to types that the function can evaluate.
815    ///
816    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
817    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
818    /// cases
819    ///
820    /// See the [type coercion module](crate::type_coercion)
821    /// documentation for more details on type coercion
822    ///
823    /// For example, if your function requires a floating point arguments, but the user calls
824    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
825    /// to ensure the argument was cast to `1::double`
826    ///
827    /// # Parameters
828    /// * `arg_types`: The argument types of the arguments  this function with
829    ///
830    /// # Return value
831    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
832    /// arguments to these specific types.
833    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
834        not_impl_err!("Function {} does not implement coerce_types", self.name())
835    }
836
837    /// Return true if this aggregate UDF is equal to the other.
838    ///
839    /// Allows customizing the equality of aggregate UDFs.
840    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
841    ///
842    /// - reflexive: `a.equals(a)`;
843    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
844    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
845    ///
846    /// By default, compares [`Self::name`] and [`Self::signature`].
847    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
848        self.name() == other.name() && self.signature() == other.signature()
849    }
850
851    /// Returns a hash value for this aggregate UDF.
852    ///
853    /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
854    /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
855    ///
856    /// By default, hashes [`Self::name`] and [`Self::signature`].
857    fn hash_value(&self) -> u64 {
858        let hasher = &mut DefaultHasher::new();
859        self.name().hash(hasher);
860        self.signature().hash(hasher);
861        hasher.finish()
862    }
863
864    /// If this function is max, return true
865    /// If the function is min, return false
866    /// Otherwise return None (the default)
867    ///
868    ///
869    /// Note: this is used to use special aggregate implementations in certain conditions
870    fn is_descending(&self) -> Option<bool> {
871        None
872    }
873
874    /// Return the value of this aggregate function if it can be determined
875    /// entirely from statistics and arguments.
876    ///
877    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
878    /// improving query performance.
879    ///
880    /// For example, if the minimum value of column `x` is known to be `42` from
881    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
882    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
883        None
884    }
885
886    /// Returns default value of the function given the input is all `null`.
887    ///
888    /// Most of the aggregate function return Null if input is Null,
889    /// while `count` returns 0 if input is Null
890    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
891        ScalarValue::try_from(data_type)
892    }
893
894    /// Returns the documentation for this Aggregate UDF.
895    ///
896    /// Documentation can be accessed programmatically as well as
897    /// generating publicly facing documentation.
898    fn documentation(&self) -> Option<&Documentation> {
899        None
900    }
901
902    /// Indicates whether the aggregation function is monotonic as a set
903    /// function. See [`SetMonotonicity`] for details.
904    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
905        SetMonotonicity::NotMonotonic
906    }
907}
908
909impl PartialEq for dyn AggregateUDFImpl {
910    fn eq(&self, other: &Self) -> bool {
911        self.equals(other)
912    }
913}
914
915// Manual implementation of `PartialOrd`
916// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
917// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
918impl PartialOrd for dyn AggregateUDFImpl {
919    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
920        match self.name().partial_cmp(other.name()) {
921            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
922            cmp => cmp,
923        }
924    }
925}
926
927pub enum ReversedUDAF {
928    /// The expression is the same as the original expression, like SUM, COUNT
929    Identical,
930    /// The expression does not support reverse calculation
931    NotSupported,
932    /// The expression is different from the original expression
933    Reversed(Arc<AggregateUDF>),
934}
935
936/// AggregateUDF that adds an alias to the underlying function. It is better to
937/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
938#[derive(Debug)]
939struct AliasedAggregateUDFImpl {
940    inner: Arc<dyn AggregateUDFImpl>,
941    aliases: Vec<String>,
942}
943
944impl AliasedAggregateUDFImpl {
945    pub fn new(
946        inner: Arc<dyn AggregateUDFImpl>,
947        new_aliases: impl IntoIterator<Item = &'static str>,
948    ) -> Self {
949        let mut aliases = inner.aliases().to_vec();
950        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
951
952        Self { inner, aliases }
953    }
954}
955
956impl AggregateUDFImpl for AliasedAggregateUDFImpl {
957    fn as_any(&self) -> &dyn Any {
958        self
959    }
960
961    fn name(&self) -> &str {
962        self.inner.name()
963    }
964
965    fn signature(&self) -> &Signature {
966        self.inner.signature()
967    }
968
969    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
970        self.inner.return_type(arg_types)
971    }
972
973    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
974        self.inner.accumulator(acc_args)
975    }
976
977    fn aliases(&self) -> &[String] {
978        &self.aliases
979    }
980
981    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
982        self.inner.state_fields(args)
983    }
984
985    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
986        self.inner.groups_accumulator_supported(args)
987    }
988
989    fn create_groups_accumulator(
990        &self,
991        args: AccumulatorArgs,
992    ) -> Result<Box<dyn GroupsAccumulator>> {
993        self.inner.create_groups_accumulator(args)
994    }
995
996    fn create_sliding_accumulator(
997        &self,
998        args: AccumulatorArgs,
999    ) -> Result<Box<dyn Accumulator>> {
1000        self.inner.accumulator(args)
1001    }
1002
1003    fn with_beneficial_ordering(
1004        self: Arc<Self>,
1005        beneficial_ordering: bool,
1006    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1007        Arc::clone(&self.inner)
1008            .with_beneficial_ordering(beneficial_ordering)
1009            .map(|udf| {
1010                udf.map(|udf| {
1011                    Arc::new(AliasedAggregateUDFImpl {
1012                        inner: udf,
1013                        aliases: self.aliases.clone(),
1014                    }) as Arc<dyn AggregateUDFImpl>
1015                })
1016            })
1017    }
1018
1019    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1020        self.inner.order_sensitivity()
1021    }
1022
1023    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1024        self.inner.simplify()
1025    }
1026
1027    fn reverse_expr(&self) -> ReversedUDAF {
1028        self.inner.reverse_expr()
1029    }
1030
1031    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1032        self.inner.coerce_types(arg_types)
1033    }
1034
1035    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1036        if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1037            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1038        } else {
1039            false
1040        }
1041    }
1042
1043    fn hash_value(&self) -> u64 {
1044        let hasher = &mut DefaultHasher::new();
1045        self.inner.hash_value().hash(hasher);
1046        self.aliases.hash(hasher);
1047        hasher.finish()
1048    }
1049
1050    fn is_descending(&self) -> Option<bool> {
1051        self.inner.is_descending()
1052    }
1053
1054    fn documentation(&self) -> Option<&Documentation> {
1055        self.inner.documentation()
1056    }
1057}
1058
1059// Aggregate UDF doc sections for use in public documentation
1060pub mod aggregate_doc_sections {
1061    use crate::DocSection;
1062
1063    pub fn doc_sections() -> Vec<DocSection> {
1064        vec![
1065            DOC_SECTION_GENERAL,
1066            DOC_SECTION_STATISTICAL,
1067            DOC_SECTION_APPROXIMATE,
1068        ]
1069    }
1070
1071    pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1072        include: true,
1073        label: "General Functions",
1074        description: None,
1075    };
1076
1077    pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1078        include: true,
1079        label: "Statistical Functions",
1080        description: None,
1081    };
1082
1083    pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1084        include: true,
1085        label: "Approximate Functions",
1086        description: None,
1087    };
1088}
1089
1090/// Indicates whether an aggregation function is monotonic as a set
1091/// function. A set function is monotonically increasing if its value
1092/// increases as its argument grows (as a set). Formally, `f` is a
1093/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1094/// is a superset of `T`.
1095///
1096/// For example `COUNT` and `MAX` are monotonically increasing as their
1097/// values always increase (or stay the same) as new values are seen. On
1098/// the other hand, `MIN` is monotonically decreasing as its value always
1099/// decreases or stays the same as new values are seen.
1100#[derive(Debug, Clone, PartialEq)]
1101pub enum SetMonotonicity {
1102    /// Aggregate value increases or stays the same as the input set grows.
1103    Increasing,
1104    /// Aggregate value decreases or stays the same as the input set grows.
1105    Decreasing,
1106    /// Aggregate value may increase, decrease, or stay the same as the input
1107    /// set grows.
1108    NotMonotonic,
1109}
1110
1111#[cfg(test)]
1112mod test {
1113    use crate::{AggregateUDF, AggregateUDFImpl};
1114    use arrow::datatypes::{DataType, Field};
1115    use datafusion_common::Result;
1116    use datafusion_expr_common::accumulator::Accumulator;
1117    use datafusion_expr_common::signature::{Signature, Volatility};
1118    use datafusion_functions_aggregate_common::accumulator::{
1119        AccumulatorArgs, StateFieldsArgs,
1120    };
1121    use std::any::Any;
1122    use std::cmp::Ordering;
1123
1124    #[derive(Debug, Clone)]
1125    struct AMeanUdf {
1126        signature: Signature,
1127    }
1128
1129    impl AMeanUdf {
1130        fn new() -> Self {
1131            Self {
1132                signature: Signature::uniform(
1133                    1,
1134                    vec![DataType::Float64],
1135                    Volatility::Immutable,
1136                ),
1137            }
1138        }
1139    }
1140
1141    impl AggregateUDFImpl for AMeanUdf {
1142        fn as_any(&self) -> &dyn Any {
1143            self
1144        }
1145        fn name(&self) -> &str {
1146            "a"
1147        }
1148        fn signature(&self) -> &Signature {
1149            &self.signature
1150        }
1151        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1152            unimplemented!()
1153        }
1154        fn accumulator(
1155            &self,
1156            _acc_args: AccumulatorArgs,
1157        ) -> Result<Box<dyn Accumulator>> {
1158            unimplemented!()
1159        }
1160        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1161            unimplemented!()
1162        }
1163    }
1164
1165    #[derive(Debug, Clone)]
1166    struct BMeanUdf {
1167        signature: Signature,
1168    }
1169    impl BMeanUdf {
1170        fn new() -> Self {
1171            Self {
1172                signature: Signature::uniform(
1173                    1,
1174                    vec![DataType::Float64],
1175                    Volatility::Immutable,
1176                ),
1177            }
1178        }
1179    }
1180
1181    impl AggregateUDFImpl for BMeanUdf {
1182        fn as_any(&self) -> &dyn Any {
1183            self
1184        }
1185        fn name(&self) -> &str {
1186            "b"
1187        }
1188        fn signature(&self) -> &Signature {
1189            &self.signature
1190        }
1191        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1192            unimplemented!()
1193        }
1194        fn accumulator(
1195            &self,
1196            _acc_args: AccumulatorArgs,
1197        ) -> Result<Box<dyn Accumulator>> {
1198            unimplemented!()
1199        }
1200        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1201            unimplemented!()
1202        }
1203    }
1204
1205    #[test]
1206    fn test_partial_ord() {
1207        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1208        // not intended to exhaustively test all possibilities
1209        let a1 = AggregateUDF::from(AMeanUdf::new());
1210        let a2 = AggregateUDF::from(AMeanUdf::new());
1211        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1212
1213        let b1 = AggregateUDF::from(BMeanUdf::new());
1214        assert!(a1 < b1);
1215        assert!(!(a1 == b1));
1216    }
1217}