use crate::groups_accumulator::GroupsAccumulator;
use crate::{Accumulator, Expr};
use crate::{
AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}
}
impl Eq for AggregateUDF {}
impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
}
}
impl AggregateUDF {
#[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")]
pub fn new(
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
accumulator: &AccumulatorFactoryFunction,
state_type: &StateTypeFunction,
) -> Self {
Self::new_from_impl(AggregateUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
accumulator: accumulator.clone(),
state_type: state_type.clone(),
})
}
pub fn new_from_impl<F>(fun: F) -> AggregateUDF
where
F: AggregateUDFImpl + 'static,
{
Self {
inner: Arc::new(fun),
}
}
pub fn inner(&self) -> Arc<dyn AggregateUDFImpl> {
self.inner.clone()
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
self.inner.return_type(args)
}
pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(return_type)
}
pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
self.inner.state_type(return_type)
}
pub fn groups_accumulator_supported(&self) -> bool {
self.inner.groups_accumulator_supported()
}
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}
}
impl<F> From<F> for AggregateUDF
where
F: AggregateUDFImpl + Send + Sync + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
pub trait AggregateUDFImpl: Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;
fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>>;
fn groups_accumulator_supported(&self) -> bool {
false
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}
fn aliases(&self) -> &[String] {
&[]
}
}
#[derive(Debug)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
aliases: Vec<String>,
}
impl AliasedAggregateUDFImpl {
pub fn new(
inner: Arc<dyn AggregateUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self { inner, aliases }
}
}
impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(arg)
}
fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
self.inner.state_type(return_type)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
pub struct AggregateUDFLegacyWrapper {
name: String,
signature: Signature,
return_type: ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
state_type: StateTypeFunction,
}
impl Debug for AggregateUDFLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}
fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(arg)
}
fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
let res = (self.state_type)(return_type)?;
Ok(res.as_ref().clone())
}
}