use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::ArrayRef;
use arrow::compute::SortOptions;
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserError};
use sqlparser::tokenizer::{Token, TokenWithLocation};
use std::cmp::Ordering;
pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result<Vec<ScalarValue>> {
columns
.iter()
.map(|arr| ScalarValue::try_from_array(arr, idx))
.collect()
}
pub fn compare_rows(
x: &[ScalarValue],
y: &[ScalarValue],
sort_options: &[SortOptions],
) -> Result<Ordering> {
let zip_it = x.iter().zip(y.iter()).zip(sort_options.iter());
for ((lhs, rhs), sort_options) in zip_it {
let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) {
(true, false, false) | (false, true, true) => Ordering::Greater,
(true, false, true) | (false, true, false) => Ordering::Less,
(false, false, _) => if sort_options.descending {
rhs.partial_cmp(lhs)
} else {
lhs.partial_cmp(rhs)
}
.ok_or_else(|| {
DataFusionError::Internal("Column array shouldn't be empty".to_string())
})?,
(true, true, _) => continue,
};
if result != Ordering::Equal {
return Ok(result);
}
}
Ok(Ordering::Equal)
}
pub fn bisect<const SIDE: bool>(
item_columns: &[ArrayRef],
target: &[ScalarValue],
sort_options: &[SortOptions],
) -> Result<usize> {
let low: usize = 0;
let high: usize = item_columns
.get(0)
.ok_or_else(|| {
DataFusionError::Internal("Column array shouldn't be empty".to_string())
})?
.len();
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
find_bisect_point(item_columns, target, compare_fn, low, high)
}
pub fn find_bisect_point<F>(
item_columns: &[ArrayRef],
target: &[ScalarValue],
compare_fn: F,
mut low: usize,
mut high: usize,
) -> Result<usize>
where
F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>,
{
while low < high {
let mid = ((high - low) / 2) + low;
let val = get_row_at_idx(item_columns, mid)?;
if compare_fn(&val, target)? {
low = mid + 1;
} else {
high = mid;
}
}
Ok(low)
}
pub fn linear_search<const SIDE: bool>(
item_columns: &[ArrayRef],
target: &[ScalarValue],
sort_options: &[SortOptions],
) -> Result<usize> {
let low: usize = 0;
let high: usize = item_columns
.get(0)
.ok_or_else(|| {
DataFusionError::Internal("Column array shouldn't be empty".to_string())
})?
.len();
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare_rows(current, target, sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
search_in_slice(item_columns, target, compare_fn, low, high)
}
pub fn search_in_slice<F>(
item_columns: &[ArrayRef],
target: &[ScalarValue],
compare_fn: F,
mut low: usize,
high: usize,
) -> Result<usize>
where
F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>,
{
while low < high {
let val = get_row_at_idx(item_columns, low)?;
if !compare_fn(&val, target)? {
break;
}
low += 1;
}
Ok(low)
}
pub fn quote_identifier(s: &str) -> String {
format!("\"{}\"", s.replace('"', "\"\""))
}
pub(crate) fn parse_identifiers(s: &str) -> Result<Vec<Ident>> {
let dialect = GenericDialect;
let mut parser = Parser::new(&dialect).try_with_sql(s)?;
let mut idents = vec![];
match parser.next_token_no_skip() {
Some(TokenWithLocation {
token: Token::Word(w),
..
}) => idents.push(w.to_ident()),
Some(TokenWithLocation { token, .. }) => {
return Err(ParserError::ParserError(format!(
"Unexpected token in identifier: {token}"
)))?
}
None => {
return Err(ParserError::ParserError(
"Empty input when parsing identifier".to_string(),
))?
}
};
while let Some(TokenWithLocation { token, .. }) = parser.next_token_no_skip() {
match token {
Token::Period => match parser.next_token_no_skip() {
Some(TokenWithLocation {
token: Token::Word(w),
..
}) => idents.push(w.to_ident()),
Some(TokenWithLocation { token, .. }) => {
return Err(ParserError::ParserError(format!(
"Unexpected token following period in identifier: {token}"
)))?
}
None => {
return Err(ParserError::ParserError(
"Trailing period in identifier".to_string(),
))?
}
},
_ => {
return Err(ParserError::ParserError(format!(
"Unexpected token in identifier: {token}"
)))?
}
}
}
Ok(idents)
}
pub(crate) fn parse_identifiers_normalized(s: &str) -> Vec<String> {
parse_identifiers(s)
.unwrap_or_default()
.into_iter()
.map(|id| match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
})
.collect::<Vec<_>>()
}
#[cfg(test)]
mod tests {
use arrow::array::Float64Array;
use std::sync::Arc;
use crate::from_slice::FromSlice;
use crate::ScalarValue;
use crate::ScalarValue::Null;
use super::*;
#[test]
fn test_bisect_linear_left_and_right() -> Result<()> {
let arrays: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.])),
Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 4.0, 5.0])),
Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 10., 11.0])),
Arc::new(Float64Array::from_slice([15.0, 13.0, 8.0, 5., 0.0])),
];
let search_tuple: Vec<ScalarValue> = vec![
ScalarValue::Float64(Some(8.0)),
ScalarValue::Float64(Some(3.0)),
ScalarValue::Float64(Some(8.0)),
ScalarValue::Float64(Some(8.0)),
];
let ords = [
SortOptions {
descending: false,
nulls_first: true,
},
SortOptions {
descending: false,
nulls_first: true,
},
SortOptions {
descending: false,
nulls_first: true,
},
SortOptions {
descending: true,
nulls_first: true,
},
];
let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 3);
let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 3);
Ok(())
}
#[test]
fn vector_ord() {
assert!(vec![1, 0, 0, 0, 0, 0, 0, 1] < vec![1, 0, 0, 0, 0, 0, 0, 2]);
assert!(vec![1, 0, 0, 0, 0, 0, 1, 1] > vec![1, 0, 0, 0, 0, 0, 0, 2]);
assert!(
vec![
ScalarValue::Int32(Some(2)),
Null,
ScalarValue::Int32(Some(0)),
] < vec![
ScalarValue::Int32(Some(2)),
Null,
ScalarValue::Int32(Some(1)),
]
);
assert!(
vec![
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(None),
ScalarValue::Int32(Some(0)),
] < vec![
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(None),
ScalarValue::Int32(Some(1)),
]
);
}
#[test]
fn ord_same_type() {
assert!((ScalarValue::Int32(Some(2)) < ScalarValue::Int32(Some(3))));
}
#[test]
fn test_bisect_linear_left_and_right_diff_sort() -> Result<()> {
let arrays: Vec<ArrayRef> = vec![Arc::new(Float64Array::from_slice([
4.0, 3.0, 2.0, 1.0, 0.0,
]))];
let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(4.0))];
let ords = [SortOptions {
descending: true,
nulls_first: true,
}];
let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 0);
let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 0);
let arrays: Vec<ArrayRef> = vec![Arc::new(Float64Array::from_slice([
4.0, 3.0, 2.0, 1.0, 0.0,
]))];
let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(4.0))];
let ords = [SortOptions {
descending: true,
nulls_first: true,
}];
let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 1);
let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 1);
let arrays: Vec<ArrayRef> =
vec![Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.]))];
let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(7.0))];
let ords = [SortOptions {
descending: false,
nulls_first: true,
}];
let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 1);
let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 1);
let arrays: Vec<ArrayRef> =
vec![Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.]))];
let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(7.0))];
let ords = [SortOptions {
descending: false,
nulls_first: true,
}];
let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
let arrays: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 8.0, 9., 10.])),
Arc::new(Float64Array::from_slice([10.0, 9.0, 8.0, 7.5, 7., 6.])),
];
let search_tuple: Vec<ScalarValue> = vec![
ScalarValue::Float64(Some(8.0)),
ScalarValue::Float64(Some(8.0)),
];
let ords = [
SortOptions {
descending: false,
nulls_first: true,
},
SortOptions {
descending: true,
nulls_first: true,
},
];
let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 3);
let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 3);
let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
assert_eq!(res, 2);
Ok(())
}
#[test]
fn test_parse_identifiers() -> Result<()> {
let s = "CATALOG.\"F(o)o. \"\"bar\".table";
let actual = parse_identifiers(s)?;
let expected = vec![
Ident {
value: "CATALOG".to_string(),
quote_style: None,
},
Ident {
value: "F(o)o. \"bar".to_string(),
quote_style: Some('"'),
},
Ident {
value: "table".to_string(),
quote_style: None,
},
];
assert_eq!(expected, actual);
let s = "";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Empty input when parsing identifier\"))",
format!("{err:?}")
);
let s = "*schema.table";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token in identifier: *\"))",
format!("{err:?}")
);
let s = "schema.table*";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token in identifier: *\"))",
format!("{err:?}")
);
let s = "schema.table.";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Trailing period in identifier\"))",
format!("{err:?}")
);
let s = "schema.*";
let err = parse_identifiers(s).expect_err("didn't fail to parse");
assert_eq!(
"SQL(ParserError(\"Unexpected token following period in identifier: *\"))",
format!("{err:?}")
);
Ok(())
}
}