use arrow::datatypes::DataType;
use cranelift::codegen::ir;
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
use std::fmt::{Display, Formatter};
#[derive(Clone, Debug)]
pub enum Stmt {
IfElse(Box<Expr>, Vec<Stmt>, Vec<Stmt>),
WhileLoop(Box<Expr>, Vec<Stmt>),
Assign(String, Box<Expr>),
Call(String, Vec<Expr>),
Declare(String, JITType),
Store(Box<Expr>, Box<Expr>),
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum TypedLit {
Bool(bool),
Int(i64),
Float(f32),
Double(f64),
}
#[derive(Clone, Debug, PartialEq)]
pub enum Expr {
Literal(Literal),
Identifier(String, JITType),
Binary(BinaryExpr),
Call(String, Vec<Expr>, JITType),
Load(Box<Expr>, JITType),
}
impl Expr {
pub fn get_type(&self) -> JITType {
match self {
Expr::Literal(lit) => lit.get_type(),
Expr::Identifier(_, ty) => *ty,
Expr::Binary(bin) => bin.get_type(),
Expr::Call(_, _, ty) => *ty,
Expr::Load(_, ty) => *ty,
}
}
}
impl Literal {
fn get_type(&self) -> JITType {
match self {
Literal::Parsing(_, ty) => *ty,
Literal::Typed(tl) => tl.get_type(),
}
}
}
impl TypedLit {
fn get_type(&self) -> JITType {
match self {
TypedLit::Bool(_) => BOOL,
TypedLit::Int(_) => I64,
TypedLit::Float(_) => F32,
TypedLit::Double(_) => F64,
}
}
}
impl BinaryExpr {
fn get_type(&self) -> JITType {
match self {
BinaryExpr::Eq(_, _) => BOOL,
BinaryExpr::Ne(_, _) => BOOL,
BinaryExpr::Lt(_, _) => BOOL,
BinaryExpr::Le(_, _) => BOOL,
BinaryExpr::Gt(_, _) => BOOL,
BinaryExpr::Ge(_, _) => BOOL,
BinaryExpr::Add(lhs, _) => lhs.get_type(),
BinaryExpr::Sub(lhs, _) => lhs.get_type(),
BinaryExpr::Mul(lhs, _) => lhs.get_type(),
BinaryExpr::Div(lhs, _) => lhs.get_type(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum BinaryExpr {
Eq(Box<Expr>, Box<Expr>),
Ne(Box<Expr>, Box<Expr>),
Lt(Box<Expr>, Box<Expr>),
Le(Box<Expr>, Box<Expr>),
Gt(Box<Expr>, Box<Expr>),
Ge(Box<Expr>, Box<Expr>),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}
#[derive(Clone, Debug, PartialEq)]
pub enum Literal {
Parsing(String, JITType),
Typed(TypedLit),
}
impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr {
type Error = DataFusionError;
fn try_from(
(value, schema): (datafusion_expr::Expr, DFSchemaRef),
) -> Result<Self, Self::Error> {
match &value {
datafusion_expr::Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left,
op,
right,
}) => {
let op = match op {
datafusion_expr::Operator::Eq => BinaryExpr::Eq,
datafusion_expr::Operator::NotEq => BinaryExpr::Ne,
datafusion_expr::Operator::Lt => BinaryExpr::Lt,
datafusion_expr::Operator::LtEq => BinaryExpr::Le,
datafusion_expr::Operator::Gt => BinaryExpr::Gt,
datafusion_expr::Operator::GtEq => BinaryExpr::Ge,
datafusion_expr::Operator::Plus => BinaryExpr::Add,
datafusion_expr::Operator::Minus => BinaryExpr::Sub,
datafusion_expr::Operator::Multiply => BinaryExpr::Mul,
datafusion_expr::Operator::Divide => BinaryExpr::Div,
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Compiling binary expression {value} not yet supported"
)))
}
};
Ok(Expr::Binary(op(
Box::new((*left.clone(), schema.clone()).try_into()?),
Box::new((*right.clone(), schema).try_into()?),
)))
}
datafusion_expr::Expr::Column(col) => {
let field = schema.field_from_column(col)?;
let ty = field.data_type();
let jit_type = JITType::try_from(ty)?;
Ok(Expr::Identifier(field.qualified_name(), jit_type))
}
datafusion_expr::Expr::Literal(s) => {
let lit = match s {
ScalarValue::Boolean(Some(b)) => TypedLit::Bool(*b),
ScalarValue::Float32(Some(f)) => TypedLit::Float(*f),
ScalarValue::Float64(Some(f)) => TypedLit::Double(*f),
ScalarValue::Int64(Some(i)) => TypedLit::Int(*i),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Compiling Scalar {s} not yet supported in JIT mode"
)))
}
};
Ok(Expr::Literal(Literal::Typed(lit)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Compiling {value} not yet supported"
))),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct JITType {
pub(crate) native: ir::Type,
pub(crate) code: u8,
}
pub const NIL: JITType = JITType {
native: ir::types::INVALID,
code: 0,
};
pub const BOOL: JITType = JITType {
native: ir::types::B1,
code: 0x70,
};
pub const I8: JITType = JITType {
native: ir::types::I8,
code: 0x76,
};
pub const I16: JITType = JITType {
native: ir::types::I16,
code: 0x77,
};
pub const I32: JITType = JITType {
native: ir::types::I32,
code: 0x78,
};
pub const I64: JITType = JITType {
native: ir::types::I64,
code: 0x79,
};
pub const F32: JITType = JITType {
native: ir::types::F32,
code: 0x7b,
};
pub const F64: JITType = JITType {
native: ir::types::F64,
code: 0x7c,
};
pub const R32: JITType = JITType {
native: ir::types::R32,
code: 0x7e,
};
pub const R64: JITType = JITType {
native: ir::types::R64,
code: 0x7f,
};
pub const PTR_SIZE: usize = std::mem::size_of::<usize>();
pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 };
impl TryFrom<&DataType> for JITType {
type Error = DataFusionError;
fn try_from(df_type: &DataType) -> Result<Self, Self::Error> {
match df_type {
DataType::Int64 => Ok(I64),
DataType::Float32 => Ok(F32),
DataType::Float64 => Ok(F64),
DataType::Boolean => Ok(BOOL),
_ => Err(DataFusionError::NotImplemented(format!(
"Compiling Expression with type {df_type} not yet supported in JIT mode"
))),
}
}
}
impl Stmt {
pub fn fmt_ident(&self, ident: usize, f: &mut Formatter) -> std::fmt::Result {
let mut ident_str = String::new();
for _ in 0..ident {
ident_str.push(' ');
}
match self {
Stmt::IfElse(cond, then_stmts, else_stmts) => {
writeln!(f, "{ident_str}if {cond} {{")?;
for stmt in then_stmts {
stmt.fmt_ident(ident + 4, f)?;
}
writeln!(f, "{ident_str}}} else {{")?;
for stmt in else_stmts {
stmt.fmt_ident(ident + 4, f)?;
}
writeln!(f, "{ident_str}}}")
}
Stmt::WhileLoop(cond, stmts) => {
writeln!(f, "{ident_str}while {cond} {{")?;
for stmt in stmts {
stmt.fmt_ident(ident + 4, f)?;
}
writeln!(f, "{ident_str}}}")
}
Stmt::Assign(name, expr) => {
writeln!(f, "{ident_str}{name} = {expr};")
}
Stmt::Call(name, args) => {
writeln!(
f,
"{}{}({});",
ident_str,
name,
args.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)
}
Stmt::Declare(name, ty) => {
writeln!(f, "{ident_str}let {name}: {ty};")
}
Stmt::Store(value, ptr) => {
writeln!(f, "{ident_str}*({ptr}) = {value}")
}
}
}
}
impl Display for Stmt {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.fmt_ident(0, f)?;
Ok(())
}
}
impl Display for Expr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Expr::Literal(l) => write!(f, "{l}"),
Expr::Identifier(name, _) => write!(f, "{name}"),
Expr::Binary(be) => write!(f, "{be}"),
Expr::Call(name, exprs, _) => {
write!(
f,
"{}({})",
name,
exprs
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)
}
Expr::Load(ptr, _) => write!(f, "*({ptr})",),
}
}
}
impl Display for Literal {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Literal::Parsing(str, _) => write!(f, "{str}"),
Literal::Typed(tl) => write!(f, "{tl}"),
}
}
}
impl Display for TypedLit {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TypedLit::Bool(b) => write!(f, "{b}"),
TypedLit::Int(i) => write!(f, "{i}"),
TypedLit::Float(fl) => write!(f, "{fl}"),
TypedLit::Double(d) => write!(f, "{d}"),
}
}
}
impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
BinaryExpr::Eq(lhs, rhs) => write!(f, "{lhs} == {rhs}"),
BinaryExpr::Ne(lhs, rhs) => write!(f, "{lhs} != {rhs}"),
BinaryExpr::Lt(lhs, rhs) => write!(f, "{lhs} < {rhs}"),
BinaryExpr::Le(lhs, rhs) => write!(f, "{lhs} <= {rhs}"),
BinaryExpr::Gt(lhs, rhs) => write!(f, "{lhs} > {rhs}"),
BinaryExpr::Ge(lhs, rhs) => write!(f, "{lhs} >= {rhs}"),
BinaryExpr::Add(lhs, rhs) => write!(f, "{lhs} + {rhs}"),
BinaryExpr::Sub(lhs, rhs) => write!(f, "{lhs} - {rhs}"),
BinaryExpr::Mul(lhs, rhs) => write!(f, "{lhs} * {rhs}"),
BinaryExpr::Div(lhs, rhs) => write!(f, "{lhs} / {rhs}"),
}
}
}
impl std::fmt::Display for JITType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::fmt::Debug for JITType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.code {
0 => write!(f, "nil"),
0x70 => write!(f, "bool"),
0x76 => write!(f, "i8"),
0x77 => write!(f, "i16"),
0x78 => write!(f, "i32"),
0x79 => write!(f, "i64"),
0x7b => write!(f, "f32"),
0x7c => write!(f, "f64"),
0x7e => write!(f, "small_ptr"),
0x7f => write!(f, "ptr"),
_ => write!(f, "unknown"),
}
}
}
impl From<&str> for JITType {
fn from(x: &str) -> Self {
match x {
"bool" => BOOL,
"i8" => I8,
"i16" => I16,
"i32" => I32,
"i64" => I64,
"f32" => F32,
"f64" => F64,
"small_ptr" => R32,
"ptr" => R64,
_ => panic!("unknown type: {x}"),
}
}
}