use crate::function::AccumulatorArgs;
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;
#[derive(Debug, Clone)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}
}
impl Eq for AggregateUDF {}
impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
}
}
impl AggregateUDF {
#[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")]
pub fn new(
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
accumulator: &AccumulatorFactoryFunction,
) -> Self {
Self::new_from_impl(AggregateUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
accumulator: accumulator.clone(),
})
}
pub fn new_from_impl<F>(fun: F) -> AggregateUDF
where
F: AggregateUDFImpl + 'static,
{
Self {
inner: Arc::new(fun),
}
}
pub fn inner(&self) -> Arc<dyn AggregateUDFImpl> {
self.inner.clone()
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases))
}
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
None,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn signature(&self) -> &Signature {
self.inner.signature()
}
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
self.inner.return_type(args)
}
pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
}
pub fn groups_accumulator_supported(&self) -> bool {
self.inner.groups_accumulator_supported()
}
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}
}
impl<F> From<F> for AggregateUDF
where
F: AggregateUDFImpl + Send + Sync + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
pub trait AggregateUDFImpl: Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn name(&self) -> &str;
fn signature(&self) -> &Signature;
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
true,
)];
Ok(value_fields.into_iter().chain(ordering_fields).collect())
}
fn groups_accumulator_supported(&self) -> bool {
false
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}
fn aliases(&self) -> &[String] {
&[]
}
}
#[derive(Debug)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
aliases: Vec<String>,
}
impl AliasedAggregateUDFImpl {
pub fn new(
inner: Arc<dyn AggregateUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self { inner, aliases }
}
}
impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}
fn signature(&self) -> &Signature {
self.inner.signature()
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
self.inner.accumulator(acc_args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
pub struct AggregateUDFLegacyWrapper {
name: String,
signature: Signature,
return_type: ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
}
impl Debug for AggregateUDFLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}
}