use crate::expr::AggregateFunction;
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;
#[derive(Debug, Clone)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}
}
impl Eq for AggregateUDF {}
impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
}
}
impl std::fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
}
impl AggregateUDF {
#[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")]
pub fn new(
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
accumulator: &AccumulatorFactoryFunction,
) -> Self {
Self::new_from_impl(AggregateUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: Arc::clone(return_type),
accumulator: Arc::clone(accumulator),
})
}
pub fn new_from_impl<F>(fun: F) -> AggregateUDF
where
F: AggregateUDFImpl + 'static,
{
Self {
inner: Arc::new(fun),
}
}
pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedAggregateUDFImpl::new(
Arc::clone(&self.inner),
aliases,
))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
None,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
self.inner.return_type(args)
}
pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
}
pub fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator(args)
}
pub fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.inner.create_sliding_accumulator(args)
}
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
pub fn with_beneficial_ordering(
self,
beneficial_ordering: bool,
) -> Result<Option<AggregateUDF>> {
self.inner
.with_beneficial_ordering(beneficial_ordering)
.map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
}
pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
self.inner.order_sensitivity()
}
pub fn reverse_udf(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
self.inner.simplify()
}
}
impl<F> From<F> for AggregateUDF
where
F: AggregateUDFImpl + Send + Sync + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
pub trait AggregateUDFImpl: Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let fields = vec![Field::new(
format_state_name(args.name, "value"),
args.return_type.clone(),
true,
)];
Ok(fields
.into_iter()
.chain(args.ordering_fields.to_vec())
.collect())
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}
fn aliases(&self) -> &[String] {
&[]
}
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.accumulator(args)
}
fn with_beneficial_ordering(
self: Arc<Self>,
_beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
if self.order_sensitivity().is_beneficial() {
return exec_err!(
"Should implement with satisfied for aggregator :{:?}",
self.name()
);
}
Ok(None)
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::HardRequirement
}
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
}
pub enum ReversedUDAF {
Identical,
NotSupported,
Reversed(Arc<AggregateUDF>),
}
#[derive(Debug)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
aliases: Vec<String>,
}
impl AliasedAggregateUDFImpl {
pub fn new(
inner: Arc<dyn AggregateUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self { inner, aliases }
}
}
impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
pub struct AggregateUDFLegacyWrapper {
name: String,
signature: Signature,
return_type: ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
}
impl Debug for AggregateUDFLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}
}
pub trait AggregateExt {
fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder;
fn filter(self, filter: Expr) -> AggregateBuilder;
fn distinct(self) -> AggregateBuilder;
fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder;
}
#[derive(Debug, Clone)]
pub struct AggregateBuilder {
udaf: Option<AggregateFunction>,
order_by: Option<Vec<Expr>>,
filter: Option<Expr>,
distinct: bool,
null_treatment: Option<NullTreatment>,
}
impl AggregateBuilder {
fn new(udaf: Option<AggregateFunction>) -> Self {
Self {
udaf,
order_by: None,
filter: None,
distinct: false,
null_treatment: None,
}
}
pub fn build(self) -> Result<Expr> {
let Self {
udaf,
order_by,
filter,
distinct,
null_treatment,
} = self;
let Some(mut udaf) = udaf else {
return plan_err!(
"AggregateExt can only be used with Expr::AggregateFunction"
);
};
if let Some(order_by) = &order_by {
for expr in order_by.iter() {
if !matches!(expr, Expr::Sort(_)) {
return plan_err!(
"ORDER BY expressions must be Expr::Sort, found {expr:?}"
);
}
}
}
udaf.order_by = order_by;
udaf.filter = filter.map(Box::new);
udaf.distinct = distinct;
udaf.null_treatment = null_treatment;
Ok(Expr::AggregateFunction(udaf))
}
pub fn order_by(mut self, order_by: Vec<Expr>) -> AggregateBuilder {
self.order_by = Some(order_by);
self
}
pub fn filter(mut self, filter: Expr) -> AggregateBuilder {
self.filter = Some(filter);
self
}
pub fn distinct(mut self) -> AggregateBuilder {
self.distinct = true;
self
}
pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder {
self.null_treatment = Some(null_treatment);
self
}
}
impl AggregateExt for Expr {
fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder = AggregateBuilder::new(Some(udaf));
builder.order_by = Some(order_by);
builder
}
_ => AggregateBuilder::new(None),
}
}
fn filter(self, filter: Expr) -> AggregateBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder = AggregateBuilder::new(Some(udaf));
builder.filter = Some(filter);
builder
}
_ => AggregateBuilder::new(None),
}
}
fn distinct(self) -> AggregateBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder = AggregateBuilder::new(Some(udaf));
builder.distinct = true;
builder
}
_ => AggregateBuilder::new(None),
}
}
fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder {
match self {
Expr::AggregateFunction(udaf) => {
let mut builder = AggregateBuilder::new(Some(udaf));
builder.null_treatment = Some(null_treatment);
builder
}
_ => AggregateBuilder::new(None),
}
}
}