use crate::error::{DataFusionError, Result, _plan_err};
use arrow::{
array::{new_null_array, Array, ArrayRef, StructArray},
compute::cast,
datatypes::{DataType::Struct, Field, FieldRef},
};
use std::sync::Arc;
fn cast_struct_column(
source_col: &ArrayRef,
target_fields: &[Arc<Field>],
) -> Result<ArrayRef> {
if let Some(struct_array) = source_col.as_any().downcast_ref::<StructArray>() {
let mut children: Vec<(Arc<Field>, Arc<dyn Array>)> = Vec::new();
let num_rows = source_col.len();
for target_child_field in target_fields {
let field_arc = Arc::clone(target_child_field);
match struct_array.column_by_name(target_child_field.name()) {
Some(source_child_col) => {
let adapted_child =
cast_column(source_child_col, target_child_field)?;
children.push((field_arc, adapted_child));
}
None => {
children.push((
field_arc,
new_null_array(target_child_field.data_type(), num_rows),
));
}
}
}
let struct_array = StructArray::from(children);
Ok(Arc::new(struct_array))
} else {
Err(DataFusionError::Plan(format!(
"Cannot cast column of type {:?} to struct type. Source must be a struct to cast to struct.",
source_col.data_type()
)))
}
}
pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result<ArrayRef> {
match target_field.data_type() {
Struct(target_fields) => cast_struct_column(source_col, target_fields),
_ => Ok(cast(source_col, target_field.data_type())?),
}
}
pub fn validate_struct_compatibility(
source_fields: &[FieldRef],
target_fields: &[FieldRef],
) -> Result<bool> {
for target_field in target_fields {
if let Some(source_field) = source_fields
.iter()
.find(|f| f.name() == target_field.name())
{
match (source_field.data_type(), target_field.data_type()) {
(Struct(source_nested), Struct(target_nested)) => {
validate_struct_compatibility(source_nested, target_nested)?;
}
_ => {
if !arrow::compute::can_cast_types(
source_field.data_type(),
target_field.data_type(),
) {
return _plan_err!(
"Cannot cast struct field '{}' from type {:?} to type {:?}",
target_field.name(),
source_field.data_type(),
target_field.data_type()
);
}
}
}
}
}
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Int32Array, Int64Array, StringArray},
datatypes::{DataType, Field},
};
macro_rules! get_column_as {
($struct_array:expr, $column_name:expr, $array_type:ty) => {
$struct_array
.column_by_name($column_name)
.unwrap()
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
};
}
#[test]
fn test_cast_simple_column() {
let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let target_field = Field::new("ints", DataType::Int64, true);
let result = cast_column(&source, &target_field).unwrap();
let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.value(0), 1);
assert_eq!(result.value(1), 2);
assert_eq!(result.value(2), 3);
}
#[test]
fn test_cast_struct_with_missing_field() {
let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef;
let source_struct = StructArray::from(vec![(
Arc::new(Field::new("a", DataType::Int32, true)),
Arc::clone(&a_array),
)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = Field::new(
"s",
Struct(
vec![
Arc::new(Field::new("a", DataType::Int32, true)),
Arc::new(Field::new("b", DataType::Utf8, true)),
]
.into(),
),
true,
);
let result = cast_column(&source_col, &target_field).unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_array.fields().len(), 2);
let a_result = get_column_as!(&struct_array, "a", Int32Array);
assert_eq!(a_result.value(0), 1);
assert_eq!(a_result.value(1), 2);
let b_result = get_column_as!(&struct_array, "b", StringArray);
assert_eq!(b_result.len(), 2);
assert!(b_result.is_null(0));
assert!(b_result.is_null(1));
}
#[test]
fn test_cast_struct_source_not_struct() {
let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef;
let target_field = Field::new(
"s",
Struct(vec![Arc::new(Field::new("a", DataType::Int32, true))].into()),
true,
);
let result = cast_column(&source, &target_field);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast column of type"));
assert!(error_msg.contains("to struct type"));
assert!(error_msg.contains("Source must be a struct"));
}
#[test]
fn test_validate_struct_compatibility_incompatible_types() {
let source_fields = vec![
Arc::new(Field::new("field1", DataType::Binary, true)),
Arc::new(Field::new("field2", DataType::Utf8, true)),
];
let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast struct field 'field1'"));
assert!(error_msg.contains("Binary"));
assert!(error_msg.contains("Int32"));
}
#[test]
fn test_validate_struct_compatibility_compatible_types() {
let source_fields = vec![
Arc::new(Field::new("field1", DataType::Int32, true)),
Arc::new(Field::new("field2", DataType::Utf8, true)),
];
let target_fields = vec![Arc::new(Field::new("field1", DataType::Int64, true))];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_validate_struct_compatibility_missing_field_in_source() {
let source_fields = vec![Arc::new(Field::new("field2", DataType::Utf8, true))];
let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
assert!(result.unwrap());
}
}