use futures::StreamExt;
use std::any::Any;
use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::{SessionState, TaskContext};
use crate::logical_plan::Expr;
use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
pub struct MemTable {
schema: SchemaRef,
batches: Vec<Vec<RecordBatch>>,
}
impl MemTable {
pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
if partitions
.iter()
.flatten()
.all(|batches| schema.contains(&batches.schema()))
{
Ok(Self {
schema,
batches: partitions,
})
} else {
Err(DataFusionError::Plan(
"Mismatch between schema and batches".to_string(),
))
}
}
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
ctx: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(ctx, &None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
let tasks = (0..partition_count)
.map(|part_i| {
let task = Arc::new(TaskContext::from(ctx));
let exec = exec.clone();
tokio::spawn(async move {
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
})
})
.collect::<Vec<_>>();
let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());
for task in tasks {
let result = task.await.expect("MemTable::load could not join task")?;
data.push(result);
}
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
if let Some(num_partitions) = output_partitions {
let exec = RepartitionExec::try_new(
Arc::new(exec),
Partitioning::RoundRobinBatch(num_partitions),
)?;
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
let task_ctx = Arc::new(TaskContext::from(ctx));
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
}
output_partitions.push(batches);
}
return MemTable::try_new(schema.clone(), output_partitions);
}
MemTable::try_new(schema.clone(), data)
}
}
#[async_trait]
impl TableProvider for MemTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(MemoryExec::try_new(
&self.batches.clone(),
self.schema(),
projection.clone(),
)?))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use futures::StreamExt;
use std::collections::HashMap;
#[tokio::test]
async fn test_with_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
Arc::new(Int32Array::from(vec![None, None, Some(9)])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider
.scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
assert_eq!("b", batch2.schema().field(1).name());
assert_eq!(2, batch2.num_columns());
Ok(())
}
#[tokio::test]
async fn test_without_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
],
)?;
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
let projection: Vec<usize> = vec![0, 4];
match provider
.scan(&session_ctx.state(), &Some(projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
format!("{:?}", e)
)
}
res => panic!("Scan should failed on invalid projection, got {:?}", res),
};
Ok(())
}
#[test]
fn test_schema_validation_incompatible_column() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => assert_eq!(
"\"Mismatch between schema and batches\"",
format!("{:?}", e)
),
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[test]
fn test_schema_validation_different_column_count() -> Result<()> {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[7, 5, 9])),
],
)?;
match MemTable::try_new(schema2, vec![vec![batch]]) {
Err(DataFusionError::Plan(e)) => assert_eq!(
"\"Mismatch between schema and batches\"",
format!("{:?}", e)
),
_ => panic!("MemTable::new should have failed due to schema mismatch"),
}
Ok(())
}
#[tokio::test]
async fn test_merged_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let mut metadata = HashMap::new();
metadata.insert("foo".to_string(), "bar".to_string());
let schema1 = Schema::new_with_metadata(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
],
metadata,
);
let schema2 = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
],
)?;
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![
Arc::new(Int32Array::from_slice(&[1, 2, 3])),
Arc::new(Int32Array::from_slice(&[4, 5, 6])),
Arc::new(Int32Array::from_slice(&[7, 8, 9])),
],
)?;
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Ok(())
}
}