use super::{
type_coercion::{coerce, data_types},
PhysicalExpr,
};
use crate::error::{ExecutionError, Result};
use crate::physical_plan::array_expressions;
use crate::physical_plan::datetime_expressions;
use crate::physical_plan::math_expressions;
use crate::physical_plan::string_expressions;
use arrow::{
array::ArrayRef,
compute::kernels::length::length,
datatypes::TimeUnit,
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use fmt::{Debug, Formatter};
use std::{fmt, str::FromStr, sync::Arc};
#[derive(Debug, Clone)]
pub enum Signature {
Variadic(Vec<DataType>),
VariadicEqual,
Uniform(usize, Vec<DataType>),
Exact(Vec<DataType>),
Any(usize),
}
pub type ScalarFunctionImplementation =
Arc<dyn Fn(&[ArrayRef]) -> Result<ArrayRef> + Send + Sync>;
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BuiltinScalarFunction {
Sqrt,
Sin,
Cos,
Tan,
Asin,
Acos,
Atan,
Exp,
Log,
Log2,
Log10,
Floor,
Ceil,
Round,
Trunc,
Abs,
Signum,
Length,
Concat,
ToTimestamp,
Array,
}
impl fmt::Display for BuiltinScalarFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", format!("{:?}", self).to_lowercase())
}
}
impl FromStr for BuiltinScalarFunction {
type Err = ExecutionError;
fn from_str(name: &str) -> Result<BuiltinScalarFunction> {
Ok(match name {
"sqrt" => BuiltinScalarFunction::Sqrt,
"sin" => BuiltinScalarFunction::Sin,
"cos" => BuiltinScalarFunction::Cos,
"tan" => BuiltinScalarFunction::Tan,
"asin" => BuiltinScalarFunction::Asin,
"acos" => BuiltinScalarFunction::Acos,
"atan" => BuiltinScalarFunction::Atan,
"exp" => BuiltinScalarFunction::Exp,
"log" => BuiltinScalarFunction::Log,
"log2" => BuiltinScalarFunction::Log2,
"log10" => BuiltinScalarFunction::Log10,
"floor" => BuiltinScalarFunction::Floor,
"ceil" => BuiltinScalarFunction::Ceil,
"round" => BuiltinScalarFunction::Round,
"truc" => BuiltinScalarFunction::Trunc,
"abs" => BuiltinScalarFunction::Abs,
"signum" => BuiltinScalarFunction::Signum,
"length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"to_timestamp" => BuiltinScalarFunction::ToTimestamp,
"array" => BuiltinScalarFunction::Array,
_ => {
return Err(ExecutionError::General(format!(
"There is no built-in function named {}",
name
)))
}
})
}
}
pub fn return_type(
fun: &BuiltinScalarFunction,
arg_types: &Vec<DataType>,
) -> Result<DataType> {
data_types(&arg_types, &signature(fun))?;
if arg_types.len() == 0 {
return Err(ExecutionError::General(
format!("Function '{}' requires at least one argument", fun).to_string(),
));
}
match fun {
BuiltinScalarFunction::Length => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Int64,
DataType::Utf8 => DataType::Int32,
_ => {
return Err(ExecutionError::InternalError(
"The length function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::ToTimestamp => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList(
Box::new(arg_types[0].clone()),
arg_types.len() as i32,
)),
_ => Ok(DataType::Float64),
}
}
pub fn create_physical_expr(
fun: &BuiltinScalarFunction,
args: &Vec<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
let fun_expr: ScalarFunctionImplementation = Arc::new(match fun {
BuiltinScalarFunction::Sqrt => math_expressions::sqrt,
BuiltinScalarFunction::Sin => math_expressions::sin,
BuiltinScalarFunction::Cos => math_expressions::cos,
BuiltinScalarFunction::Tan => math_expressions::tan,
BuiltinScalarFunction::Asin => math_expressions::asin,
BuiltinScalarFunction::Acos => math_expressions::acos,
BuiltinScalarFunction::Atan => math_expressions::atan,
BuiltinScalarFunction::Exp => math_expressions::exp,
BuiltinScalarFunction::Log => math_expressions::ln,
BuiltinScalarFunction::Log2 => math_expressions::log2,
BuiltinScalarFunction::Log10 => math_expressions::log10,
BuiltinScalarFunction::Floor => math_expressions::floor,
BuiltinScalarFunction::Ceil => math_expressions::ceil,
BuiltinScalarFunction::Round => math_expressions::round,
BuiltinScalarFunction::Trunc => math_expressions::trunc,
BuiltinScalarFunction::Abs => math_expressions::abs,
BuiltinScalarFunction::Signum => math_expressions::signum,
BuiltinScalarFunction::Length => |args| Ok(length(args[0].as_ref())?),
BuiltinScalarFunction::Concat => {
|args| Ok(Arc::new(string_expressions::concatenate(args)?))
}
BuiltinScalarFunction::ToTimestamp => {
|args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?))
}
BuiltinScalarFunction::Array => |args| Ok(array_expressions::array(args)?),
});
let args = coerce(args, input_schema, &signature(fun))?;
let arg_types = args
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{}", fun),
fun_expr,
args,
&return_type(&fun, &arg_types)?,
)))
}
fn signature(fun: &BuiltinScalarFunction) -> Signature {
match fun {
BuiltinScalarFunction::Length => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]),
BuiltinScalarFunction::Array => {
Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec())
}
_ => Signature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
}
}
pub struct ScalarFunctionExpr {
fun: ScalarFunctionImplementation,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
}
impl Debug for ScalarFunctionExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("ScalarFunctionExpr")
.field("fun", &"<FUNC>")
.field("name", &self.name)
.field("args", &self.args)
.field("return_type", &self.return_type)
.finish()
}
}
impl ScalarFunctionExpr {
pub fn new(
name: &str,
fun: ScalarFunctionImplementation,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: &DataType,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type: return_type.clone(),
}
}
}
impl fmt::Display for ScalarFunctionExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}({})",
self.name,
self.args
.iter()
.map(|e| format!("{}", e))
.collect::<Vec<String>>()
.join(", ")
)
}
}
impl PhysicalExpr for ScalarFunctionExpr {
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.return_type.clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let inputs = self
.args
.iter()
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;
let fun = self.fun.as_ref();
(fun)(&inputs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{error::Result, physical_plan::expressions::lit, scalar::ScalarValue};
use arrow::{
array::{
ArrayRef, FixedSizeListArray, Float64Array, Int32Array, PrimitiveArrayOps,
StringArray,
},
datatypes::Field,
record_batch::RecordBatch,
};
fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
let arg = lit(value);
let expr =
create_physical_expr(&BuiltinScalarFunction::Exp, &vec![arg], &schema)?;
assert_eq!(expr.data_type(&schema)?, DataType::Float64);
let result =
expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?;
let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(format!("{}", result.value(0)), expected);
Ok(())
}
#[test]
fn test_math_function() -> Result<()> {
let exp_f64 = "2.718281828459045";
let exp_f32 = "2.7182817459106445";
generic_test_math(ScalarValue::from(1i32), exp_f64)?;
generic_test_math(ScalarValue::from(1u32), exp_f64)?;
generic_test_math(ScalarValue::from(1u64), exp_f64)?;
generic_test_math(ScalarValue::from(1f64), exp_f64)?;
generic_test_math(ScalarValue::from(1f32), exp_f32)?;
Ok(())
}
fn test_concat(value: ScalarValue, expected: &str) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
let expr = create_physical_expr(
&BuiltinScalarFunction::Concat,
&vec![lit(value.clone()), lit(value)],
&schema,
)?;
assert_eq!(expr.data_type(&schema)?, DataType::Utf8);
let result =
expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?;
let result = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(format!("{}", result.value(0)), expected);
Ok(())
}
#[test]
fn test_concat_utf8() -> Result<()> {
test_concat(ScalarValue::Utf8(Some("aa".to_string())), "aaaa")
}
#[test]
fn test_concat_error() -> Result<()> {
let result = return_type(&BuiltinScalarFunction::Concat, &vec![]);
if let Ok(_) = result {
Err(ExecutionError::General(
"Function 'concat' cannot accept zero arguments".to_string(),
))
} else {
Ok(())
}
}
fn generic_test_array(
value1: ScalarValue,
value2: ScalarValue,
expected_type: DataType,
expected: &str,
) -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
let expr = create_physical_expr(
&BuiltinScalarFunction::Array,
&vec![lit(value1.clone()), lit(value2.clone())],
&schema,
)?;
assert_eq!(
expr.data_type(&schema)?,
DataType::FixedSizeList(Box::new(expected_type), 2)
);
let result =
expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?;
let result = result
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap();
assert_eq!(format!("{:?}", result.value(0)), expected);
Ok(())
}
#[test]
fn test_array() -> Result<()> {
generic_test_array(
ScalarValue::Utf8(Some("aa".to_string())),
ScalarValue::Utf8(Some("aa".to_string())),
DataType::Utf8,
"StringArray\n[\n \"aa\",\n \"aa\",\n]",
)?;
generic_test_array(
ScalarValue::from(1u32),
ScalarValue::from(1u64),
DataType::UInt64,
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)?;
generic_test_array(
ScalarValue::from(1u64),
ScalarValue::from(1u32),
DataType::UInt64,
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)
}
}