Deprecated: The each() function is deprecated. This message will be suppressed on further calls in /home/zhenxiangba/zhenxiangba.com/public_html/phproxy-improved-master/index.php on line 456
udaf.rs - source
[go: Go Back, main page]

datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field, FieldRef};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34    schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay,
35    WindowFunctionParams,
36};
37use crate::function::{
38    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65///    access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79    inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83    fn eq(&self, other: &Self) -> bool {
84        self.inner.equals(other.inner.as_ref())
85    }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91    fn hash<H: Hasher>(&self, state: &mut H) {
92        self.inner.hash_value().hash(state)
93    }
94}
95
96impl fmt::Display for AggregateUDF {
97    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98        write!(f, "{}", self.name())
99    }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105    /// The statistics of the aggregate input
106    pub statistics: &'a Statistics,
107    /// The resolved return type of the aggregate function
108    pub return_type: &'a DataType,
109    /// Whether the aggregate function is distinct.
110    ///
111    /// ```sql
112    /// SELECT COUNT(DISTINCT column1) FROM t;
113    /// ```
114    pub is_distinct: bool,
115    /// The physical expression of arguments the aggregate function takes.
116    pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121    ///
122    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124    where
125        F: AggregateUDFImpl + 'static,
126    {
127        Self::new_from_shared_impl(Arc::new(fun))
128    }
129
130    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132        Self { inner: fun }
133    }
134
135    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137        &self.inner
138    }
139
140    /// Adds additional names that can be used to invoke this function, in
141    /// addition to `name`
142    ///
143    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145        Self::new_from_impl(AliasedAggregateUDFImpl::new(
146            Arc::clone(&self.inner),
147            aliases,
148        ))
149    }
150
151    /// Creates an [`Expr`] that calls the aggregate function.
152    ///
153    /// This utility allows using the UDAF without requiring access to
154    /// the registry, such as with the DataFrame API.
155    pub fn call(&self, args: Vec<Expr>) -> Expr {
156        Expr::AggregateFunction(AggregateFunction::new_udf(
157            Arc::new(self.clone()),
158            args,
159            false,
160            None,
161            None,
162            None,
163        ))
164    }
165
166    /// Returns this function's name
167    ///
168    /// See [`AggregateUDFImpl::name`] for more details.
169    pub fn name(&self) -> &str {
170        self.inner.name()
171    }
172
173    /// See [`AggregateUDFImpl::schema_name`] for more details.
174    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
175        self.inner.schema_name(params)
176    }
177
178    /// Returns a human readable expression.
179    ///
180    /// See [`Expr::human_display`] for details.
181    pub fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
182        self.inner.human_display(params)
183    }
184
185    pub fn window_function_schema_name(
186        &self,
187        params: &WindowFunctionParams,
188    ) -> Result<String> {
189        self.inner.window_function_schema_name(params)
190    }
191
192    /// See [`AggregateUDFImpl::display_name`] for more details.
193    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
194        self.inner.display_name(params)
195    }
196
197    pub fn window_function_display_name(
198        &self,
199        params: &WindowFunctionParams,
200    ) -> Result<String> {
201        self.inner.window_function_display_name(params)
202    }
203
204    pub fn is_nullable(&self) -> bool {
205        self.inner.is_nullable()
206    }
207
208    /// Returns the aliases for this function.
209    pub fn aliases(&self) -> &[String] {
210        self.inner.aliases()
211    }
212
213    /// Returns this function's signature (what input types are accepted)
214    ///
215    /// See [`AggregateUDFImpl::signature`] for more details.
216    pub fn signature(&self) -> &Signature {
217        self.inner.signature()
218    }
219
220    /// Return the type of the function given its input types
221    ///
222    /// See [`AggregateUDFImpl::return_type`] for more details.
223    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
224        self.inner.return_type(args)
225    }
226
227    /// Return the field of the function given its input fields
228    ///
229    /// See [`AggregateUDFImpl::return_field`] for more details.
230    pub fn return_field(&self, args: &[FieldRef]) -> Result<FieldRef> {
231        self.inner.return_field(args)
232    }
233
234    /// Return an accumulator the given aggregate, given its return datatype
235    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
236        self.inner.accumulator(acc_args)
237    }
238
239    /// Return the fields used to store the intermediate state for this aggregator, given
240    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
241    /// for more details.
242    ///
243    /// This is used to support multi-phase aggregations
244    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
245        self.inner.state_fields(args)
246    }
247
248    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
249    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
250        self.inner.groups_accumulator_supported(args)
251    }
252
253    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
254    pub fn create_groups_accumulator(
255        &self,
256        args: AccumulatorArgs,
257    ) -> Result<Box<dyn GroupsAccumulator>> {
258        self.inner.create_groups_accumulator(args)
259    }
260
261    pub fn create_sliding_accumulator(
262        &self,
263        args: AccumulatorArgs,
264    ) -> Result<Box<dyn Accumulator>> {
265        self.inner.create_sliding_accumulator(args)
266    }
267
268    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
269        self.inner.coerce_types(arg_types)
270    }
271
272    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
273    pub fn with_beneficial_ordering(
274        self,
275        beneficial_ordering: bool,
276    ) -> Result<Option<AggregateUDF>> {
277        self.inner
278            .with_beneficial_ordering(beneficial_ordering)
279            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
280    }
281
282    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
283    /// for possible options.
284    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
285        self.inner.order_sensitivity()
286    }
287
288    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
289    /// generate same result with this `AggregateUDF` when iterated in reverse
290    /// order, and `None` if there is no such `AggregateUDF`).
291    pub fn reverse_udf(&self) -> ReversedUDAF {
292        self.inner.reverse_expr()
293    }
294
295    /// Do the function rewrite
296    ///
297    /// See [`AggregateUDFImpl::simplify`] for more details.
298    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
299        self.inner.simplify()
300    }
301
302    /// Returns true if the function is max, false if the function is min
303    /// None in all other cases, used in certain optimizations for
304    /// or aggregate
305    pub fn is_descending(&self) -> Option<bool> {
306        self.inner.is_descending()
307    }
308
309    /// Return the value of this aggregate function if it can be determined
310    /// entirely from statistics and arguments.
311    ///
312    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
313    pub fn value_from_stats(
314        &self,
315        statistics_args: &StatisticsArgs,
316    ) -> Option<ScalarValue> {
317        self.inner.value_from_stats(statistics_args)
318    }
319
320    /// See [`AggregateUDFImpl::default_value`] for more details.
321    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
322        self.inner.default_value(data_type)
323    }
324
325    /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
326    pub fn supports_null_handling_clause(&self) -> bool {
327        self.inner.supports_null_handling_clause()
328    }
329
330    /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details.
331    pub fn is_ordered_set_aggregate(&self) -> bool {
332        self.inner.is_ordered_set_aggregate()
333    }
334
335    /// Returns the documentation for this Aggregate UDF.
336    ///
337    /// Documentation can be accessed programmatically as well as
338    /// generating publicly facing documentation.
339    pub fn documentation(&self) -> Option<&Documentation> {
340        self.inner.documentation()
341    }
342}
343
344impl<F> From<F> for AggregateUDF
345where
346    F: AggregateUDFImpl + Send + Sync + 'static,
347{
348    fn from(fun: F) -> Self {
349        Self::new_from_impl(fun)
350    }
351}
352
353/// Trait for implementing [`AggregateUDF`].
354///
355/// This trait exposes the full API for implementing user defined aggregate functions and
356/// can be used to implement any function.
357///
358/// See [`advanced_udaf.rs`] for a full example with complete implementation and
359/// [`AggregateUDF`] for other available options.
360///
361/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
362///
363/// # Basic Example
364/// ```
365/// # use std::any::Any;
366/// # use std::sync::{Arc, LazyLock};
367/// # use arrow::datatypes::{DataType, FieldRef};
368/// # use datafusion_common::{DataFusionError, plan_err, Result};
369/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
370/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
371/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
372/// # use arrow::datatypes::Schema;
373/// # use arrow::datatypes::Field;
374///
375/// #[derive(Debug, Clone)]
376/// struct GeoMeanUdf {
377///   signature: Signature,
378/// }
379///
380/// impl GeoMeanUdf {
381///   fn new() -> Self {
382///     Self {
383///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
384///      }
385///   }
386/// }
387///
388/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
389///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
390///             .with_argument("arg1", "The Float64 number for the geometric mean")
391///             .build()
392///     });
393///
394/// fn get_doc() -> &'static Documentation {
395///     &DOCUMENTATION
396/// }
397///    
398/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
399/// impl AggregateUDFImpl for GeoMeanUdf {
400///    fn as_any(&self) -> &dyn Any { self }
401///    fn name(&self) -> &str { "geo_mean" }
402///    fn signature(&self) -> &Signature { &self.signature }
403///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
404///      if !matches!(args.get(0), Some(&DataType::Float64)) {
405///        return plan_err!("geo_mean only accepts Float64 arguments");
406///      }
407///      Ok(DataType::Float64)
408///    }
409///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
410///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
411///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
412///        Ok(vec![
413///             Arc::new(args.return_field.as_ref().clone().with_name("value")),
414///             Arc::new(Field::new("ordering", DataType::UInt32, true))
415///        ])
416///    }
417///    fn documentation(&self) -> Option<&Documentation> {
418///        Some(get_doc())  
419///    }
420/// }
421///
422/// // Create a new AggregateUDF from the implementation
423/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
424///
425/// // Call the function `geo_mean(col)`
426/// let expr = geometric_mean.call(vec![col("a")]);
427/// ```
428pub trait AggregateUDFImpl: Debug + Send + Sync {
429    // Note: When adding any methods (with default implementations), remember to add them also
430    // into the AliasedAggregateUDFImpl below!
431
432    /// Returns this object as an [`Any`] trait object
433    fn as_any(&self) -> &dyn Any;
434
435    /// Returns this function's name
436    fn name(&self) -> &str;
437
438    /// Returns the name of the column this expression would create
439    ///
440    /// See [`Expr::schema_name`] for details
441    ///
442    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
443    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
444        let AggregateFunctionParams {
445            args,
446            distinct,
447            filter,
448            order_by,
449            null_treatment,
450        } = params;
451
452        // exclude the first function argument(= column) in ordered set aggregate function,
453        // because it is duplicated with the WITHIN GROUP clause in schema name.
454        let args = if self.is_ordered_set_aggregate() {
455            &args[1..]
456        } else {
457            &args[..]
458        };
459
460        let mut schema_name = String::new();
461
462        schema_name.write_fmt(format_args!(
463            "{}({}{})",
464            self.name(),
465            if *distinct { "DISTINCT " } else { "" },
466            schema_name_from_exprs_comma_separated_without_space(args)?
467        ))?;
468
469        if let Some(null_treatment) = null_treatment {
470            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
471        }
472
473        if let Some(filter) = filter {
474            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
475        };
476
477        if let Some(order_by) = order_by {
478            let clause = match self.is_ordered_set_aggregate() {
479                true => "WITHIN GROUP",
480                false => "ORDER BY",
481            };
482
483            schema_name.write_fmt(format_args!(
484                " {} [{}]",
485                clause,
486                schema_name_from_sorts(order_by)?
487            ))?;
488        };
489
490        Ok(schema_name)
491    }
492
493    /// Returns a human readable expression.
494    ///
495    /// See [`Expr::human_display`] for details.
496    fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
497        let AggregateFunctionParams {
498            args,
499            distinct,
500            filter,
501            order_by,
502            null_treatment,
503        } = params;
504
505        let mut schema_name = String::new();
506
507        schema_name.write_fmt(format_args!(
508            "{}({}{})",
509            self.name(),
510            if *distinct { "DISTINCT " } else { "" },
511            ExprListDisplay::comma_separated(args.as_slice())
512        ))?;
513
514        if let Some(null_treatment) = null_treatment {
515            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
516        }
517
518        if let Some(filter) = filter {
519            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
520        };
521
522        if let Some(order_by) = order_by {
523            schema_name.write_fmt(format_args!(
524                " ORDER BY [{}]",
525                schema_name_from_sorts(order_by)?
526            ))?;
527        };
528
529        Ok(schema_name)
530    }
531
532    /// Returns the name of the column this expression would create
533    ///
534    /// See [`Expr::schema_name`] for details
535    ///
536    /// Different from `schema_name` in that it is used for window aggregate function
537    ///
538    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
539    fn window_function_schema_name(
540        &self,
541        params: &WindowFunctionParams,
542    ) -> Result<String> {
543        let WindowFunctionParams {
544            args,
545            partition_by,
546            order_by,
547            window_frame,
548            null_treatment,
549        } = params;
550
551        let mut schema_name = String::new();
552        schema_name.write_fmt(format_args!(
553            "{}({})",
554            self.name(),
555            schema_name_from_exprs(args)?
556        ))?;
557
558        if let Some(null_treatment) = null_treatment {
559            schema_name.write_fmt(format_args!(" {null_treatment}"))?;
560        }
561
562        if !partition_by.is_empty() {
563            schema_name.write_fmt(format_args!(
564                " PARTITION BY [{}]",
565                schema_name_from_exprs(partition_by)?
566            ))?;
567        }
568
569        if !order_by.is_empty() {
570            schema_name.write_fmt(format_args!(
571                " ORDER BY [{}]",
572                schema_name_from_sorts(order_by)?
573            ))?;
574        };
575
576        schema_name.write_fmt(format_args!(" {window_frame}"))?;
577
578        Ok(schema_name)
579    }
580
581    /// Returns the user-defined display name of function, given the arguments
582    ///
583    /// This can be used to customize the output column name generated by this
584    /// function.
585    ///
586    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
587    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
588        let AggregateFunctionParams {
589            args,
590            distinct,
591            filter,
592            order_by,
593            null_treatment,
594        } = params;
595
596        let mut display_name = String::new();
597
598        display_name.write_fmt(format_args!(
599            "{}({}{})",
600            self.name(),
601            if *distinct { "DISTINCT " } else { "" },
602            expr_vec_fmt!(args)
603        ))?;
604
605        if let Some(nt) = null_treatment {
606            display_name.write_fmt(format_args!(" {nt}"))?;
607        }
608        if let Some(fe) = filter {
609            display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
610        }
611        if let Some(ob) = order_by {
612            display_name.write_fmt(format_args!(
613                " ORDER BY [{}]",
614                ob.iter()
615                    .map(|o| format!("{o}"))
616                    .collect::<Vec<String>>()
617                    .join(", ")
618            ))?;
619        }
620
621        Ok(display_name)
622    }
623
624    /// Returns the user-defined display name of function, given the arguments
625    ///
626    /// This can be used to customize the output column name generated by this
627    /// function.
628    ///
629    /// Different from `display_name` in that it is used for window aggregate function
630    ///
631    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
632    fn window_function_display_name(
633        &self,
634        params: &WindowFunctionParams,
635    ) -> Result<String> {
636        let WindowFunctionParams {
637            args,
638            partition_by,
639            order_by,
640            window_frame,
641            null_treatment,
642        } = params;
643
644        let mut display_name = String::new();
645
646        display_name.write_fmt(format_args!(
647            "{}({})",
648            self.name(),
649            expr_vec_fmt!(args)
650        ))?;
651
652        if let Some(null_treatment) = null_treatment {
653            display_name.write_fmt(format_args!(" {null_treatment}"))?;
654        }
655
656        if !partition_by.is_empty() {
657            display_name.write_fmt(format_args!(
658                " PARTITION BY [{}]",
659                expr_vec_fmt!(partition_by)
660            ))?;
661        }
662
663        if !order_by.is_empty() {
664            display_name
665                .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
666        };
667
668        display_name.write_fmt(format_args!(
669            " {} BETWEEN {} AND {}",
670            window_frame.units, window_frame.start_bound, window_frame.end_bound
671        ))?;
672
673        Ok(display_name)
674    }
675
676    /// Returns the function's [`Signature`] for information about what input
677    /// types are accepted and the function's Volatility.
678    fn signature(&self) -> &Signature;
679
680    /// What [`DataType`] will be returned by this function, given the types of
681    /// the arguments
682    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
683
684    /// What type will be returned by this function, given the arguments?
685    ///
686    /// By default, this function calls [`Self::return_type`] with the
687    /// types of each argument.
688    ///
689    /// # Notes
690    ///
691    /// Most UDFs should implement [`Self::return_type`] and not this
692    /// function as the output type for most functions only depends on the types
693    /// of their inputs (e.g. `sum(f64)` is always `f64`).
694    ///
695    /// This function can be used for more advanced cases such as:
696    ///
697    /// 1. specifying nullability
698    /// 2. return types based on the **values** of the arguments (rather than
699    ///    their **types**.
700    /// 3. return types based on metadata within the fields of the inputs
701    fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
702        let arg_types: Vec<_> =
703            arg_fields.iter().map(|f| f.data_type()).cloned().collect();
704        let data_type = self.return_type(&arg_types)?;
705
706        Ok(Arc::new(Field::new(
707            self.name(),
708            data_type,
709            self.is_nullable(),
710        )))
711    }
712
713    /// Whether the aggregate function is nullable.
714    ///
715    /// Nullable means that the function could return `null` for any inputs.
716    /// For example, aggregate functions like `COUNT` always return a non null value
717    /// but others like `MIN` will return `NULL` if there is nullable input.
718    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
719    fn is_nullable(&self) -> bool {
720        true
721    }
722
723    /// Return a new [`Accumulator`] that aggregates values for a specific
724    /// group during query execution.
725    ///
726    /// acc_args: [`AccumulatorArgs`] contains information about how the
727    /// aggregate function was called.
728    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
729
730    /// Return the fields used to store the intermediate state of this accumulator.
731    ///
732    /// See [`Accumulator::state`] for background information.
733    ///
734    /// args:  [`StateFieldsArgs`] contains arguments passed to the
735    /// aggregate function's accumulator.
736    ///
737    /// # Notes:
738    ///
739    /// The default implementation returns a single state field named `name`
740    /// with the same type as `value_type`. This is suitable for aggregates such
741    /// as `SUM` or `MIN` where partial state can be combined by applying the
742    /// same aggregate.
743    ///
744    /// For aggregates such as `AVG` where the partial state is more complex
745    /// (e.g. a COUNT and a SUM), this method is used to define the additional
746    /// fields.
747    ///
748    /// The name of the fields must be unique within the query and thus should
749    /// be derived from `name`. See [`format_state_name`] for a utility function
750    /// to generate a unique name.
751    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
752        let fields = vec![args
753            .return_field
754            .as_ref()
755            .clone()
756            .with_name(format_state_name(args.name, "value"))];
757
758        Ok(fields
759            .into_iter()
760            .map(Arc::new)
761            .chain(args.ordering_fields.to_vec())
762            .collect())
763    }
764
765    /// If the aggregate expression has a specialized
766    /// [`GroupsAccumulator`] implementation. If this returns true,
767    /// `[Self::create_groups_accumulator]` will be called.
768    ///
769    /// # Notes
770    ///
771    /// Even if this function returns true, DataFusion will still use
772    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
773    /// used as a window function or when there no GROUP BY columns in the
774    /// query.
775    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
776        false
777    }
778
779    /// Return a specialized [`GroupsAccumulator`] that manages state
780    /// for all groups.
781    ///
782    /// For maximum performance, a [`GroupsAccumulator`] should be
783    /// implemented in addition to [`Accumulator`].
784    fn create_groups_accumulator(
785        &self,
786        _args: AccumulatorArgs,
787    ) -> Result<Box<dyn GroupsAccumulator>> {
788        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
789    }
790
791    /// Returns any aliases (alternate names) for this function.
792    ///
793    /// Note: `aliases` should only include names other than [`Self::name`].
794    /// Defaults to `[]` (no aliases)
795    fn aliases(&self) -> &[String] {
796        &[]
797    }
798
799    /// Sliding accumulator is an alternative accumulator that can be used for
800    /// window functions. It has retract method to revert the previous update.
801    ///
802    /// See [retract_batch] for more details.
803    ///
804    /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
805    fn create_sliding_accumulator(
806        &self,
807        args: AccumulatorArgs,
808    ) -> Result<Box<dyn Accumulator>> {
809        self.accumulator(args)
810    }
811
812    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
813    /// satisfied by its input. If this is not the case, UDFs with order
814    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
815    /// the correct result with possibly more work internally.
816    ///
817    /// # Returns
818    ///
819    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
820    /// If the expression can benefit from existing input ordering, but does
821    /// not implement the method, returns an error. Order insensitive and hard
822    /// requirement aggregators return `Ok(None)`.
823    fn with_beneficial_ordering(
824        self: Arc<Self>,
825        _beneficial_ordering: bool,
826    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
827        if self.order_sensitivity().is_beneficial() {
828            return exec_err!(
829                "Should implement with satisfied for aggregator :{:?}",
830                self.name()
831            );
832        }
833        Ok(None)
834    }
835
836    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
837    /// for possible options.
838    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
839        // We have hard ordering requirements by default, meaning that order
840        // sensitive UDFs need their input orderings to satisfy their ordering
841        // requirements to generate correct results.
842        AggregateOrderSensitivity::HardRequirement
843    }
844
845    /// Optionally apply per-UDaF simplification / rewrite rules.
846    ///
847    /// This can be used to apply function specific simplification rules during
848    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
849    /// implementation does nothing.
850    ///
851    /// Note that DataFusion handles simplifying arguments and  "constant
852    /// folding" (replacing a function call with constant arguments such as
853    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
854    /// optimizations manually for specific UDFs.
855    ///
856    /// # Returns
857    ///
858    /// [None] if simplify is not defined or,
859    ///
860    /// Or, a closure with two arguments:
861    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
862    /// * 'info': [crate::simplify::SimplifyInfo]
863    ///
864    /// closure returns simplified [Expr] or an error.
865    ///
866    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
867        None
868    }
869
870    /// Returns the reverse expression of the aggregate function.
871    fn reverse_expr(&self) -> ReversedUDAF {
872        ReversedUDAF::NotSupported
873    }
874
875    /// Coerce arguments of a function call to types that the function can evaluate.
876    ///
877    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
878    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
879    /// cases
880    ///
881    /// See the [type coercion module](crate::type_coercion)
882    /// documentation for more details on type coercion
883    ///
884    /// For example, if your function requires a floating point arguments, but the user calls
885    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
886    /// to ensure the argument was cast to `1::double`
887    ///
888    /// # Parameters
889    /// * `arg_types`: The argument types of the arguments  this function with
890    ///
891    /// # Return value
892    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
893    /// arguments to these specific types.
894    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
895        not_impl_err!("Function {} does not implement coerce_types", self.name())
896    }
897
898    /// Return true if this aggregate UDF is equal to the other.
899    ///
900    /// Allows customizing the equality of aggregate UDFs.
901    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
902    ///
903    /// - reflexive: `a.equals(a)`;
904    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
905    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
906    ///
907    /// By default, compares [`Self::name`] and [`Self::signature`].
908    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
909        self.name() == other.name() && self.signature() == other.signature()
910    }
911
912    /// Returns a hash value for this aggregate UDF.
913    ///
914    /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
915    /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
916    ///
917    /// By default, hashes [`Self::name`] and [`Self::signature`].
918    fn hash_value(&self) -> u64 {
919        let hasher = &mut DefaultHasher::new();
920        self.name().hash(hasher);
921        self.signature().hash(hasher);
922        hasher.finish()
923    }
924
925    /// If this function is max, return true
926    /// If the function is min, return false
927    /// Otherwise return None (the default)
928    ///
929    ///
930    /// Note: this is used to use special aggregate implementations in certain conditions
931    fn is_descending(&self) -> Option<bool> {
932        None
933    }
934
935    /// Return the value of this aggregate function if it can be determined
936    /// entirely from statistics and arguments.
937    ///
938    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
939    /// improving query performance.
940    ///
941    /// For example, if the minimum value of column `x` is known to be `42` from
942    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
943    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
944        None
945    }
946
947    /// Returns default value of the function given the input is all `null`.
948    ///
949    /// Most of the aggregate function return Null if input is Null,
950    /// while `count` returns 0 if input is Null
951    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
952        ScalarValue::try_from(data_type)
953    }
954
955    /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true
956    /// If the function does not, return false
957    fn supports_null_handling_clause(&self) -> bool {
958        true
959    }
960
961    /// If this function is ordered-set aggregate function, return true
962    /// If the function is not, return false
963    fn is_ordered_set_aggregate(&self) -> bool {
964        false
965    }
966
967    /// Returns the documentation for this Aggregate UDF.
968    ///
969    /// Documentation can be accessed programmatically as well as
970    /// generating publicly facing documentation.
971    fn documentation(&self) -> Option<&Documentation> {
972        None
973    }
974
975    /// Indicates whether the aggregation function is monotonic as a set
976    /// function. See [`SetMonotonicity`] for details.
977    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
978        SetMonotonicity::NotMonotonic
979    }
980}
981
982impl PartialEq for dyn AggregateUDFImpl {
983    fn eq(&self, other: &Self) -> bool {
984        self.equals(other)
985    }
986}
987
988// Manual implementation of `PartialOrd`
989// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
990// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
991impl PartialOrd for dyn AggregateUDFImpl {
992    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
993        match self.name().partial_cmp(other.name()) {
994            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
995            cmp => cmp,
996        }
997    }
998}
999
1000pub enum ReversedUDAF {
1001    /// The expression is the same as the original expression, like SUM, COUNT
1002    Identical,
1003    /// The expression does not support reverse calculation
1004    NotSupported,
1005    /// The expression is different from the original expression
1006    Reversed(Arc<AggregateUDF>),
1007}
1008
1009/// AggregateUDF that adds an alias to the underlying function. It is better to
1010/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1011#[derive(Debug)]
1012struct AliasedAggregateUDFImpl {
1013    inner: Arc<dyn AggregateUDFImpl>,
1014    aliases: Vec<String>,
1015}
1016
1017impl AliasedAggregateUDFImpl {
1018    pub fn new(
1019        inner: Arc<dyn AggregateUDFImpl>,
1020        new_aliases: impl IntoIterator<Item = &'static str>,
1021    ) -> Self {
1022        let mut aliases = inner.aliases().to_vec();
1023        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1024
1025        Self { inner, aliases }
1026    }
1027}
1028
1029impl AggregateUDFImpl for AliasedAggregateUDFImpl {
1030    fn as_any(&self) -> &dyn Any {
1031        self
1032    }
1033
1034    fn name(&self) -> &str {
1035        self.inner.name()
1036    }
1037
1038    fn signature(&self) -> &Signature {
1039        self.inner.signature()
1040    }
1041
1042    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1043        self.inner.return_type(arg_types)
1044    }
1045
1046    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1047        self.inner.accumulator(acc_args)
1048    }
1049
1050    fn aliases(&self) -> &[String] {
1051        &self.aliases
1052    }
1053
1054    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1055        self.inner.state_fields(args)
1056    }
1057
1058    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1059        self.inner.groups_accumulator_supported(args)
1060    }
1061
1062    fn create_groups_accumulator(
1063        &self,
1064        args: AccumulatorArgs,
1065    ) -> Result<Box<dyn GroupsAccumulator>> {
1066        self.inner.create_groups_accumulator(args)
1067    }
1068
1069    fn create_sliding_accumulator(
1070        &self,
1071        args: AccumulatorArgs,
1072    ) -> Result<Box<dyn Accumulator>> {
1073        self.inner.accumulator(args)
1074    }
1075
1076    fn with_beneficial_ordering(
1077        self: Arc<Self>,
1078        beneficial_ordering: bool,
1079    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1080        Arc::clone(&self.inner)
1081            .with_beneficial_ordering(beneficial_ordering)
1082            .map(|udf| {
1083                udf.map(|udf| {
1084                    Arc::new(AliasedAggregateUDFImpl {
1085                        inner: udf,
1086                        aliases: self.aliases.clone(),
1087                    }) as Arc<dyn AggregateUDFImpl>
1088                })
1089            })
1090    }
1091
1092    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1093        self.inner.order_sensitivity()
1094    }
1095
1096    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
1097        self.inner.simplify()
1098    }
1099
1100    fn reverse_expr(&self) -> ReversedUDAF {
1101        self.inner.reverse_expr()
1102    }
1103
1104    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1105        self.inner.coerce_types(arg_types)
1106    }
1107
1108    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1109        if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1110            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1111        } else {
1112            false
1113        }
1114    }
1115
1116    fn hash_value(&self) -> u64 {
1117        let hasher = &mut DefaultHasher::new();
1118        self.inner.hash_value().hash(hasher);
1119        self.aliases.hash(hasher);
1120        hasher.finish()
1121    }
1122
1123    fn is_descending(&self) -> Option<bool> {
1124        self.inner.is_descending()
1125    }
1126
1127    fn documentation(&self) -> Option<&Documentation> {
1128        self.inner.documentation()
1129    }
1130}
1131
1132// Aggregate UDF doc sections for use in public documentation
1133pub mod aggregate_doc_sections {
1134    use crate::DocSection;
1135
1136    pub fn doc_sections() -> Vec<DocSection> {
1137        vec![
1138            DOC_SECTION_GENERAL,
1139            DOC_SECTION_STATISTICAL,
1140            DOC_SECTION_APPROXIMATE,
1141        ]
1142    }
1143
1144    pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1145        include: true,
1146        label: "General Functions",
1147        description: None,
1148    };
1149
1150    pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1151        include: true,
1152        label: "Statistical Functions",
1153        description: None,
1154    };
1155
1156    pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1157        include: true,
1158        label: "Approximate Functions",
1159        description: None,
1160    };
1161}
1162
1163/// Indicates whether an aggregation function is monotonic as a set
1164/// function. A set function is monotonically increasing if its value
1165/// increases as its argument grows (as a set). Formally, `f` is a
1166/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1167/// is a superset of `T`.
1168///
1169/// For example `COUNT` and `MAX` are monotonically increasing as their
1170/// values always increase (or stay the same) as new values are seen. On
1171/// the other hand, `MIN` is monotonically decreasing as its value always
1172/// decreases or stays the same as new values are seen.
1173#[derive(Debug, Clone, PartialEq)]
1174pub enum SetMonotonicity {
1175    /// Aggregate value increases or stays the same as the input set grows.
1176    Increasing,
1177    /// Aggregate value decreases or stays the same as the input set grows.
1178    Decreasing,
1179    /// Aggregate value may increase, decrease, or stay the same as the input
1180    /// set grows.
1181    NotMonotonic,
1182}
1183
1184#[cfg(test)]
1185mod test {
1186    use crate::{AggregateUDF, AggregateUDFImpl};
1187    use arrow::datatypes::{DataType, FieldRef};
1188    use datafusion_common::Result;
1189    use datafusion_expr_common::accumulator::Accumulator;
1190    use datafusion_expr_common::signature::{Signature, Volatility};
1191    use datafusion_functions_aggregate_common::accumulator::{
1192        AccumulatorArgs, StateFieldsArgs,
1193    };
1194    use std::any::Any;
1195    use std::cmp::Ordering;
1196
1197    #[derive(Debug, Clone)]
1198    struct AMeanUdf {
1199        signature: Signature,
1200    }
1201
1202    impl AMeanUdf {
1203        fn new() -> Self {
1204            Self {
1205                signature: Signature::uniform(
1206                    1,
1207                    vec![DataType::Float64],
1208                    Volatility::Immutable,
1209                ),
1210            }
1211        }
1212    }
1213
1214    impl AggregateUDFImpl for AMeanUdf {
1215        fn as_any(&self) -> &dyn Any {
1216            self
1217        }
1218        fn name(&self) -> &str {
1219            "a"
1220        }
1221        fn signature(&self) -> &Signature {
1222            &self.signature
1223        }
1224        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1225            unimplemented!()
1226        }
1227        fn accumulator(
1228            &self,
1229            _acc_args: AccumulatorArgs,
1230        ) -> Result<Box<dyn Accumulator>> {
1231            unimplemented!()
1232        }
1233        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1234            unimplemented!()
1235        }
1236    }
1237
1238    #[derive(Debug, Clone)]
1239    struct BMeanUdf {
1240        signature: Signature,
1241    }
1242    impl BMeanUdf {
1243        fn new() -> Self {
1244            Self {
1245                signature: Signature::uniform(
1246                    1,
1247                    vec![DataType::Float64],
1248                    Volatility::Immutable,
1249                ),
1250            }
1251        }
1252    }
1253
1254    impl AggregateUDFImpl for BMeanUdf {
1255        fn as_any(&self) -> &dyn Any {
1256            self
1257        }
1258        fn name(&self) -> &str {
1259            "b"
1260        }
1261        fn signature(&self) -> &Signature {
1262            &self.signature
1263        }
1264        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1265            unimplemented!()
1266        }
1267        fn accumulator(
1268            &self,
1269            _acc_args: AccumulatorArgs,
1270        ) -> Result<Box<dyn Accumulator>> {
1271            unimplemented!()
1272        }
1273        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1274            unimplemented!()
1275        }
1276    }
1277
1278    #[test]
1279    fn test_partial_ord() {
1280        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1281        // not intended to exhaustively test all possibilities
1282        let a1 = AggregateUDF::from(AMeanUdf::new());
1283        let a2 = AggregateUDF::from(AMeanUdf::new());
1284        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1285
1286        let b1 = AggregateUDF::from(BMeanUdf::new());
1287        assert!(a1 < b1);
1288        assert!(!(a1 == b1));
1289    }
1290}