use arrow::array::timezone::Tz;
use arrow::array::types::TimestampNanosecondType;
use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
use arrow::datatypes::{
DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
use datafusion_physical_plan::ExecutionPlan;
use parking_lot::RwLock;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
#[derive(Debug, Clone)]
struct Empty {
name: &'static str,
}
impl LazyBatchGenerator for Empty {
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
Ok(None)
}
}
impl fmt::Display for Empty {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}: empty", self.name)
}
}
trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
type StepType: fmt::Debug + Clone + Send + Sync;
type ValueType: fmt::Debug + Clone + Send + Sync;
fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
fn advance(&mut self, step: &Self::StepType) -> Result<()>;
fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
fn to_value_type(&self) -> Self::ValueType;
fn display_value(&self) -> String;
}
impl SeriesValue for i64 {
type StepType = i64;
type ValueType = i64;
fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
reach_end_int64(*self, end, *step, include_end)
}
fn advance(&mut self, step: &Self::StepType) -> Result<()> {
*self += step;
Ok(())
}
fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
Ok(Arc::new(Int64Array::from(values)))
}
fn to_value_type(&self) -> Self::ValueType {
*self
}
fn display_value(&self) -> String {
self.to_string()
}
}
#[derive(Debug, Clone)]
struct TimestampValue {
value: i64,
parsed_tz: Option<Tz>,
tz_str: Option<Arc<str>>,
}
impl SeriesValue for TimestampValue {
type StepType = IntervalMonthDayNano;
type ValueType = i64;
fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
if include_end {
if step_negative {
self.value < end.value
} else {
self.value > end.value
}
} else if step_negative {
self.value <= end.value
} else {
self.value >= end.value
}
}
fn advance(&mut self, step: &Self::StepType) -> Result<()> {
let tz = self
.parsed_tz
.unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
let Some(next_ts) =
TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
else {
return plan_err!(
"Failed to add interval {:?} to timestamp {}",
step,
self.value
);
};
self.value = next_ts;
Ok(())
}
fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
let array = TimestampNanosecondArray::from(values);
let array = match self.tz_str.as_ref() {
Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
None => array,
};
Ok(Arc::new(array))
}
fn to_value_type(&self) -> Self::ValueType {
self.value
}
fn display_value(&self) -> String {
self.value.to_string()
}
}
#[derive(Debug, Clone)]
enum GenSeriesArgs {
ContainsNull { name: &'static str },
Int64Args {
start: i64,
end: i64,
step: i64,
include_end: bool,
name: &'static str,
},
TimestampArgs {
start: i64,
end: i64,
step: IntervalMonthDayNano,
tz: Option<Arc<str>>,
include_end: bool,
name: &'static str,
},
DateArgs {
start: i64,
end: i64,
step: IntervalMonthDayNano,
include_end: bool,
name: &'static str,
},
}
#[derive(Debug, Clone)]
struct GenerateSeriesTable {
schema: SchemaRef,
args: GenSeriesArgs,
}
#[derive(Debug, Clone)]
struct GenericSeriesState<T: SeriesValue> {
schema: SchemaRef,
start: T,
end: T,
step: T::StepType,
batch_size: usize,
current: T,
include_end: bool,
name: &'static str,
}
impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
let mut buf = Vec::with_capacity(self.batch_size);
while buf.len() < self.batch_size
&& !self
.current
.should_stop(self.end.clone(), &self.step, self.include_end)
{
buf.push(self.current.to_value_type());
self.current.advance(&self.step)?;
}
if buf.is_empty() {
return Ok(None);
}
let array = self.current.create_array(buf)?;
let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
Ok(Some(batch))
}
}
impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}: start={}, end={}, batch_size={}",
self.name,
self.start.display_value(),
self.end.display_value(),
self.batch_size
)
}
}
fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
if step > 0 {
if include_end {
val > end
} else {
val >= end
}
} else if include_end {
val < end
} else {
val <= end
}
}
fn validate_interval_step(
step: IntervalMonthDayNano,
start: i64,
end: i64,
) -> Result<()> {
if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
return plan_err!("Step interval cannot be zero");
}
let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
if start > end && step_is_positive {
return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
}
if start < end && step_is_negative {
return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
}
Ok(())
}
#[async_trait]
impl TableProvider for GenerateSeriesTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batch_size = state.config_options().execution.batch_size;
let schema = match projection {
Some(projection) => Arc::new(self.schema.project(projection)?),
None => self.schema(),
};
let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
GenSeriesArgs::Int64Args {
start,
end,
step,
include_end,
name,
} => Arc::new(RwLock::new(GenericSeriesState {
schema: self.schema(),
start: *start,
end: *end,
step: *step,
current: *start,
batch_size,
include_end: *include_end,
name,
})),
GenSeriesArgs::TimestampArgs {
start,
end,
step,
tz,
include_end,
name,
} => {
let parsed_tz = tz
.as_ref()
.map(|s| Tz::from_str(s.as_ref()))
.transpose()
.map_err(|e| {
datafusion_common::DataFusionError::Internal(format!(
"Failed to parse timezone: {e}"
))
})?
.unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
Arc::new(RwLock::new(GenericSeriesState {
schema: self.schema(),
start: TimestampValue {
value: *start,
parsed_tz: Some(parsed_tz),
tz_str: tz.clone(),
},
end: TimestampValue {
value: *end,
parsed_tz: Some(parsed_tz),
tz_str: tz.clone(),
},
step: *step,
current: TimestampValue {
value: *start,
parsed_tz: Some(parsed_tz),
tz_str: tz.clone(),
},
batch_size,
include_end: *include_end,
name,
}))
}
GenSeriesArgs::DateArgs {
start,
end,
step,
include_end,
name,
} => Arc::new(RwLock::new(GenericSeriesState {
schema: self.schema(),
start: TimestampValue {
value: *start,
parsed_tz: None,
tz_str: None,
},
end: TimestampValue {
value: *end,
parsed_tz: None,
tz_str: None,
},
step: *step,
current: TimestampValue {
value: *start,
parsed_tz: None,
tz_str: None,
},
batch_size,
include_end: *include_end,
name,
})),
};
Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?))
}
}
#[derive(Debug)]
struct GenerateSeriesFuncImpl {
name: &'static str,
include_end: bool,
}
impl TableFunctionImpl for GenerateSeriesFuncImpl {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
if exprs.is_empty() || exprs.len() > 3 {
return plan_err!("{} function requires 1 to 3 arguments", self.name);
}
match &exprs[0] {
Expr::Literal(
ScalarValue::Null | ScalarValue::Int64(_),
_,
) => self.call_int64(exprs),
Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
self.call_timestamp(exprs)
}
Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
self.call_date(exprs)
}
Expr::Literal(scalar, _) => {
plan_err!(
"Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
scalar.data_type()
)
}
_ => plan_err!("Arguments must be literals"),
}
}
}
impl GenerateSeriesFuncImpl {
fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let mut normalize_args = Vec::new();
for (expr_index, expr) in exprs.iter().enumerate() {
match expr {
Expr::Literal(ScalarValue::Null, _) => {}
Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
other => {
return plan_err!(
"Argument #{} must be an INTEGER or NULL, got {:?}",
expr_index + 1,
other
)
}
};
}
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
false,
)]));
if normalize_args.len() != exprs.len() {
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull { name: self.name },
}));
}
let (start, end, step) = match &normalize_args[..] {
[end] => (0, *end, 1),
[start, end] => (*start, *end, 1),
[start, end, step] => (*start, *end, *step),
_ => {
return plan_err!("{} function requires 1 to 3 arguments", self.name);
}
};
if start > end && step > 0 {
return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
}
if start < end && step < 0 {
return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
}
if step == 0 {
return plan_err!("Step cannot be zero");
}
Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::Int64Args {
start,
end,
step,
include_end: self.include_end,
name: self.name,
},
}))
}
fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
if exprs.len() != 3 {
return plan_err!(
"{} function with timestamps requires exactly 3 arguments",
self.name
);
}
let (start_ts, tz) = match &exprs[0] {
Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
(*ts, tz.clone())
}
other => {
return plan_err!(
"First argument must be a timestamp or NULL, got {:?}",
other
)
}
};
let end_ts = match &exprs[1] {
Expr::Literal(ScalarValue::Null, _) => None,
Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
other => {
return plan_err!(
"Second argument must be a timestamp or NULL, got {:?}",
other
)
}
};
let step_interval = match &exprs[2] {
Expr::Literal(ScalarValue::Null, _) => None,
Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
other => {
return plan_err!(
"Third argument must be an interval or NULL, got {:?}",
other
)
}
};
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
false,
)]));
let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
else {
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull { name: self.name },
}));
};
validate_interval_step(step, start, end)?;
Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::TimestampArgs {
start,
end,
step,
tz,
include_end: self.include_end,
name: self.name,
},
}))
}
fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
if exprs.len() != 3 {
return plan_err!(
"{} function with dates requires exactly 3 arguments",
self.name
);
}
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
)]));
let start_date = match &exprs[0] {
Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
Expr::Literal(ScalarValue::Date32(None), _)
| Expr::Literal(ScalarValue::Null, _) => {
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull { name: self.name },
}));
}
other => {
return plan_err!(
"First argument must be a date or NULL, got {:?}",
other
)
}
};
let end_date = match &exprs[1] {
Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
Expr::Literal(ScalarValue::Date32(None), _)
| Expr::Literal(ScalarValue::Null, _) => {
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull { name: self.name },
}));
}
other => {
return plan_err!(
"Second argument must be a date or NULL, got {:?}",
other
)
}
};
let step_interval = match &exprs[2] {
Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
*interval
}
Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
| Expr::Literal(ScalarValue::Null, _) => {
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull { name: self.name },
}));
}
other => {
return plan_err!(
"Third argument must be an interval or NULL, got {:?}",
other
)
}
};
const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
let start_ts = start_date as i64 * NANOS_PER_DAY;
let end_ts = end_date as i64 * NANOS_PER_DAY;
validate_interval_step(step_interval, start_ts, end_ts)?;
Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::DateArgs {
start: start_ts,
end: end_ts,
step: step_interval,
include_end: self.include_end,
name: self.name,
},
}))
}
}
#[derive(Debug)]
pub struct GenerateSeriesFunc {}
impl TableFunctionImpl for GenerateSeriesFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let impl_func = GenerateSeriesFuncImpl {
name: "generate_series",
include_end: true,
};
impl_func.call(exprs)
}
}
#[derive(Debug)]
pub struct RangeFunc {}
impl TableFunctionImpl for RangeFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let impl_func = GenerateSeriesFuncImpl {
name: "range",
include_end: false,
};
impl_func.call(exprs)
}
}