use datafusion_common::{not_impl_err, plan_datafusion_err, Result};
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};
pub trait FunctionRegistry {
fn udfs(&self) -> HashSet<String>;
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Registering ScalarUDF")
}
fn register_udaf(
&mut self,
_udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Registering AggregateUDF")
}
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}
fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Deregistering ScalarUDF")
}
fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Deregistering AggregateUDF")
}
fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Deregistering WindowUDF")
}
fn register_function_rewrite(
&mut self,
_rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
) -> Result<()> {
not_impl_err!("Registering FunctionRewrite")
}
}
pub trait SerializerRegistry: Send + Sync {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
}
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
udwfs: HashMap<String, Arc<WindowUDF>>,
}
impl MemoryFunctionRegistry {
pub fn new() -> Self {
Self::default()
}
}
impl FunctionRegistry for MemoryFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.udfs.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.udfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.udafs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.udwfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.udfs.insert(udf.name().to_string(), udf))
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.udafs.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.udwfs.insert(udaf.name().into(), udaf))
}
}