use datafusion_common::{DataFusionError, Result};
use sqlparser::ast;
use std::cmp::Ordering;
use std::convert::{From, TryFrom};
use std::fmt;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub struct WindowFrame {
pub units: WindowFrameUnits,
pub start_bound: WindowFrameBound,
pub end_bound: WindowFrameBound,
}
impl fmt::Display for WindowFrame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} BETWEEN {} AND {}",
self.units, self.start_bound, self.end_bound
)?;
Ok(())
}
}
impl TryFrom<ast::WindowFrame> for WindowFrame {
type Error = DataFusionError;
fn try_from(value: ast::WindowFrame) -> Result<Self> {
let start_bound = value.start_bound.into();
let end_bound = value
.end_bound
.map(WindowFrameBound::from)
.unwrap_or(WindowFrameBound::CurrentRow);
if let WindowFrameBound::Following(None) = start_bound {
Err(DataFusionError::Execution(
"Invalid window frame: start bound cannot be unbounded following"
.to_owned(),
))
} else if let WindowFrameBound::Preceding(None) = end_bound {
Err(DataFusionError::Execution(
"Invalid window frame: end bound cannot be unbounded preceding"
.to_owned(),
))
} else if start_bound > end_bound {
Err(DataFusionError::Execution(format!(
"Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
start_bound, end_bound
)))
} else {
let units = value.units.into();
if units == WindowFrameUnits::Range {
for bound in &[start_bound, end_bound] {
match bound {
WindowFrameBound::Preceding(Some(v))
| WindowFrameBound::Following(Some(v))
if *v > 0 =>
{
Err(DataFusionError::NotImplemented(format!(
"With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment",
units, v
)))
}
_ => Ok(()),
}?;
}
}
Ok(Self {
units,
start_bound,
end_bound,
})
}
}
}
impl Default for WindowFrame {
fn default() -> Self {
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(None),
end_bound: WindowFrameBound::CurrentRow,
}
}
}
#[derive(Debug, Clone, Copy, Eq)]
pub enum WindowFrameBound {
Preceding(Option<u64>),
CurrentRow,
Following(Option<u64>),
}
impl From<ast::WindowFrameBound> for WindowFrameBound {
fn from(value: ast::WindowFrameBound) -> Self {
match value {
ast::WindowFrameBound::Preceding(v) => Self::Preceding(v),
ast::WindowFrameBound::Following(v) => Self::Following(v),
ast::WindowFrameBound::CurrentRow => Self::CurrentRow,
}
}
}
impl fmt::Display for WindowFrameBound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"),
WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"),
WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"),
WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n),
WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n),
}
}
}
impl PartialEq for WindowFrameBound {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for WindowFrameBound {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for WindowFrameBound {
fn cmp(&self, other: &Self) -> Ordering {
self.get_rank().cmp(&other.get_rank())
}
}
impl Hash for WindowFrameBound {
fn hash<H: Hasher>(&self, state: &mut H) {
self.get_rank().hash(state)
}
}
impl WindowFrameBound {
fn get_rank(&self) -> (u8, u64) {
match self {
WindowFrameBound::Preceding(None) => (0, 0),
WindowFrameBound::Following(None) => (4, 0),
WindowFrameBound::Preceding(Some(0))
| WindowFrameBound::CurrentRow
| WindowFrameBound::Following(Some(0)) => (2, 0),
WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v),
WindowFrameBound::Following(Some(v)) => (3, *v),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub enum WindowFrameUnits {
Rows,
Range,
Groups,
}
impl fmt::Display for WindowFrameUnits {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
WindowFrameUnits::Rows => "ROWS",
WindowFrameUnits::Range => "RANGE",
WindowFrameUnits::Groups => "GROUPS",
})
}
}
impl From<ast::WindowFrameUnits> for WindowFrameUnits {
fn from(value: ast::WindowFrameUnits) -> Self {
match value {
ast::WindowFrameUnits::Range => Self::Range,
ast::WindowFrameUnits::Groups => Self::Groups,
ast::WindowFrameUnits::Rows => Self::Rows,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_window_frame_creation() -> Result<()> {
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Following(None),
end_bound: None,
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
result.err().unwrap().to_string(),
"Execution error: Invalid window frame: start bound cannot be unbounded following"
.to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Preceding(None),
end_bound: Some(ast::WindowFrameBound::Preceding(None)),
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
result.err().unwrap().to_string(),
"Execution error: Invalid window frame: end bound cannot be unbounded preceding"
.to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Preceding(Some(1)),
end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))),
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
result.err().unwrap().to_string(),
"Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
start_bound: ast::WindowFrameBound::Preceding(Some(2)),
end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))),
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
result.err().unwrap().to_string(),
"This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned()
);
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Rows,
start_bound: ast::WindowFrameBound::Preceding(Some(2)),
end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))),
};
let result = WindowFrame::try_from(window_frame);
assert!(result.is_ok());
Ok(())
}
#[test]
fn test_eq() {
assert_eq!(
WindowFrameBound::Preceding(Some(0)),
WindowFrameBound::CurrentRow
);
assert_eq!(
WindowFrameBound::CurrentRow,
WindowFrameBound::Following(Some(0))
);
assert_eq!(
WindowFrameBound::Following(Some(2)),
WindowFrameBound::Following(Some(2))
);
assert_eq!(
WindowFrameBound::Following(None),
WindowFrameBound::Following(None)
);
assert_eq!(
WindowFrameBound::Preceding(Some(2)),
WindowFrameBound::Preceding(Some(2))
);
assert_eq!(
WindowFrameBound::Preceding(None),
WindowFrameBound::Preceding(None)
);
}
#[test]
fn test_ord() {
assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow);
assert!(
WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1))
);
assert!(
WindowFrameBound::Preceding(Some(u64::MAX))
< WindowFrameBound::Preceding(Some(u64::MAX - 1))
);
assert!(
WindowFrameBound::Preceding(None)
< WindowFrameBound::Preceding(Some(1000000))
);
assert!(
WindowFrameBound::Preceding(None)
< WindowFrameBound::Preceding(Some(u64::MAX))
);
assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0)));
assert!(
WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1))
);
assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1)));
assert!(
WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2))
);
assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None));
assert!(
WindowFrameBound::Following(Some(u64::MAX))
< WindowFrameBound::Following(None)
);
}
}