use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};
pub trait FunctionRegistry {
fn udfs(&self) -> HashSet<String>;
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Registering ScalarUDF")
}
fn register_udaf(
&mut self,
_udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Registering AggregateUDF")
}
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}
}
pub trait SerializerRegistry: Send + Sync {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
}
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
udwfs: HashMap<String, Arc<WindowUDF>>,
}
impl MemoryFunctionRegistry {
pub fn new() -> Self {
Self::default()
}
}
impl FunctionRegistry for MemoryFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.udfs.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.udfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.udafs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.udwfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.udfs.insert(udf.name().to_string(), udf))
}
}