use std::sync::Arc;
use super::optimizer::PhysicalOptimizerRule;
use crate::physical_plan::Partitioning::*;
use crate::physical_plan::{
repartition::RepartitionExec, with_new_children_if_necessary, ExecutionPlan,
};
use crate::{error::Result, execution::context::SessionConfig};
#[derive(Default)]
pub struct Repartition {}
impl Repartition {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn optimize_partitions(
target_partitions: usize,
plan: Arc<dyn ExecutionPlan>,
can_reorder: bool,
would_benefit: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
let new_plan = if plan.children().is_empty() {
plan
} else {
let can_reorder_children =
match (plan.relies_on_input_order(), plan.maintains_input_order()) {
(true, _) => {
false
}
(false, false) => {
true
}
(false, true) => {
can_reorder
}
};
let children = plan
.children()
.iter()
.map(|child| {
optimize_partitions(
target_partitions,
child.clone(),
can_reorder_children,
plan.benefits_from_input_partitioning(),
)
})
.collect::<Result<_>>()?;
with_new_children_if_necessary(plan, children)?
};
let could_repartition = match new_plan.output_partitioning() {
RoundRobinBatch(x) => x < target_partitions,
UnknownPartitioning(x) => x < target_partitions,
Hash(_, _) => false,
};
if would_benefit && could_repartition && can_reorder {
Ok(Arc::new(RepartitionExec::try_new(
new_plan,
RoundRobinBatch(target_partitions),
)?))
} else {
Ok(new_plan)
}
}
impl PhysicalOptimizerRule for Repartition {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &SessionConfig,
) -> Result<Arc<dyn ExecutionPlan>> {
if config.target_partitions == 1 {
Ok(plan)
} else {
optimize_partitions(config.target_partitions, plan, false, false)
}
}
fn name(&self) -> &str {
"repartition"
}
}
#[cfg(test)]
mod tests {
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use super::*;
use crate::datasource::listing::PartitionedFile;
use crate::datasource::object_store::ObjectStoreUrl;
use crate::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use crate::physical_plan::expressions::{col, PhysicalSortExpr};
use crate::physical_plan::file_format::{FileScanConfig, ParquetExec};
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::union::UnionExec;
use crate::physical_plan::{displayable, Statistics};
fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)]))
}
fn parquet_exec() -> Arc<ParquetExec> {
Arc::new(ParquetExec::new(
FileScanConfig {
object_store_url: ObjectStoreUrl::parse("test:///").unwrap(),
file_schema: schema(),
file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]],
statistics: Statistics::default(),
projection: None,
limit: None,
table_partition_cols: vec![],
},
None,
))
}
fn sort_preserving_merge_exec(
input: Arc<dyn ExecutionPlan>,
) -> Arc<dyn ExecutionPlan> {
let expr = vec![PhysicalSortExpr {
expr: col("c1", &schema()).unwrap(),
options: arrow::compute::SortOptions::default(),
}];
Arc::new(SortPreservingMergeExec::new(expr, input))
}
fn filter_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(FilterExec::try_new(col("c1", &schema()).unwrap(), input).unwrap())
}
fn sort_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let sort_exprs = vec![PhysicalSortExpr {
expr: col("c1", &schema()).unwrap(),
options: SortOptions::default(),
}];
Arc::new(SortExec::try_new(sort_exprs, input).unwrap())
}
fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let exprs = vec![(col("c1", &schema()).unwrap(), "c1".to_string())];
Arc::new(ProjectionExec::try_new(exprs, input).unwrap())
}
fn aggregate(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let schema = schema();
Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![],
input,
schema.clone(),
)
.unwrap(),
),
schema,
)
.unwrap(),
)
}
fn limit_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
None,
Some(100),
))
}
fn limit_exec_with_skip(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Arc::new(GlobalLimitExec::new(
Arc::new(LocalLimitExec::new(input, 100)),
Some(5),
Some(100),
))
}
fn trim_plan_display(plan: &str) -> Vec<&str> {
plan.split('\n')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect()
}
macro_rules! assert_optimized {
($EXPECTED_LINES: expr, $PLAN: expr) => {
let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect();
let optimizer = Repartition {};
let optimized = optimizer
.optimize($PLAN, &SessionConfig::new().with_target_partitions(10))?;
let plan = displayable(optimized.as_ref()).indent().to_string();
let actual_lines = trim_plan_display(&plan);
assert_eq!(
&expected_lines, &actual_lines,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected_lines, actual_lines
);
};
}
#[test]
fn added_repartition_to_single_partition() -> Result<()> {
let plan = aggregate(parquet_exec());
let expected = [
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_deepest_node() -> Result<()> {
let plan = aggregate(filter_exec(parquet_exec()));
let expected = &[
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_unsorted_limit() -> Result<()> {
let plan = limit_exec(filter_exec(parquet_exec()));
let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_unsorted_limit_with_skip() -> Result<()> {
let plan = limit_exec_with_skip(filter_exec(parquet_exec()));
let expected = &[
"GlobalLimitExec: skip=5, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_sorted_limit() -> Result<()> {
let plan = limit_exec(sort_exec(parquet_exec()));
let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"SortExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_sorted_limit_with_filter() -> Result<()> {
let plan = limit_exec(filter_exec(sort_exec(parquet_exec())));
let expected = &[
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
"SortExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_limit() -> Result<()> {
let plan = aggregate(limit_exec(filter_exec(limit_exec(parquet_exec()))));
let expected = &[
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_limit_with_skip() -> Result<()> {
let plan = aggregate(limit_exec_with_skip(filter_exec(limit_exec(
parquet_exec(),
))));
let expected = &[
"AggregateExec: mode=Final, gby=[], aggr=[]",
"AggregateExec: mode=Partial, gby=[], aggr=[]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=5, fetch=100",
"LocalLimitExec: fetch=100",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"GlobalLimitExec: skip=None, fetch=100",
"LocalLimitExec: fetch=100",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_union() -> Result<()> {
let plan = Arc::new(UnionExec::new(vec![parquet_exec(); 5]));
let expected = &[
"UnionExec",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_ignores_sort_preserving_merge() -> Result<()> {
let plan = sort_preserving_merge_exec(parquet_exec());
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_does_not_repartition_transitively() -> Result<()> {
let plan = sort_preserving_merge_exec(projection_exec(parquet_exec()));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_projection() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(projection_exec(parquet_exec())));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_filter() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(filter_exec(parquet_exec())));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
#[test]
fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> {
let plan = sort_preserving_merge_exec(sort_exec(projection_exec(filter_exec(
parquet_exec(),
))));
let expected = &[
"SortPreservingMergeExec: [c1@0 ASC]",
"SortExec: [c1@0 ASC]",
"ProjectionExec: expr=[c1@0 as c1]",
"FilterExec: c1@0",
"RepartitionExec: partitioning=RoundRobinBatch(10)",
"ParquetExec: limit=None, partitions=[x], projection=[c1]",
];
assert_optimized!(expected, plan);
Ok(())
}
}