use fmt::Debug;
use std::any::Any;
use std::fmt;
use arrow::{
datatypes::Field,
datatypes::{DataType, Schema},
};
use super::{expressions::format_state_name, Accumulator, AggregateExpr};
use crate::physical_plan::PhysicalExpr;
use datafusion_common::{DataFusionError, Result};
pub use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
use std::sync::Arc;
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(),
name: name.into(),
}))
}
#[derive(Debug)]
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
data_type: DataType,
name: String,
}
impl AggregateFunctionExpr {
pub fn fun(&self) -> &AggregateUDF {
&self.fun
}
}
impl AggregateExpr for AggregateFunctionExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}
fn state_fields(&self) -> Result<Vec<Field>> {
let fields = (self.fun.state_type)(&self.data_type)?
.iter()
.enumerate()
.map(|(i, data_type)| {
Field::new(
format_state_name(&self.name, &format!("{i}")),
data_type.clone(),
true,
)
})
.collect::<Vec<Field>>();
Ok(fields)
}
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
(self.fun.accumulator)(&self.data_type)
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let accumulator = (self.fun.accumulator)(&self.data_type)?;
if !accumulator.supports_retract_batch() {
return Err(DataFusionError::NotImplemented(format!(
"Aggregate can not be used as a sliding accumulator because \
`retract_batch` is not implemented: {}",
self.name
)));
}
Ok(accumulator)
}
fn name(&self) -> &str {
&self.name
}
}
impl PartialEq<dyn Any> for AggregateFunctionExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.fun == x.fun
&& self.args.len() == x.args.len()
&& self
.args
.iter()
.zip(x.args.iter())
.all(|(this_arg, other_arg)| this_arg.eq(other_arg))
})
.unwrap_or(false)
}
}