use std::sync::Arc;
use super::ColumnarValue;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::binary::{eq_decimal, eq_decimal_scalar};
use crate::scalar::ScalarValue;
use arrow::array::Array;
use arrow::array::*;
use arrow::compute::kernels::boolean::nullif;
use arrow::compute::kernels::comparison::{
eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar,
};
use arrow::datatypes::{DataType, TimeUnit};
macro_rules! compute_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?) as ArrayRef)
}};
}
macro_rules! primitive_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for NULLIF/primitive/boolean operator",
other
))),
}
}};
}
pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but NULLIF takes exactly two args",
args.len(),
)));
}
let (lhs, rhs) = (&args[0], &args[1]);
match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?;
let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
Ok(ColumnarValue::Array(array))
}
(ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
let cond_array = binary_array_op!(lhs, rhs, eq)?;
let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
Ok(ColumnarValue::Array(array))
}
_ => Err(DataFusionError::NotImplemented(
"nullif does not support a literal as first argument".to_string(),
)),
}
}
pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::from_slice::FromSlice;
#[test]
fn nullif_int32() -> Result<()> {
let a = Int32Array::from(vec![
Some(1),
Some(2),
None,
None,
Some(3),
None,
None,
Some(4),
Some(5),
]);
let a = ColumnarValue::Array(Arc::new(a));
let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);
let expected = Arc::new(Int32Array::from(vec![
Some(1),
None,
None,
None,
Some(3),
None,
None,
Some(4),
Some(5),
])) as ArrayRef;
assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}
#[test]
fn nullif_int32_nonulls() -> Result<()> {
let a = Int32Array::from_slice(&[1, 3, 10, 7, 8, 1, 2, 4, 5]);
let a = ColumnarValue::Array(Arc::new(a));
let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);
let expected = Arc::new(Int32Array::from(vec![
None,
Some(3),
Some(10),
Some(7),
Some(8),
None,
Some(2),
Some(4),
Some(5),
])) as ArrayRef;
assert_eq!(expected.as_ref(), result.as_ref());
Ok(())
}
}