use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray};
use arrow::compute::try_binary;
use arrow::datatypes::{DataType, Int64Type};
use arrow::error::ArrowError;
use std::any::Any;
use std::mem::swap;
use std::sync::Arc;
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_macros::user_doc;
#[user_doc(
doc_section(label = "Math Functions"),
description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
syntax_example = "gcd(expression_x, expression_y)",
standard_argument(name = "expression_x", prefix = "First numeric"),
standard_argument(name = "expression_y", prefix = "Second numeric")
)]
#[derive(Debug)]
pub struct GcdFunc {
signature: Signature,
}
impl Default for GcdFunc {
fn default() -> Self {
Self::new()
}
}
impl GcdFunc {
pub fn new() -> Self {
Self {
signature: Signature::uniform(
2,
vec![DataType::Int64],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for GcdFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"gcd"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
internal_datafusion_err!("Expected 2 arguments for function gcd")
})?;
match args {
[ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
compute_gcd_for_arrays(&a, &b)
}
[ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
match (a, b) {
(Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
Some(compute_gcd(a, b)?),
))),
_ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
}
}
[ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
compute_gcd_with_scalar(&a, b)
}
[ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Array(b)] => {
compute_gcd_with_scalar(&b, a)
}
_ => exec_err!("Unsupported argument types for function gcd"),
}
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
let a = a.as_primitive::<Int64Type>();
let b = b.as_primitive::<Int64Type>();
try_binary(a, b, compute_gcd)
.map(|arr: PrimitiveArray<Int64Type>| {
ColumnarValue::Array(Arc::new(arr) as ArrayRef)
})
.map_err(Into::into) }
fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
match scalar {
Some(scalar_value) => {
let result: Result<Int64Array> = arr
.as_primitive::<Int64Type>()
.iter()
.map(|val| match val {
Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
_ => Ok(None),
})
.collect();
result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
}
None => Ok(ColumnarValue::Array(new_null_array(
&DataType::Int64,
arr.len(),
))),
}
}
pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
if a == 0 {
return b;
}
if b == 0 {
return a;
}
let shift = (a | b).trailing_zeros();
a >>= a.trailing_zeros();
loop {
b >>= b.trailing_zeros();
if a > b {
swap(&mut a, &mut b);
}
b -= a;
if b == 0 {
return a << shift;
}
}
}
pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
let a = x.unsigned_abs();
let b = y.unsigned_abs();
let r = unsigned_gcd(a, b);
r.try_into().map_err(|_| {
ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
})
}