use arrow::array::{Array, ArrayData};
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use pyo3::exceptions::PyException;
use pyo3::prelude::PyErr;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python};
use crate::{DataFusionError, ScalarValue};
impl From<DataFusionError> for PyErr {
fn from(err: DataFusionError) -> PyErr {
PyException::new_err(err.to_string())
}
}
impl FromPyArrow for ScalarValue {
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = value.py();
let typ = value.getattr("type")?;
let val = value.call_method0("as_py")?;
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [val])?;
let array = factory.call1((args, typ))?;
let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
let scalar = ScalarValue::try_from_array(&array, 0)?;
Ok(scalar)
}
}
impl ToPyArrow for ScalarValue {
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let array = self.to_array()?;
let pyarray = array.to_data().to_pyarrow(py)?;
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
Ok(pyscalar)
}
}
impl<'source> FromPyObject<'source> for ScalarValue {
fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
Self::from_pyarrow_bound(value)
}
}
impl<'source> IntoPyObject<'source> for ScalarValue {
type Target = PyAny;
type Output = Bound<'source, Self::Target>;
type Error = PyErr;
fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
let array = self.to_array()?;
let pyarray = array.to_data().to_pyarrow(py)?;
let pyarray_bound = pyarray.bind(py);
pyarray_bound.call_method1("__getitem__", (0,))
}
}
#[cfg(test)]
mod tests {
use pyo3::ffi::c_str;
use pyo3::prepare_freethreaded_python;
use pyo3::py_run;
use pyo3::types::PyDict;
use super::*;
fn init_python() {
prepare_freethreaded_python();
Python::with_gil(|py| {
if py.run(c_str!("import pyarrow"), None, None).is_err() {
let locals = PyDict::new(py);
py.run(
c_str!(
"import sys; executable = sys.executable; python_path = sys.path"
),
None,
Some(&locals),
)
.expect("Couldn't get python info");
let executable = locals.get_item("executable").unwrap();
let executable: String = executable.extract().unwrap();
let python_path = locals.get_item("python_path").unwrap();
let python_path: Vec<String> = python_path.extract().unwrap();
panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\
HINT: try `pip install pyarrow`\n\
NOTE: On Mac OS, you must compile against a Framework Python \
(default in python.org installers and brew, but not pyenv)\n\
NOTE: On Mac OS, PYO3 might point to incorrect Python library \
path when using virtual environments. Try \
`export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n")
}
})
}
#[test]
fn test_roundtrip() {
init_python();
let example_scalars = vec![
ScalarValue::Boolean(Some(true)),
ScalarValue::Int32(Some(23)),
ScalarValue::Float64(Some(12.34)),
ScalarValue::from("Hello!"),
ScalarValue::Date32(Some(1234)),
];
Python::with_gil(|py| {
for scalar in example_scalars.iter() {
let result = ScalarValue::from_pyarrow_bound(
scalar.to_pyarrow(py).unwrap().bind(py),
)
.unwrap();
assert_eq!(scalar, &result);
}
});
}
#[test]
fn test_py_scalar() -> PyResult<()> {
init_python();
Python::with_gil(|py| -> PyResult<()> {
let scalar_float = ScalarValue::Float64(Some(12.34));
let py_float = scalar_float
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_float, "assert py_float == 12.34");
let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
let py_string = scalar_string
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_string, "assert py_string == 'Hello!'");
Ok(())
})
}
}