use std::cmp::Ordering;
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
MIN_DECIMAL128_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::ScalarValue;
pub fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let lit_data_type = lit_value.data_type();
if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
return None;
}
if lit_value.is_null() {
return ScalarValue::try_from(target_type).ok();
}
try_cast_numeric_literal(lit_value, target_type)
.or_else(|| try_cast_string_literal(lit_value, target_type))
.or_else(|| try_cast_dictionary(lit_value, target_type))
.or_else(|| try_cast_binary(lit_value, target_type))
}
pub fn is_supported_type(data_type: &DataType) -> bool {
is_supported_numeric_type(data_type)
|| is_supported_string_type(data_type)
|| is_supported_dictionary_type(data_type)
|| is_supported_binary_type(data_type)
}
fn is_supported_numeric_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
| DataType::Timestamp(_, _)
)
}
fn is_supported_string_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
)
}
fn is_supported_dictionary_type(data_type: &DataType) -> bool {
matches!(data_type,
DataType::Dictionary(_, inner) if is_supported_type(inner))
}
fn is_supported_binary_type(data_type: &DataType) -> bool {
matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_))
}
fn try_cast_numeric_literal(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let lit_data_type = lit_value.data_type();
if !is_supported_numeric_type(&lit_data_type)
|| !is_supported_numeric_type(target_type)
{
return None;
}
let mul = match target_type {
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
_ => return None,
};
let (target_min, target_max) = match target_type {
DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize],
),
_ => return None,
};
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
Some(*v / (lit_scale_mul / mul))
} else {
None
}
}
_ => None,
};
match lit_value_target_type {
None => None,
Some(value) => {
if value >= target_min && value <= target_max {
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
let value = cast_between_timestamp(
&lit_data_type,
&DataType::Timestamp(TimeUnit::Second, tz.clone()),
value,
);
ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
let value = cast_between_timestamp(
&lit_data_type,
&DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
value,
);
ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
let value = cast_between_timestamp(
&lit_data_type,
&DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
value,
);
ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
let value = cast_between_timestamp(
&lit_data_type,
&DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
value,
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
_ => {
return None;
}
};
Some(result_scalar)
} else {
None
}
}
}
}
fn try_cast_string_literal(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
let scalar_value = match target_type {
DataType::Utf8 => ScalarValue::Utf8(string_value),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
DataType::Utf8View => ScalarValue::Utf8View(string_value),
_ => return None,
};
Some(scalar_value)
}
fn try_cast_dictionary(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let lit_value_type = lit_value.data_type();
let result_scalar = match (lit_value, target_type) {
(ScalarValue::Dictionary(_, inner_value), _)
if inner_value.data_type() == *target_type =>
{
(**inner_value).clone()
}
(_, DataType::Dictionary(index_type, inner_type))
if **inner_type == lit_value_type =>
{
ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone()))
}
_ => {
return None;
}
};
Some(result_scalar)
}
fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option<i64> {
let value = value as i64;
let from_scale = match from {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};
let to_scale = match to {
DataType::Timestamp(TimeUnit::Second, _) => 1,
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
_ => return Some(value),
};
match from_scale.cmp(&to_scale) {
Ordering::Less => value.checked_mul(to_scale / from_scale),
Ordering::Greater => Some(value / (from_scale / to_scale)),
Ordering::Equal => Some(value),
}
}
fn try_cast_binary(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
match (lit_value, target_type) {
(ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n))
if v.len() == *n as usize =>
{
Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone())))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::{Field, Fields, TimeUnit};
use std::sync::Arc;
#[derive(Debug, Clone)]
enum ExpectedCast {
Value(ScalarValue),
NoValue,
}
fn expect_cast(
literal: ScalarValue,
target_type: DataType,
expected_result: ExpectedCast,
) {
let actual_value = try_cast_literal_to_type(&literal, &target_type);
println!("expect_cast: ");
println!(" {literal:?} --> {target_type:?}");
println!(" expected_result: {expected_result:?}");
println!(" actual_result: {actual_value:?}");
match expected_result {
ExpectedCast::Value(expected_value) => {
let actual_value =
actual_value.expect("Expected cast value but got None");
assert_eq!(actual_value, expected_value);
let literal_array = literal
.to_array_of_size(1)
.expect("Failed to convert to array of size");
let expected_array = expected_value
.to_array_of_size(1)
.expect("Failed to convert to array of size");
let cast_array = cast_with_options(
&literal_array,
&target_type,
&CastOptions::default(),
)
.expect("Expected to be cast array with arrow cast kernel");
assert_eq!(
&expected_array, &cast_array,
"Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}"
);
if let (
DataType::Timestamp(left_unit, left_tz),
DataType::Timestamp(right_unit, right_tz),
) = (actual_value.data_type(), expected_value.data_type())
{
assert_eq!(left_unit, right_unit);
assert_eq!(left_tz, right_tz);
}
}
ExpectedCast::NoValue => {
assert!(
actual_value.is_none(),
"Expected no cast value, but got {actual_value:?}"
);
}
}
}
#[test]
fn test_try_cast_to_type_nulls() {
let scalars = vec![
ScalarValue::Int8(None),
ScalarValue::Int16(None),
ScalarValue::Int32(None),
ScalarValue::Int64(None),
ScalarValue::UInt8(None),
ScalarValue::UInt16(None),
ScalarValue::UInt32(None),
ScalarValue::UInt64(None),
ScalarValue::Decimal128(None, 3, 0),
ScalarValue::Decimal128(None, 8, 2),
ScalarValue::Utf8(None),
ScalarValue::LargeUtf8(None),
];
for s1 in &scalars {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
}
#[test]
fn test_try_cast_to_type_int_in_range() {
let scalars = vec![
ScalarValue::Int8(Some(123)),
ScalarValue::Int16(Some(123)),
ScalarValue::Int32(Some(123)),
ScalarValue::Int64(Some(123)),
ScalarValue::UInt8(Some(123)),
ScalarValue::UInt16(Some(123)),
ScalarValue::UInt32(Some(123)),
ScalarValue::UInt64(Some(123)),
ScalarValue::Decimal128(Some(123), 3, 0),
ScalarValue::Decimal128(Some(12300), 8, 2),
];
for s1 in &scalars {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
let max_i32 = ScalarValue::Int32(Some(i32::MAX));
expect_cast(
max_i32,
DataType::UInt64,
ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
);
let min_i32 = ScalarValue::Int32(Some(i32::MIN));
expect_cast(
min_i32,
DataType::Int64,
ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
);
let max_i64 = ScalarValue::Int64(Some(i64::MAX));
expect_cast(
max_i64,
DataType::UInt64,
ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
);
}
#[test]
fn test_try_cast_to_type_int_out_of_range() {
let min_i32 = ScalarValue::Int32(Some(i32::MIN));
let min_i64 = ScalarValue::Int64(Some(i64::MIN));
let max_i64 = ScalarValue::Int64(Some(i64::MAX));
let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
expect_cast(
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
DataType::Int64,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1),
DataType::Int64,
ExpectedCast::NoValue,
);
}
#[test]
fn test_try_decimal_cast_in_range() {
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(3, 0),
ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(8, 0),
ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(8, 5),
ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
);
}
#[test]
fn test_try_decimal_cast_out_of_range() {
expect_cast(
ScalarValue::Decimal128(Some(12345), 5, 2),
DataType::Decimal128(3, 0),
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Decimal128(Some(12300), 5, 2),
DataType::Decimal128(2, 0),
ExpectedCast::NoValue,
);
}
#[test]
fn test_try_cast_to_type_timestamps() {
for time_unit in [
TimeUnit::Second,
TimeUnit::Millisecond,
TimeUnit::Microsecond,
TimeUnit::Nanosecond,
] {
let utc = Some("+00:00".into());
let (lit_tz_none, lit_tz_utc) = match time_unit {
TimeUnit::Second => (
ScalarValue::TimestampSecond(Some(12345), None),
ScalarValue::TimestampSecond(Some(12345), utc),
),
TimeUnit::Millisecond => (
ScalarValue::TimestampMillisecond(Some(12345), None),
ScalarValue::TimestampMillisecond(Some(12345), utc),
),
TimeUnit::Microsecond => (
ScalarValue::TimestampMicrosecond(Some(12345), None),
ScalarValue::TimestampMicrosecond(Some(12345), utc),
),
TimeUnit::Nanosecond => (
ScalarValue::TimestampNanosecond(Some(12345), None),
ScalarValue::TimestampNanosecond(Some(12345), utc),
),
};
assert_eq!(lit_tz_none, lit_tz_utc);
let dt_tz_none = lit_tz_none.data_type();
let dt_tz_utc = lit_tz_utc.data_type();
expect_cast(
lit_tz_none.clone(),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
lit_tz_none.clone(),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
lit_tz_utc.clone(),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
DataType::Int64,
ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
);
expect_cast(
ScalarValue::Int64(Some(12345)),
dt_tz_none.clone(),
ExpectedCast::Value(lit_tz_none.clone()),
);
expect_cast(
ScalarValue::Int64(Some(12345)),
dt_tz_utc.clone(),
ExpectedCast::Value(lit_tz_utc.clone()),
);
expect_cast(
lit_tz_utc.clone(),
DataType::LargeUtf8,
ExpectedCast::NoValue,
);
}
}
#[test]
fn test_try_cast_to_type_unsupported() {
expect_cast(
ScalarValue::Int64(Some(12345)),
DataType::List(Arc::new(Field::new("f", DataType::Int32, true))),
ExpectedCast::NoValue,
);
}
#[test]
fn test_try_cast_literal_to_timestamp() {
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123456), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampMillisecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampNanosecond(Some(123000000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMicrosecond(Some(123000000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
assert_eq!(
new_scalar,
ScalarValue::TimestampMillisecond(Some(123000), None)
);
let new_scalar = try_cast_literal_to_type(
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
}
#[test]
fn test_try_cast_to_string_type() {
let scalars = vec![
ScalarValue::from("string"),
ScalarValue::LargeUtf8(Some("string".to_owned())),
];
for s1 in &scalars {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
}
#[test]
fn test_try_cast_to_dictionary_type() {
fn dictionary_type(t: DataType) -> DataType {
DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
}
fn dictionary_value(value: ScalarValue) -> ScalarValue {
ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
}
let scalars = vec![
ScalarValue::from("string"),
ScalarValue::LargeUtf8(Some("string".to_owned())),
];
for s in &scalars {
expect_cast(
s.clone(),
dictionary_type(s.data_type()),
ExpectedCast::Value(dictionary_value(s.clone())),
);
expect_cast(
dictionary_value(s.clone()),
s.data_type(),
ExpectedCast::Value(s.clone()),
)
}
}
#[test]
fn test_try_cast_to_fixed_size_binary() {
expect_cast(
ScalarValue::Binary(Some(vec![1, 2, 3])),
DataType::FixedSizeBinary(3),
ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
)
}
#[test]
fn test_numeric_boundary_values() {
expect_cast(
ScalarValue::Int8(Some(i8::MAX)),
DataType::UInt8,
ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))),
);
expect_cast(
ScalarValue::Int8(Some(i8::MIN)),
DataType::UInt8,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::UInt8(Some(u8::MAX)),
DataType::Int8,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Int32(Some(i32::MAX)),
DataType::Int64,
ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))),
);
expect_cast(
ScalarValue::Int64(Some(i64::MIN)),
DataType::UInt64,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::UInt32(Some(u32::MAX)),
DataType::Int32,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::UInt64(Some(u64::MAX)),
DataType::Int64,
ExpectedCast::NoValue,
);
}
#[test]
fn test_decimal_precision_limits() {
use arrow::datatypes::{
MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION,
};
expect_cast(
ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
DataType::Decimal128(5, 0),
ExpectedCast::Value(ScalarValue::Decimal128(
Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]),
5,
0,
)),
);
expect_cast(
ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0),
DataType::Decimal128(5, 0),
ExpectedCast::Value(ScalarValue::Decimal128(
Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]),
5,
0,
)),
);
expect_cast(
ScalarValue::Decimal128(Some(123), 3, 0),
DataType::Decimal128(5, 2),
ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)),
);
expect_cast(
ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0),
DataType::Decimal128(3, 0),
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Decimal128(Some(12345), 5, 3), DataType::Int32,
ExpectedCast::NoValue, );
expect_cast(
ScalarValue::Decimal128(Some(12345), 5, 2), DataType::Decimal128(3, 0), ExpectedCast::NoValue,
);
}
#[test]
fn test_timestamp_overflow_scenarios() {
let max_seconds = i64::MAX / 1_000_000_000;
expect_cast(
ScalarValue::TimestampSecond(Some(max_seconds), None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
ExpectedCast::Value(ScalarValue::TimestampNanosecond(
Some(max_seconds * 1_000_000_000),
None,
)),
);
expect_cast(
ScalarValue::TimestampNanosecond(Some(i64::MAX), None),
DataType::Timestamp(TimeUnit::Second, None),
ExpectedCast::Value(ScalarValue::TimestampSecond(
Some(i64::MAX / 1_000_000_000),
None,
)),
);
expect_cast(
ScalarValue::TimestampNanosecond(Some(1), None),
DataType::Timestamp(TimeUnit::Second, None),
ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)),
);
expect_cast(
ScalarValue::TimestampMicrosecond(Some(999), None),
DataType::Timestamp(TimeUnit::Millisecond, None),
ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)),
);
}
#[test]
fn test_string_view() {
expect_cast(
ScalarValue::Utf8View(Some("test".to_string())),
DataType::Utf8,
ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))),
);
expect_cast(
ScalarValue::Utf8View(Some("test".to_string())),
DataType::LargeUtf8,
ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))),
);
expect_cast(
ScalarValue::Utf8(Some("hello".to_string())),
DataType::Utf8View,
ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))),
);
expect_cast(
ScalarValue::LargeUtf8(Some("world".to_string())),
DataType::Utf8View,
ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))),
);
expect_cast(
ScalarValue::Utf8(Some("".to_string())),
DataType::Utf8View,
ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))),
);
let large_string = "x".repeat(1000);
expect_cast(
ScalarValue::LargeUtf8(Some(large_string.clone())),
DataType::Utf8View,
ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))),
);
}
#[test]
fn test_binary_size_edge_cases() {
expect_cast(
ScalarValue::Binary(Some(vec![1, 2])),
DataType::FixedSizeBinary(3),
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Binary(Some(vec![1, 2, 3, 4])),
DataType::FixedSizeBinary(3),
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Binary(Some(vec![])),
DataType::FixedSizeBinary(0),
ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))),
);
expect_cast(
ScalarValue::Binary(Some(vec![1, 2, 3])),
DataType::FixedSizeBinary(3),
ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))),
);
expect_cast(
ScalarValue::Binary(Some(vec![42])),
DataType::FixedSizeBinary(1),
ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))),
);
}
#[test]
fn test_dictionary_index_types() {
let string_value = ScalarValue::Utf8(Some("test".to_string()));
let dict_int8 =
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
expect_cast(
string_value.clone(),
dict_int8,
ExpectedCast::Value(ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(string_value.clone()),
)),
);
let dict_int16 =
DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8));
expect_cast(
string_value.clone(),
dict_int16,
ExpectedCast::Value(ScalarValue::Dictionary(
Box::new(DataType::Int16),
Box::new(string_value.clone()),
)),
);
let dict_int64 =
DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8));
expect_cast(
string_value.clone(),
dict_int64,
ExpectedCast::Value(ScalarValue::Dictionary(
Box::new(DataType::Int64),
Box::new(string_value.clone()),
)),
);
let dict_value = ScalarValue::Dictionary(
Box::new(DataType::Int32),
Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
);
expect_cast(
dict_value,
DataType::LargeUtf8,
ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))),
);
}
#[test]
fn test_type_support_functions() {
assert!(is_supported_numeric_type(&DataType::Int8));
assert!(is_supported_numeric_type(&DataType::UInt64));
assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2)));
assert!(is_supported_numeric_type(&DataType::Timestamp(
TimeUnit::Nanosecond,
None
)));
assert!(!is_supported_numeric_type(&DataType::Float32));
assert!(!is_supported_numeric_type(&DataType::Float64));
assert!(is_supported_string_type(&DataType::Utf8));
assert!(is_supported_string_type(&DataType::LargeUtf8));
assert!(is_supported_string_type(&DataType::Utf8View));
assert!(!is_supported_string_type(&DataType::Binary));
assert!(is_supported_binary_type(&DataType::Binary));
assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10)));
assert!(!is_supported_binary_type(&DataType::Utf8));
assert!(is_supported_dictionary_type(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8)
)));
assert!(is_supported_dictionary_type(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Int64)
)));
assert!(!is_supported_dictionary_type(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::List(Arc::new(Field::new(
"item",
DataType::Int32,
true
))))
)));
assert!(is_supported_type(&DataType::Int32));
assert!(is_supported_type(&DataType::Utf8));
assert!(is_supported_type(&DataType::Binary));
assert!(is_supported_type(&DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8)
)));
assert!(!is_supported_type(&DataType::List(Arc::new(Field::new(
"item",
DataType::Int32,
true
)))));
assert!(!is_supported_type(&DataType::Struct(Fields::empty())));
}
#[test]
fn test_error_conditions() {
expect_cast(
ScalarValue::Float32(Some(1.5)),
DataType::Int32,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Int32(Some(123)),
DataType::Float64,
ExpectedCast::NoValue,
);
expect_cast(
ScalarValue::Float64(Some(1.5)),
DataType::Float32,
ExpectedCast::NoValue,
);
let list_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
expect_cast(
ScalarValue::Int32(Some(123)),
list_type,
ExpectedCast::NoValue,
);
let bad_dict = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::List(Arc::new(Field::new(
"item",
DataType::Int32,
true,
)))),
);
expect_cast(
ScalarValue::Int32(Some(123)),
bad_dict,
ExpectedCast::NoValue,
);
}
}