use std::fmt;
use std::str::FromStr;
use crate::type_coercion::functions::data_types;
use crate::utils;
use crate::{Signature, TypeSignature, Volatility};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
use arrow::datatypes::DataType;
use strum_macros::EnumIter;
impl fmt::Display for BuiltInWindowFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
pub enum BuiltInWindowFunction {
RowNumber,
Rank,
DenseRank,
PercentRank,
CumeDist,
Ntile,
Lag,
Lead,
FirstValue,
LastValue,
NthValue,
}
impl BuiltInWindowFunction {
fn name(&self) -> &str {
use BuiltInWindowFunction::*;
match self {
RowNumber => "ROW_NUMBER",
Rank => "RANK",
DenseRank => "DENSE_RANK",
PercentRank => "PERCENT_RANK",
CumeDist => "CUME_DIST",
Ntile => "NTILE",
Lag => "LAG",
Lead => "LEAD",
FirstValue => "FIRST_VALUE",
LastValue => "LAST_VALUE",
NthValue => "NTH_VALUE",
}
}
}
impl FromStr for BuiltInWindowFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
Ok(match name.to_uppercase().as_str() {
"ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
"RANK" => BuiltInWindowFunction::Rank,
"DENSE_RANK" => BuiltInWindowFunction::DenseRank,
"PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
"CUME_DIST" => BuiltInWindowFunction::CumeDist,
"NTILE" => BuiltInWindowFunction::Ntile,
"LAG" => BuiltInWindowFunction::Lag,
"LEAD" => BuiltInWindowFunction::Lead,
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
_ => return plan_err!("There is no built-in window function named {name}"),
})
}
}
impl BuiltInWindowFunction {
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
data_types(input_expr_types, &self.signature())
.map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
)
)
})?;
match self {
BuiltInWindowFunction::RowNumber
| BuiltInWindowFunction::Rank
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
Ok(DataType::Float64)
}
BuiltInWindowFunction::Lag
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue
| BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
}
}
pub fn signature(&self) -> Signature {
match self {
BuiltInWindowFunction::RowNumber
| BuiltInWindowFunction::Rank
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::PercentRank
| BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable),
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
Signature::one_of(
vec![
TypeSignature::Any(1),
TypeSignature::Any(2),
TypeSignature::Any(3),
],
Volatility::Immutable,
)
}
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
Signature::any(1, Volatility::Immutable)
}
BuiltInWindowFunction::Ntile => Signature::uniform(
1,
vec![
DataType::UInt64,
DataType::UInt32,
DataType::UInt16,
DataType::UInt8,
DataType::Int64,
DataType::Int32,
DataType::Int16,
DataType::Int8,
],
Volatility::Immutable,
),
BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use strum::IntoEnumIterator;
#[test]
fn test_display_and_from_str() {
for func_original in BuiltInWindowFunction::iter() {
let func_name = func_original.to_string();
let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap();
assert_eq!(func_from_str, func_original);
}
}
}