use crate::error::{DataFusionError, Result};
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::expressions::{
is_avg_support_arg_type, is_correlation_support_arg_type,
is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type,
is_variance_support_arg_type, try_cast,
};
use crate::physical_plan::functions::{Signature, TypeSignature};
use crate::physical_plan::PhysicalExpr;
use crate::{
arrow::datatypes::Schema,
physical_plan::expressions::is_approx_percentile_cont_supported_arg_type,
};
use arrow::datatypes::DataType;
use std::ops::Deref;
use std::sync::Arc;
pub(crate) fn coerce_types(
agg_fun: &AggregateFunction,
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(input_types.to_vec())
}
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
get_min_max_result_type(input_types)
}
AggregateFunction::Sum => {
if !is_sum_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Avg => {
if !is_avg_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Variance => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Covariance => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Stddev => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
if !matches!(input_types[1], DataType::Float64) {
return Err(DataFusionError::Plan(format!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun, input_types[1]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
}
}
fn check_arg_count(
agg_fun: &AggregateFunction,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
agg_count,
input_types.len()
)));
}
}
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
types.len(),
input_types.len()
)));
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
if !ok {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
input_types.len()
)));
}
}
_ => {
return Err(DataFusionError::Internal(format!(
"Aggregate functions do not support this {:?}",
signature
)));
}
}
Ok(())
}
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
assert_eq!(input_types.len(), 1);
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
Ok(vec![dict_value_type.deref().clone()])
}
_ => Ok(input_types.to_vec()),
}
}
pub(crate) fn coerce_exprs(
agg_fun: &AggregateFunction,
input_exprs: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
signature: &Signature,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
if input_exprs.is_empty() {
return Ok(vec![]);
}
let input_types = input_exprs
.iter()
.map(|e| e.data_type(schema))
.collect::<Result<Vec<_>>>()?;
let coerced_types = coerce_types(agg_fun, &input_types, signature)?;
input_exprs
.iter()
.zip(coerced_types.into_iter())
.map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type))
.collect::<Result<Vec<_>>>()
}
#[cfg(test)]
mod tests {
use crate::physical_plan::aggregates;
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types;
use arrow::datatypes::DataType;
#[test]
fn test_aggregate_coerce_types() {
let fun = AggregateFunction::Min;
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string());
let fun = AggregateFunction::Sum;
let input_types = vec![DataType::Utf8];
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Sum does not support inputs of type Utf8.",
result.unwrap_err().to_string()
);
let fun = AggregateFunction::Avg;
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Avg does not support inputs of type Utf8.",
result.unwrap_err().to_string()
);
let funs = vec![
AggregateFunction::Count,
AggregateFunction::ArrayAgg,
AggregateFunction::ApproxDistinct,
AggregateFunction::Min,
AggregateFunction::Max,
];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Utf8],
];
for fun in funs {
for input_type in &input_types {
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}
let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Float32],
vec![DataType::Decimal(20, 3)],
];
for fun in funs {
for input_type in &input_types {
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}
let input_types = vec![
vec![DataType::Int8, DataType::Float64],
vec![DataType::Int16, DataType::Float64],
vec![DataType::Int32, DataType::Float64],
vec![DataType::Int64, DataType::Float64],
vec![DataType::UInt8, DataType::Float64],
vec![DataType::UInt16, DataType::Float64],
vec![DataType::UInt32, DataType::Float64],
vec![DataType::UInt64, DataType::Float64],
vec![DataType::Float32, DataType::Float64],
vec![DataType::Float64, DataType::Float64],
];
for input_type in &input_types {
let signature =
aggregates::signature(&AggregateFunction::ApproxPercentileCont);
let result = coerce_types(
&AggregateFunction::ApproxPercentileCont,
input_type,
&signature,
);
assert_eq!(*input_type, result.unwrap());
}
}
}