use crate::error::{_plan_datafusion_err, _plan_err};
use crate::{Result, ScalarValue};
use arrow::datatypes::DataType;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum ParamValues {
List(Vec<ScalarValue>),
Map(HashMap<String, ScalarValue>),
}
impl ParamValues {
pub fn verify(&self, expect: &[DataType]) -> Result<()> {
match self {
ParamValues::List(list) => {
if expect.len() != list.len() {
return _plan_err!(
"Expected {} parameters, got {}",
expect.len(),
list.len()
);
}
let iter = expect.iter().zip(list.iter());
for (i, (param_type, value)) in iter.enumerate() {
if *param_type != value.data_type() {
return _plan_err!(
"Expected parameter of type {:?}, got {:?} at index {}",
param_type,
value.data_type(),
i
);
}
}
Ok(())
}
ParamValues::Map(_) => {
Ok(())
}
}
}
pub fn get_placeholders_with_values(&self, id: &str) -> Result<ScalarValue> {
match self {
ParamValues::List(list) => {
if id.is_empty() {
return _plan_err!("Empty placeholder id");
}
let idx = id[1..]
.parse::<usize>()
.map_err(|e| {
_plan_datafusion_err!("Failed to parse placeholder id: {e}")
})?
.checked_sub(1);
let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| {
_plan_datafusion_err!("No value found for placeholder with id {id}")
})?;
Ok(value.clone())
}
ParamValues::Map(map) => {
let name = &id[1..];
let value = map.get(name).ok_or_else(|| {
_plan_datafusion_err!("No value found for placeholder with name {id}")
})?;
Ok(value.clone())
}
}
}
}
impl From<Vec<ScalarValue>> for ParamValues {
fn from(value: Vec<ScalarValue>) -> Self {
Self::List(value)
}
}
impl<K> From<Vec<(K, ScalarValue)>> for ParamValues
where
K: Into<String>,
{
fn from(value: Vec<(K, ScalarValue)>) -> Self {
let value: HashMap<String, ScalarValue> =
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
Self::Map(value)
}
}
impl<K> From<HashMap<K, ScalarValue>> for ParamValues
where
K: Into<String>,
{
fn from(value: HashMap<K, ScalarValue>) -> Self {
let value: HashMap<String, ScalarValue> =
value.into_iter().map(|(k, v)| (k.into(), v)).collect();
Self::Map(value)
}
}