use crate::error::Result;
use crate::execution::physical_plan::common::RecordBatchIterator;
use crate::execution::physical_plan::{common, ExecutionPlan};
use crate::execution::physical_plan::{BatchIterator, Partition};
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use std::sync::{Arc, Mutex};
use std::thread;
use std::thread::JoinHandle;
pub struct MergeExec {
schema: Arc<Schema>,
partitions: Vec<Arc<dyn Partition>>,
}
impl MergeExec {
pub fn new(schema: Arc<Schema>, partitions: Vec<Arc<dyn Partition>>) -> Self {
MergeExec { schema, partitions }
}
}
impl ExecutionPlan for MergeExec {
fn schema(&self) -> Arc<Schema> {
self.schema.clone()
}
fn partitions(&self) -> Result<Vec<Arc<dyn Partition>>> {
Ok(vec![Arc::new(MergePartition {
schema: self.schema.clone(),
partitions: self.partitions.clone(),
})])
}
}
struct MergePartition {
schema: Arc<Schema>,
partitions: Vec<Arc<dyn Partition>>,
}
impl Partition for MergePartition {
fn execute(&self) -> Result<Arc<Mutex<dyn BatchIterator>>> {
let threads: Vec<JoinHandle<Result<Vec<RecordBatch>>>> = self
.partitions
.iter()
.map(|p| {
let p = p.clone();
thread::spawn(move || {
let it = p.execute()?;
common::collect(it)
})
})
.collect();
let mut combined_results: Vec<Arc<RecordBatch>> = vec![];
for thread in threads {
let join = thread.join().expect("Failed to join thread");
let result = join?;
result
.iter()
.for_each(|batch| combined_results.push(Arc::new(batch.clone())));
}
Ok(Arc::new(Mutex::new(RecordBatchIterator::new(
self.schema.clone(),
combined_results,
))))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::csv::CsvExec;
use crate::test;
#[test]
fn merge() -> Result<()> {
let schema = test::aggr_test_schema();
let num_partitions = 4;
let path =
test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?;
let csv = CsvExec::try_new(&path, schema.clone(), true, None, 1024)?;
let input = csv.partitions()?;
assert_eq!(input.len(), num_partitions);
let merge = MergeExec::new(schema.clone(), input);
let merged = merge.partitions()?;
assert_eq!(merged.len(), 1);
let iter = merged[0].execute()?;
let batches = common::collect(iter)?;
assert_eq!(batches.len(), num_partitions);
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(row_count, 100);
Ok(())
}
}