use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::{ArrowNativeType, FieldRef};
use arrow::{
array::ArrowNativeTypeOp,
compute::SortOptions,
datatypes::{
DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
ToByteSlice,
},
};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}
#[deprecated(since = "44.0.0", note = "use PrimitiveArray::with_datatype")]
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
_ => array,
};
Ok(array)
}
pub fn ordering_fields(
ordering_req: &LexOrdering,
data_types: &[DataType],
) -> Vec<FieldRef> {
ordering_req
.iter()
.zip(data_types.iter())
.map(|(sort_expr, dtype)| {
Field::new(
sort_expr.expr.to_string().as_str(),
dtype.clone(),
true,
)
})
.map(Arc::new)
.collect()
}
pub fn get_sort_options(ordering_req: &LexOrdering) -> Vec<SortOptions> {
ordering_req.iter().map(|item| item.options).collect()
}
#[derive(Copy, Clone, Debug)]
pub struct Hashable<T>(pub T);
impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}
impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}
impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
pub struct DecimalAverager<T: DecimalType> {
sum_mul: T::Native,
target_mul: T::Native,
target_precision: u8,
}
impl<T: DecimalType> DecimalAverager<T> {
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(sum_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute sum_mul in DecimalAverager".to_string(),
))?;
let target_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(target_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute target_mul in DecimalAverager".to_string(),
))?;
if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_precision,
})
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
#[inline(always)]
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
let new_value = value.div_wrapping(count);
let validate =
T::validate_decimal_precision(new_value, self.target_precision);
if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}