use crate::type_coercion::{is_date, is_decimal, is_interval, is_numeric, is_timestamp};
use crate::Operator;
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
pub fn binary_operator_data_type(
lhs_type: &DataType,
op: &Operator,
rhs_type: &DataType,
) -> Result<DataType> {
let result_type = if !any_decimal(lhs_type, rhs_type) {
coerce_types(lhs_type, op, rhs_type)?
} else {
let (coerced_lhs_type, coerced_rhs_type) =
math_decimal_coercion(lhs_type, rhs_type);
let lhs_type = if let Some(lhs_type) = coerced_lhs_type {
lhs_type
} else {
lhs_type.clone()
};
let rhs_type = if let Some(rhs_type) = coerced_rhs_type {
rhs_type
} else {
rhs_type.clone()
};
match op {
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulo => decimal_op_mathematics_type(op, &lhs_type, &rhs_type)
.or_else(|| coerce_types(&lhs_type, op, &rhs_type).ok())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Could not get return type for {:?} between {:?} and {:?}",
op, lhs_type, rhs_type
))
})?,
_ => coerce_types(&lhs_type, op, &rhs_type)?,
}
};
match op {
Operator::Eq
| Operator::NotEq
| Operator::And
| Operator::Or
| Operator::Lt
| Operator::Gt
| Operator::GtEq
| Operator::LtEq
| Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
| Operator::RegexNotIMatch
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom => Ok(DataType::Boolean),
Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseXor
| Operator::BitwiseShiftLeft
| Operator::BitwiseShiftRight => Ok(result_type),
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulo => Ok(result_type),
Operator::StringConcat => Ok(result_type),
}
}
pub fn coerce_types(
lhs_type: &DataType,
op: &Operator,
rhs_type: &DataType,
) -> Result<DataType> {
let result = match op {
Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseXor
| Operator::BitwiseShiftRight
| Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type),
Operator::And | Operator::Or => match (lhs_type, rhs_type) {
(DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean),
(DataType::Null, DataType::Null) => Some(DataType::Boolean),
(DataType::Boolean, DataType::Null) | (DataType::Null, DataType::Boolean) => {
Some(DataType::Boolean)
}
_ => None,
},
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::Gt
| Operator::GtEq
| Operator::LtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom => comparison_coercion(lhs_type, rhs_type),
Operator::Plus | Operator::Minus
if (is_date(lhs_type)
|| is_date(rhs_type)
|| is_timestamp(lhs_type)
|| is_timestamp(rhs_type)
|| is_interval(lhs_type)
|| is_interval(rhs_type))
&& (!is_interval(lhs_type)
|| !is_timestamp(rhs_type)
|| *op != Operator::Minus) =>
{
temporal_add_sub_coercion(lhs_type, rhs_type, op)
}
Operator::Plus
| Operator::Minus
| Operator::Modulo
| Operator::Divide
| Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type),
Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
| Operator::RegexNotIMatch => regex_coercion(lhs_type, rhs_type),
Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
};
match result {
None => Err(DataFusionError::Plan(
format!(
"{lhs_type:?} {op} {rhs_type:?} can't be evaluated because there isn't a common type to coerce the types to"
),
)),
Some(t) => Ok(t)
}
}
pub fn math_decimal_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> (Option<DataType>, Option<DataType>) {
use arrow::datatypes::DataType::*;
if both_decimal(lhs_type, rhs_type) {
return (None, None);
}
match (lhs_type, rhs_type) {
(Null, dec_type @ Decimal128(_, _)) => (Some(dec_type.clone()), None),
(dec_type @ Decimal128(_, _), Null) => (None, Some(dec_type.clone())),
(Dictionary(key_type, value_type), _) => {
let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type);
let lhs_type = value_type
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
(lhs_type, rhs_type)
}
(_, Dictionary(key_type, value_type)) => {
let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type);
let rhs_type = value_type
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)));
(lhs_type, rhs_type)
}
(Decimal128(_, _), Float32 | Float64) => (Some(Float64), Some(Float64)),
(Float32 | Float64, Decimal128(_, _)) => (Some(Float64), Some(Float64)),
(Decimal128(_, _), _) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
(None, converted_decimal_type)
}
(_, Decimal128(_, _)) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type);
(converted_decimal_type, None)
}
_ => (None, None),
}
}
fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !both_numeric_or_null_and_numeric(left_type, right_type) {
return None;
}
if left_type == right_type {
return Some(left_type.clone());
}
match (left_type, right_type) {
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
| (Int16, UInt16)
| (UInt16, Int8)
| (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
}
pub fn temporal_add_sub_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
op: &Operator,
) -> Option<DataType> {
match (lhs_type, rhs_type, op) {
(lhs, rhs, _) if is_interval(lhs) && (is_date(rhs) || is_timestamp(rhs)) => {
Some(rhs.clone())
}
(lhs, rhs, _) if is_interval(rhs) && (is_date(lhs) || is_timestamp(lhs)) => {
Some(lhs.clone())
}
(lhs, rhs, Operator::Minus) if is_timestamp(lhs) && is_timestamp(rhs) => {
handle_timestamp_minus(lhs, rhs)
}
(lhs, rhs, _) if is_interval(lhs) && is_interval(rhs) => {
handle_interval_addition(lhs, rhs)
}
(lhs, rhs, Operator::Minus)
if (is_date(lhs) || is_timestamp(lhs))
&& (is_date(rhs) || is_timestamp(rhs)) =>
{
temporal_coercion(lhs, rhs)
}
_ => None,
}
}
fn handle_interval_addition(lhs: &DataType, rhs: &DataType) -> Option<DataType> {
match (lhs, rhs) {
(
DataType::Interval(IntervalUnit::YearMonth),
DataType::Interval(IntervalUnit::YearMonth),
) => Some(DataType::Interval(IntervalUnit::YearMonth)),
(
DataType::Interval(IntervalUnit::DayTime),
DataType::Interval(IntervalUnit::DayTime),
) => Some(DataType::Interval(IntervalUnit::DayTime)),
(_, _) => Some(DataType::Interval(IntervalUnit::MonthDayNano)),
}
}
fn handle_timestamp_minus(lhs: &DataType, rhs: &DataType) -> Option<DataType> {
match (lhs, rhs) {
(
DataType::Timestamp(TimeUnit::Second, _),
DataType::Timestamp(TimeUnit::Second, _),
)
| (
DataType::Timestamp(TimeUnit::Millisecond, _),
DataType::Timestamp(TimeUnit::Millisecond, _),
) => Some(DataType::Interval(IntervalUnit::DayTime)),
(
DataType::Timestamp(TimeUnit::Microsecond, _),
DataType::Timestamp(TimeUnit::Microsecond, _),
)
| (
DataType::Timestamp(TimeUnit::Nanosecond, _),
DataType::Timestamp(TimeUnit::Nanosecond, _),
) => Some(DataType::Interval(IntervalUnit::MonthDayNano)),
(_, _) => None,
}
}
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, _) if DataType::is_numeric(rhs_type) => Some(Utf8),
(LargeUtf8, _) if DataType::is_numeric(rhs_type) => Some(LargeUtf8),
(_, Utf8) if DataType::is_numeric(lhs_type) => Some(Utf8),
(_, LargeUtf8) if DataType::is_numeric(lhs_type) => Some(LargeUtf8),
_ => None,
}
}
fn comparison_binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
return None;
};
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}
match (lhs_type, rhs_type) {
(d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _)
| (_, Int64)
| (UInt64, Int8)
| (Int8, UInt64)
| (UInt64, Int16)
| (Int16, UInt64)
| (UInt64, Int32)
| (Int32, UInt64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
| (Int16, UInt16)
| (UInt16, Int8)
| (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}
fn get_comparison_common_decimal_type(
decimal_type: &DataType,
other_type: &DataType,
) -> Option<DataType> {
let other_decimal_type = &match other_type {
DataType::Int8 => DataType::Decimal128(3, 0),
DataType::Int16 => DataType::Decimal128(5, 0),
DataType::Int32 => DataType::Decimal128(10, 0),
DataType::Int64 => DataType::Decimal128(20, 0),
DataType::Float32 => DataType::Decimal128(14, 7),
DataType::Float64 => DataType::Decimal128(30, 15),
_ => {
return None;
}
};
match (decimal_type, &other_decimal_type) {
(d1 @ DataType::Decimal128(_, _), d2 @ DataType::Decimal128(_, _)) => {
get_wider_decimal_type(d1, d2)
}
_ => None,
}
}
fn get_wider_decimal_type(
lhs_decimal_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
match (lhs_decimal_type, rhs_type) {
(DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => {
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal_type((range + s) as u8, s))
}
(_, _) => None,
}
}
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
match numeric_type {
DataType::Int8 => Some(DataType::Decimal128(3, 0)),
DataType::Int16 => Some(DataType::Decimal128(5, 0)),
DataType::Int32 => Some(DataType::Decimal128(10, 0)),
DataType::Int64 => Some(DataType::Decimal128(20, 0)),
DataType::Float32 => Some(DataType::Decimal128(14, 7)),
DataType::Float64 => Some(DataType::Decimal128(30, 15)),
_ => None,
}
}
fn mathematics_numerical_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) {
return None;
};
if lhs_type == rhs_type
&& !(matches!(lhs_type, DataType::Dictionary(_, _))
|| matches!(rhs_type, DataType::Dictionary(_, _)))
{
return Some(lhs_type.clone());
}
match (lhs_type, rhs_type) {
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
}
(Dictionary(key_type, value_type), _) => {
let value_type = mathematics_numerical_coercion(value_type, rhs_type);
value_type
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)))
}
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(lhs_type, value_type)
}
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}
fn create_decimal_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DECIMAL128_MAX_PRECISION.min(precision),
DECIMAL128_MAX_SCALE.min(scale),
)
}
pub fn coercion_decimal_mathematics_type(
mathematics_op: &Operator,
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
(Decimal128(_, _), Decimal128(_, _)) => match mathematics_op {
Operator::Plus | Operator::Minus => decimal_op_mathematics_type(
mathematics_op,
left_decimal_type,
right_decimal_type,
),
Operator::Multiply | Operator::Divide | Operator::Modulo => {
get_wider_decimal_type(left_decimal_type, right_decimal_type)
}
_ => None,
},
_ => None,
}
}
pub fn decimal_op_mathematics_type(
mathematics_op: &Operator,
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
(Decimal128(p1, s1), Decimal128(p2, s2)) => {
match mathematics_op {
Operator::Plus | Operator::Minus => {
let result_scale = *s1.max(s2);
let result_precision =
result_scale + (*p1 as i8 - *s1).max(*p2 as i8 - *s2) + 1;
Some(create_decimal_type(result_precision as u8, result_scale))
}
Operator::Multiply => {
let result_scale = *s1 + *s2;
let result_precision = *p1 + *p2 + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Divide => {
let result_scale = 6.max(*s1 + *p2 as i8 + 1);
let result_precision = result_scale + *p1 as i8 - *s1 + *s2;
Some(create_decimal_type(result_precision as u8, result_scale))
}
Operator::Modulo => {
let result_scale = *s1.max(s2);
let result_precision =
result_scale + (*p1 as i8 - *s1).min(*p2 as i8 - *s2);
Some(create_decimal_type(result_precision as u8, result_scale))
}
_ => None,
}
}
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
decimal_op_mathematics_type(
mathematics_op,
lhs_value_type.as_ref(),
rhs_value_type.as_ref(),
)
}
(Dictionary(key_type, value_type), _) => {
let value_type = decimal_op_mathematics_type(
mathematics_op,
value_type.as_ref(),
right_decimal_type,
);
value_type
.map(|value_type| Dictionary(key_type.clone(), Box::new(value_type)))
}
(_, Dictionary(_, value_type)) => decimal_op_mathematics_type(
mathematics_op,
left_decimal_type,
value_type.as_ref(),
),
_ => None,
}
}
fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_numeric(lhs_type),
(DataType::Null, _) => is_numeric(rhs_type),
(
DataType::Dictionary(_, lhs_value_type),
DataType::Dictionary(_, rhs_value_type),
) => is_numeric(lhs_value_type) && is_numeric(rhs_value_type),
(DataType::Dictionary(_, value_type), _) => {
is_numeric(value_type) && is_numeric(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_numeric(lhs_type) && is_numeric(value_type)
}
_ => is_numeric(lhs_type) && is_numeric(rhs_type),
}
}
fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_decimal(lhs_type),
(DataType::Null, _) => is_decimal(rhs_type),
(DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => true,
(DataType::Dictionary(_, value_type), _) => {
is_decimal(value_type) && is_decimal(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_decimal(lhs_type) && is_decimal(value_type)
}
_ => false,
}
}
pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
match (lhs_type, rhs_type) {
(DataType::Dictionary(_, value_type), _) => {
is_decimal(value_type) || is_decimal(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_decimal(lhs_type) || is_decimal(value_type)
}
(_, _) => is_decimal(lhs_type) || is_decimal(rhs_type),
}
}
fn dictionary_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
match (lhs_type, rhs_type) {
(
DataType::Dictionary(_lhs_index_type, lhs_value_type),
DataType::Dictionary(_rhs_index_type, rhs_value_type),
) => comparison_coercion(lhs_value_type, rhs_value_type),
(d @ DataType::Dictionary(_, value_type), other_type)
| (other_type, d @ DataType::Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(DataType::Dictionary(_index_type, value_type), _) => {
comparison_coercion(value_type, rhs_type)
}
(_, DataType::Dictionary(_index_type, value_type)) => {
comparison_coercion(lhs_type, value_type)
}
_ => None,
}
}
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
}
fn string_concat_internal_coercion(
from_type: &DataType,
to_type: &DataType,
) -> Option<DataType> {
if can_cast_types(from_type, to_type) {
Some(to_type.to_owned())
} else {
None
}
}
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, Utf8) => Some(Utf8),
(LargeUtf8, Utf8) => Some(LargeUtf8),
(Utf8, LargeUtf8) => Some(LargeUtf8),
(LargeUtf8, LargeUtf8) => Some(LargeUtf8),
_ => None,
}
}
pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
}
fn is_time_with_valid_unit(datatype: DataType) -> bool {
matches!(
datatype,
DataType::Time32(TimeUnit::Second)
| DataType::Time32(TimeUnit::Millisecond)
| DataType::Time64(TimeUnit::Microsecond)
| DataType::Time64(TimeUnit::Nanosecond)
)
}
fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Date64, Date32) | (Date32, Date64) => Some(Date64),
(Utf8, Date32) | (Date32, Utf8) => Some(Date32),
(Utf8, Date64) | (Date64, Utf8) => Some(Date64),
(Utf8, Time32(unit)) | (Time32(unit), Utf8) => {
match is_time_with_valid_unit(Time32(unit.clone())) {
false => None,
true => Some(Time32(unit.clone())),
}
}
(Utf8, Time64(unit)) | (Time64(unit), Utf8) => {
match is_time_with_valid_unit(Time64(unit.clone())) {
false => None,
true => Some(Time64(unit.clone())),
}
}
(Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => {
Some(Timestamp(TimeUnit::Nanosecond, tz.clone()))
}
(Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
Some(Timestamp(TimeUnit::Nanosecond, None))
}
(Timestamp(_, _tz), Date32) | (Date32, Timestamp(_, _tz)) => {
Some(Timestamp(TimeUnit::Nanosecond, None))
}
(Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
let tz = match (lhs_tz, rhs_tz) {
(Some(lhs_tz), Some(rhs_tz)) => {
if lhs_tz != rhs_tz {
return None;
} else {
Some(lhs_tz.clone())
}
}
(Some(lhs_tz), None) => Some(lhs_tz.clone()),
(None, Some(rhs_tz)) => Some(rhs_tz.clone()),
(None, None) => None,
};
let unit = match (lhs_unit, rhs_unit) {
(TimeUnit::Second, TimeUnit::Millisecond) => TimeUnit::Second,
(TimeUnit::Second, TimeUnit::Microsecond) => TimeUnit::Second,
(TimeUnit::Second, TimeUnit::Nanosecond) => TimeUnit::Second,
(TimeUnit::Millisecond, TimeUnit::Second) => TimeUnit::Second,
(TimeUnit::Millisecond, TimeUnit::Microsecond) => TimeUnit::Millisecond,
(TimeUnit::Millisecond, TimeUnit::Nanosecond) => TimeUnit::Millisecond,
(TimeUnit::Microsecond, TimeUnit::Second) => TimeUnit::Second,
(TimeUnit::Microsecond, TimeUnit::Millisecond) => TimeUnit::Millisecond,
(TimeUnit::Microsecond, TimeUnit::Nanosecond) => TimeUnit::Microsecond,
(TimeUnit::Nanosecond, TimeUnit::Second) => TimeUnit::Second,
(TimeUnit::Nanosecond, TimeUnit::Millisecond) => TimeUnit::Millisecond,
(TimeUnit::Nanosecond, TimeUnit::Microsecond) => TimeUnit::Microsecond,
(l, r) => {
assert_eq!(l, r);
l.clone()
}
};
Some(Timestamp(unit, tz))
}
_ => None,
}
}
fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, other_type) | (other_type, DataType::Null) => {
if can_cast_types(&DataType::Null, other_type) {
Some(other_type.clone())
} else {
None
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Operator;
use arrow::datatypes::DataType;
use datafusion_common::assert_contains;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
#[test]
fn test_coercion_error() -> Result<()> {
let result_type =
coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
if let Err(DataFusionError::Plan(e)) = result_type {
assert_eq!(e, "Float32 + Utf8 can't be evaluated because there isn't a common type to coerce the types to");
Ok(())
} else {
Err(DataFusionError::Internal(
"Coercion should have returned an DataFusionError::Internal".to_string(),
))
}
}
#[test]
fn test_decimal_binary_comparison_coercion() -> Result<()> {
let input_decimal = DataType::Decimal128(20, 3);
let input_types = [
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Decimal128(38, 10),
DataType::Decimal128(20, 8),
DataType::Null,
];
let result_types = [
DataType::Decimal128(20, 3),
DataType::Decimal128(20, 3),
DataType::Decimal128(20, 3),
DataType::Decimal128(23, 3),
DataType::Decimal128(24, 7),
DataType::Decimal128(32, 15),
DataType::Decimal128(38, 10),
DataType::Decimal128(25, 8),
DataType::Decimal128(20, 3),
];
let comparison_op_types = [
Operator::NotEq,
Operator::Eq,
Operator::Gt,
Operator::GtEq,
Operator::Lt,
Operator::LtEq,
];
for (i, input_type) in input_types.iter().enumerate() {
let expect_type = &result_types[i];
for op in comparison_op_types {
let result_type = coerce_types(&input_decimal, &op, input_type)?;
assert_eq!(expect_type, &result_type);
}
}
let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean);
assert!(result_type.is_err());
Ok(())
}
#[test]
fn test_decimal_mathematics_op_type() {
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(),
DataType::Decimal128(3, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(),
DataType::Decimal128(5, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(),
DataType::Decimal128(10, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(),
DataType::Decimal128(20, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(),
DataType::Decimal128(14, 7)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(),
DataType::Decimal128(30, 15)
);
let op = Operator::Plus;
let left_decimal_type = DataType::Decimal128(10, 3);
let right_decimal_type = DataType::Decimal128(20, 4);
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal128(21, 4), result.unwrap());
let op = Operator::Minus;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal128(21, 4), result.unwrap());
let op = Operator::Multiply;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal128(20, 4), result.unwrap());
let result =
decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type);
assert_eq!(DataType::Decimal128(31, 7), result.unwrap());
let op = Operator::Divide;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal128(20, 4), result.unwrap());
let result =
decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type);
assert_eq!(DataType::Decimal128(35, 24), result.unwrap());
let op = Operator::Modulo;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal128(20, 4), result.unwrap());
let result =
decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type);
assert_eq!(DataType::Decimal128(11, 4), result.unwrap());
}
#[test]
fn test_dictionary_type_coercion() {
use DataType::*;
let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, false),
Some(Int32)
);
let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Utf8));
let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None);
let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Utf8;
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(lhs_type.clone())
);
let lhs_type = Utf8;
let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(rhs_type.clone())
);
}
macro_rules! test_coercion_binary_rule {
($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{
let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
assert_eq!(result, $C_TYPE);
}};
}
#[test]
fn test_date_timestamp_arithmetic_error() -> Result<()> {
let err = coerce_types(
&DataType::Timestamp(TimeUnit::Nanosecond, None),
&Operator::Minus,
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
.unwrap_err()
.to_string();
assert_contains!(&err, "Timestamp(Nanosecond, None) - Timestamp(Millisecond, None) can't be evaluated because there isn't a common type to coerce the types to");
let err = coerce_types(&DataType::Date32, &Operator::Plus, &DataType::Date64)
.unwrap_err()
.to_string();
assert_contains!(&err, "Date32 + Date64 can't be evaluated because there isn't a common type to coerce the types to");
Ok(())
}
#[test]
fn test_type_coercion() -> Result<()> {
let result = like_coercion(&DataType::Utf8, &DataType::Utf8);
assert_eq!(result, Some(DataType::Utf8));
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Date32,
Operator::Eq,
DataType::Date32
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Date64,
Operator::Lt,
DataType::Date64
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Time32(TimeUnit::Second),
Operator::Eq,
DataType::Time32(TimeUnit::Second)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Time32(TimeUnit::Millisecond),
Operator::Eq,
DataType::Time32(TimeUnit::Millisecond)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Time64(TimeUnit::Microsecond),
Operator::Eq,
DataType::Time64(TimeUnit::Microsecond)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Time64(TimeUnit::Nanosecond),
Operator::Eq,
DataType::Time64(TimeUnit::Nanosecond)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Timestamp(TimeUnit::Second, None),
Operator::Lt,
DataType::Timestamp(TimeUnit::Nanosecond, None)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Timestamp(TimeUnit::Millisecond, None),
Operator::Lt,
DataType::Timestamp(TimeUnit::Nanosecond, None)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Timestamp(TimeUnit::Microsecond, None),
Operator::Lt,
DataType::Timestamp(TimeUnit::Nanosecond, None)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Timestamp(TimeUnit::Nanosecond, None),
Operator::Lt,
DataType::Timestamp(TimeUnit::Nanosecond, None)
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Utf8,
Operator::RegexMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Utf8,
Operator::RegexNotMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Utf8,
DataType::Utf8,
Operator::RegexNotIMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
DataType::Utf8,
Operator::RegexMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
DataType::Utf8,
Operator::RegexIMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
DataType::Utf8,
Operator::RegexNotMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()),
DataType::Utf8,
Operator::RegexNotIMatch,
DataType::Utf8
);
test_coercion_binary_rule!(
DataType::Int16,
DataType::Int64,
Operator::BitwiseAnd,
DataType::Int64
);
test_coercion_binary_rule!(
DataType::UInt64,
DataType::UInt64,
Operator::BitwiseAnd,
DataType::UInt64
);
test_coercion_binary_rule!(
DataType::Int8,
DataType::UInt32,
Operator::BitwiseAnd,
DataType::Int64
);
test_coercion_binary_rule!(
DataType::UInt32,
DataType::Int32,
Operator::BitwiseAnd,
DataType::Int64
);
test_coercion_binary_rule!(
DataType::UInt16,
DataType::Int16,
Operator::BitwiseAnd,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::UInt32,
DataType::UInt32,
Operator::BitwiseAnd,
DataType::UInt32
);
test_coercion_binary_rule!(
DataType::UInt16,
DataType::UInt32,
Operator::BitwiseAnd,
DataType::UInt32
);
Ok(())
}
#[test]
fn test_type_coercion_arithmetic() -> Result<()> {
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt32,
Operator::Plus,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt16,
Operator::Minus,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int8,
DataType::Int64,
Operator::Multiply,
DataType::Int64
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Int32,
Operator::Plus,
DataType::Float32
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Float64,
Operator::Multiply,
DataType::Float64
);
Ok(())
}
fn test_math_decimal_coercion_rule(
lhs_type: DataType,
rhs_type: DataType,
mathematics_op: Operator,
expected_lhs_type: Option<DataType>,
expected_rhs_type: Option<DataType>,
expected_coerced_type: DataType,
expected_output_type: DataType,
) {
let (l, r) = math_decimal_coercion(&lhs_type, &rhs_type);
assert_eq!(l, expected_lhs_type);
assert_eq!(r, expected_rhs_type);
let lhs_type = l.unwrap_or(lhs_type);
let rhs_type = r.unwrap_or(rhs_type);
let coerced_type =
coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type)
.unwrap();
assert_eq!(coerced_type, expected_coerced_type);
let output_type =
decimal_op_mathematics_type(&mathematics_op, &lhs_type, &rhs_type).unwrap();
assert_eq!(output_type, expected_output_type);
}
#[test]
fn test_coercion_arithmetic_decimal() -> Result<()> {
test_math_decimal_coercion_rule(
DataType::Decimal128(10, 2),
DataType::Decimal128(10, 2),
Operator::Plus,
None,
None,
DataType::Decimal128(11, 2),
DataType::Decimal128(11, 2),
);
test_math_decimal_coercion_rule(
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Plus,
Some(DataType::Decimal128(10, 0)),
None,
DataType::Decimal128(13, 2),
DataType::Decimal128(13, 2),
);
test_math_decimal_coercion_rule(
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Minus,
Some(DataType::Decimal128(10, 0)),
None,
DataType::Decimal128(13, 2),
DataType::Decimal128(13, 2),
);
test_math_decimal_coercion_rule(
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Multiply,
Some(DataType::Decimal128(10, 0)),
None,
DataType::Decimal128(12, 2),
DataType::Decimal128(21, 2),
);
test_math_decimal_coercion_rule(
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Divide,
Some(DataType::Decimal128(10, 0)),
None,
DataType::Decimal128(12, 2),
DataType::Decimal128(23, 11),
);
test_math_decimal_coercion_rule(
DataType::Int32,
DataType::Decimal128(10, 2),
Operator::Modulo,
Some(DataType::Decimal128(10, 0)),
None,
DataType::Decimal128(12, 2),
DataType::Decimal128(10, 2),
);
Ok(())
}
#[test]
fn test_type_coercion_compare() -> Result<()> {
test_coercion_binary_rule!(
DataType::Boolean,
DataType::Boolean,
Operator::Eq,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Int64,
Operator::Eq,
DataType::Float32
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Float64,
Operator::GtEq,
DataType::Float64
);
test_coercion_binary_rule!(
DataType::Int8,
DataType::Int32,
Operator::LtEq,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int64,
DataType::Int32,
Operator::LtEq,
DataType::Int64
);
test_coercion_binary_rule!(
DataType::UInt32,
DataType::UInt8,
Operator::Gt,
DataType::UInt32
);
test_coercion_binary_rule!(
DataType::Int64,
DataType::Decimal128(10, 0),
Operator::Eq,
DataType::Decimal128(20, 0)
);
test_coercion_binary_rule!(
DataType::Int64,
DataType::Decimal128(10, 2),
Operator::Lt,
DataType::Decimal128(22, 2)
);
test_coercion_binary_rule!(
DataType::Float64,
DataType::Decimal128(10, 3),
Operator::Gt,
DataType::Decimal128(30, 15)
);
test_coercion_binary_rule!(
DataType::Int64,
DataType::Decimal128(10, 0),
Operator::Eq,
DataType::Decimal128(20, 0)
);
test_coercion_binary_rule!(
DataType::Decimal128(14, 2),
DataType::Decimal128(10, 3),
Operator::GtEq,
DataType::Decimal128(15, 3)
);
Ok(())
}
#[test]
fn test_type_coercion_logical_op() -> Result<()> {
test_coercion_binary_rule!(
DataType::Boolean,
DataType::Boolean,
Operator::And,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Boolean,
DataType::Boolean,
Operator::Or,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Boolean,
DataType::Null,
Operator::And,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Boolean,
DataType::Null,
Operator::Or,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Null,
DataType::Null,
Operator::Or,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Null,
DataType::Null,
Operator::And,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Null,
DataType::Boolean,
Operator::And,
DataType::Boolean
);
test_coercion_binary_rule!(
DataType::Null,
DataType::Boolean,
Operator::Or,
DataType::Boolean
);
Ok(())
}
}