use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use datafusion_physical_expr::functions::Hint;
use std::sync::Arc;
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
_ => {
return datafusion_common::exec_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
**value_type
);
}
},
data_type => {
return datafusion_common::exec_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
data_type
);
}
})
}
};
}
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
pub(super) fn make_scalar_function<F>(
inner: F,
hints: Vec<Hint>,
) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Arc::new(move |args: &[ColumnarValue]| {
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});
let is_scalar = len.is_none();
let inferred_length = len.unwrap_or(1);
let args = args
.iter()
.zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
.map(|(arg, hint)| {
let expansion_len = match hint {
Hint::AcceptsSingular => 1,
Hint::Pad => inferred_length,
};
arg.clone().into_array(expansion_len)
})
.collect::<datafusion_common::Result<Vec<_>>>()?;
let result = (inner)(&args);
if is_scalar {
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
})
}
#[cfg(test)]
pub mod test {
macro_rules! test_function {
($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
let func = $FUNC;
let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
let return_type = func.return_type(&type_array);
match expected {
Ok(expected) => {
assert_eq!(return_type.is_ok(), true);
assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE);
let result = func.invoke($ARGS);
assert_eq!(result.is_ok(), true);
let len = $ARGS
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});
let inferred_length = len.unwrap_or(1);
let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array");
let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
match expected {
Some(v) => assert_eq!(result.value(0), v),
None => assert!(result.is_null(0)),
};
}
Err(expected_error) => {
if return_type.is_err() {
match return_type {
Ok(_) => assert!(false, "expected error"),
Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
}
}
else {
match func.invoke($ARGS) {
Ok(_) => assert!(false, "expected error"),
Err(error) => {
assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
}
}
}
}
};
};
}
#[allow(unused_imports)]
pub(crate) use test_function;
}