use std::fmt::{Display, Formatter};
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::Result;
use datafusion_common::{exec_err, ScalarValue};
use datafusion_expr::ColumnarValue;
pub(crate) enum TrimType {
Left,
Right,
Both,
}
impl Display for TrimType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TrimType::Left => write!(f, "ltrim"),
TrimType::Right => write!(f, "rtrim"),
TrimType::Both => write!(f, "btrim"),
}
}
}
pub(crate) fn general_trim<T: OffsetSizeTrait>(
args: &[ArrayRef],
trim_type: TrimType,
) -> Result<ArrayRef> {
let func = match trim_type {
TrimType::Left => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_start_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Right => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(input, pattern.as_ref())
},
TrimType::Both => |input, pattern: &str| {
let pattern = pattern.chars().collect::<Vec<char>>();
str::trim_end_matches::<&[char]>(
str::trim_start_matches::<&[char]>(input, pattern.as_ref()),
pattern.as_ref(),
)
},
};
let string_array = as_generic_string_array::<T>(&args[0])?;
match args.len() {
1 => {
let result = string_array
.iter()
.map(|string| string.map(|string: &str| func(string, " ")))
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
}
2 => {
let characters_array = as_generic_string_array::<T>(&args[1])?;
let result = string_array
.iter()
.zip(characters_array.iter())
.map(|(string, characters)| match (string, characters) {
(Some(string), Some(characters)) => Some(func(string, characters)),
_ => None,
})
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!(
"{trim_type} was called with {other} arguments. It requires at least 1 and at most 2."
)
}
}
}
pub(crate) fn unary_string_function<'a, T, O, F, R>(
args: &[&'a dyn Array],
op: F,
name: &str,
) -> Result<GenericStringArray<O>>
where
R: AsRef<str>,
O: OffsetSizeTrait,
T: OffsetSizeTrait,
F: Fn(&'a str) -> R,
{
if args.len() != 1 {
return exec_err!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
name
);
}
let string_array = as_generic_string_array::<T>(args[0])?;
Ok(string_array.iter().map(|string| string.map(&op)).collect())
}
pub(crate) fn handle<'a, F, R>(
args: &'a [ColumnarValue],
op: F,
name: &str,
) -> Result<ColumnarValue>
where
R: AsRef<str>,
F: Fn(&'a str) -> R,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i32,
i32,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
DataType::LargeUtf8 => {
Ok(ColumnarValue::Array(Arc::new(unary_string_function::<
i64,
i64,
_,
_,
>(
&[a.as_ref()], op, name
)?)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x).as_ref().to_string());
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
}
}