use arrow::array::ArrayRef;
use arrow::array::NullArray;
use arrow::compute::{kernels, CastOptions};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
use datafusion_common::{internal_err, Result, ScalarValue};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub enum ColumnarValue {
Array(ArrayRef),
Scalar(ScalarValue),
}
impl From<ArrayRef> for ColumnarValue {
fn from(value: ArrayRef) -> Self {
ColumnarValue::Array(value)
}
}
impl From<ScalarValue> for ColumnarValue {
fn from(value: ScalarValue) -> Self {
ColumnarValue::Scalar(value)
}
}
impl ColumnarValue {
pub fn data_type(&self) -> DataType {
match self {
ColumnarValue::Array(array_value) => array_value.data_type().clone(),
ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(),
}
}
pub fn into_array(self, num_rows: usize) -> Result<ArrayRef> {
Ok(match self {
ColumnarValue::Array(array) => array,
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?,
})
}
pub fn create_null_array(num_rows: usize) -> Self {
ColumnarValue::Array(Arc::new(NullArray::new(num_rows)))
}
pub fn values_to_arrays(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> {
if args.is_empty() {
return Ok(vec![]);
}
let mut array_len = None;
for arg in args {
array_len = match (arg, array_len) {
(ColumnarValue::Array(a), None) => Some(a.len()),
(ColumnarValue::Array(a), Some(array_len)) => {
if array_len == a.len() {
Some(array_len)
} else {
return internal_err!(
"Arguments has mixed length. Expected length: {array_len}, found length: {}", a.len()
);
}
}
(ColumnarValue::Scalar(_), array_len) => array_len,
}
}
let inferred_length = array_len.unwrap_or(1);
let args = args
.iter()
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;
Ok(args)
}
pub fn cast_to(
&self,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match self {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
let scalar_array =
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
if let ScalarValue::Float64(Some(float_ts)) = scalar {
ScalarValue::Int64(Some(
(float_ts * 1_000_000_000_f64).trunc() as i64,
))
.to_array()?
} else {
scalar.to_array()?
}
} else {
scalar.to_array()?
};
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn values_to_arrays() {
let cases = vec![
TestCase {
input: vec![],
expected: vec![],
},
TestCase {
input: vec![ColumnarValue::Array(make_array(1, 3))],
expected: vec![make_array(1, 3)],
},
TestCase {
input: vec![
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Array(make_array(2, 3)),
],
expected: vec![make_array(1, 3), make_array(2, 3)],
},
TestCase {
input: vec![
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
],
expected: vec![
make_array(1, 3),
make_array(100, 3), ],
},
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
],
expected: vec![
make_array(100, 3), make_array(1, 3),
],
},
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(200))),
],
expected: vec![
make_array(100, 3), make_array(1, 3),
make_array(200, 3), ],
},
];
for case in cases {
case.run();
}
}
#[test]
#[should_panic(
expected = "Arguments has mixed length. Expected length: 3, found length: 4"
)]
fn values_to_arrays_mixed_length() {
ColumnarValue::values_to_arrays(&[
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Array(make_array(2, 4)),
])
.unwrap();
}
#[test]
#[should_panic(
expected = "Arguments has mixed length. Expected length: 3, found length: 7"
)]
fn values_to_arrays_mixed_length_and_scalar() {
ColumnarValue::values_to_arrays(&[
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(2, 7)),
])
.unwrap();
}
struct TestCase {
input: Vec<ColumnarValue>,
expected: Vec<ArrayRef>,
}
impl TestCase {
fn run(self) {
let Self { input, expected } = self;
assert_eq!(
ColumnarValue::values_to_arrays(&input).unwrap(),
expected,
"\ninput: {input:?}\nexpected: {expected:?}"
);
}
}
fn make_array(val: i32, len: usize) -> ArrayRef {
Arc::new(arrow::array::Int32Array::from(vec![val; len]))
}
}