use std::any::type_name;
use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
use arrow::compute;
use hashbrown::HashMap;
use lazy_static::lazy_static;
use regex::Regex;
macro_rules! downcast_string_arg {
($ARG:expr, $NAME:expr, $T:ident) => {{
$ARG.as_any()
.downcast_ref::<GenericStringArray<T>>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
type_name::<GenericStringArray<T>>()
))
})?
}};
}
pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_arg!(args[2], "flags", T));
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
}
other => Err(DataFusionError::Internal(format!(
"regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
other
))),
}
}
fn regex_replace_posix_groups(replacement: &str) -> String {
lazy_static! {
static ref CAPTURE_GROUPS_RE: Regex = Regex::new(r"(\\)(\d*)").unwrap();
}
CAPTURE_GROUPS_RE
.replace_all(replacement, "$${$2}")
.into_owned()
}
pub fn regexp_replace<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut patterns: HashMap<String, Regex> = HashMap::new();
match args.len() {
3 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);
let re = match patterns.get(pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::Execution(err.to_string())),
}
}
};
Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let flags_array = downcast_string_arg!(args[3], "flags", T);
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(format!("(?{}){}", flags.to_string().replace("g", ""), pattern), true)
} else {
(format!("(?{}){}", flags, pattern), false)
};
let re = match patterns.get(&pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern, re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::Execution(err.to_string())),
}
}
};
Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
})).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => Err(DataFusionError::Internal(format!(
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
other
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::from_slice::FromSlice;
use arrow::array::*;
#[test]
fn test_case_sensitive_regexp_match() {
let values = StringArray::from_slice(&["abc"; 5]);
let patterns =
StringArray::from_slice(&["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
expected_builder.append(false).unwrap();
let expected = expected_builder.finish();
let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_case_insensitive_regexp_match() {
let values = StringArray::from_slice(&["abc"; 5]);
let patterns =
StringArray::from_slice(&["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let flags = StringArray::from_slice(&["i"; 5]);
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("a").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.values().append_value("b").unwrap();
expected_builder.append(true).unwrap();
expected_builder.append(false).unwrap();
let expected = expected_builder.finish();
let re =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
}