diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs index 4e8591e9..09016298 100644 --- a/src/distributed_planner/distributed_config.rs +++ b/src/distributed_planner/distributed_config.rs @@ -64,6 +64,10 @@ extensions_options! { /// budget will still be admitted (otherwise we would livelock), so the actual peak per /// connection is `worker_connection_buffer_budget_bytes + max_message_size`. pub worker_connection_buffer_budget_bytes: usize, default = 64 * 1024 * 1024 + /// Distributed DataFusion relies on row count estimation in order to infer how many workers + /// should be used in serving the query. Some plans might not implement any kind of row count + /// estimation, and this parameter sets the default estimated row count for those plans. + pub default_estimated_row_count: Option, default = Some(0) /// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to /// estimate how many tasks should be spawned for the [Stage] containing the leaf node. pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default() diff --git a/src/distributed_planner/mod.rs b/src/distributed_planner/mod.rs index a1a7c9fb..f12c684c 100644 --- a/src/distributed_planner/mod.rs +++ b/src/distributed_planner/mod.rs @@ -7,6 +7,7 @@ mod partial_reduce_below_network_shuffles; mod prepare_network_boundaries; mod push_fetch_into_network_coalesce; mod session_state_builder_ext; +mod statistics; mod task_estimator; pub use distributed_config::DistributedConfig; diff --git a/src/distributed_planner/statistics/compute_per_node.rs b/src/distributed_planner/statistics/compute_per_node.rs new file mode 100644 index 00000000..27b31176 --- /dev/null +++ b/src/distributed_planner/statistics/compute_per_node.rs @@ -0,0 +1,1029 @@ +use crate::BroadcastExec; +use crate::execution_plans::ChildrenIsolatorUnionExec; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::common::{JoinSide, Statistics}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::physical_plan::expressions::{Column, Literal}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use datafusion::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, SymmetricHashJoinExec, +}; +use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; +use datafusion::physical_plan::{ExecutionPlan, Partitioning}; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +#[derive(Clone, PartialEq, Eq)] +pub(super) enum Complexity { + /// Constant complexity + Constant(usize), + /// Linear with a specific column from a specific child. + Linear(LinearComplexity), + /// NLogM + Log(Box, Box), + /// N+M + Plus(Box, Box), + /// N*M + Multiply(Box, Box), +} + +#[derive(Clone, PartialEq, Eq)] +pub(super) enum LinearComplexity { + /// Depends on linearly with the input column with the provided index + Column(usize), + /// Depends on linearly with the all the input columns + AllColumns, + /// Depends on linearly with the input column with the provided index from the left child + ColumnFromLeft(usize), + /// Depends on linearly with the all the input columns from the left child + AllColumnsFromLeft, + /// Depends on linearly with the input column with the provided index from the right child + ColumnFromRight(usize), + /// Depends on linearly with the all the input columns from the right child + AllColumnsFromRight, + /// Depends on linearly with the all the output columns + AllOutputColumns, +} + +impl Complexity { + fn log(self, other: Self) -> Self { + match (self, other) { + (Self::Constant(n), Self::Constant(m)) => { + Self::Constant(n * (m as f64).log2() as usize) + } + (s, o) => Self::Log(Box::new(s), Box::new(o)), + } + } + + fn plus(self, other: Self) -> Self { + match (self, other) { + (Self::Constant(n), Self::Constant(m)) => Self::Constant(n + m), + // (A + k1) + k2 = A + (k1 + k2): bubble constants rightward so they can fold + (Self::Plus(a, b), Self::Constant(m)) if matches!(*b, Self::Constant(_)) => { + (*a).plus((*b).plus(Self::Constant(m))) + } + (s, o) if s == o => Self::Constant(2).multiply(s), + (s, o) => Self::Plus(Box::new(s), Box::new(o)), + } + } + + fn multiply(self, other: Self) -> Self { + match (self, other) { + (Self::Constant(n), Self::Constant(m)) => Self::Constant(n * m), + (s, o) => Self::Multiply(Box::new(s), Box::new(o)), + } + } + + /// Computes the total bytes processed given per-child row counts. + /// Returns None if statistics are unavailable for any required input. + pub(super) fn cost( + &self, + output_stat: &Arc, + input_stats: &[Arc], + ) -> Option { + Some(match self { + Self::Constant(v) => *v, + Self::Linear(linear) => match linear { + LinearComplexity::Column(i) => { + let col_stats = &input_stats.first()?.column_statistics; + *col_stats.get(*i)?.byte_size.get_value()? + } + LinearComplexity::AllColumns => { + *input_stats.first()?.total_byte_size.get_value()? + } + LinearComplexity::ColumnFromLeft(i) => { + let col_stats = &input_stats.first()?.column_statistics; + *col_stats.get(*i)?.byte_size.get_value()? + } + LinearComplexity::AllColumnsFromLeft => { + *input_stats.first()?.total_byte_size.get_value()? + } + LinearComplexity::ColumnFromRight(i) => { + let col_stats = &input_stats.last()?.column_statistics; + *col_stats.get(*i)?.byte_size.get_value()? + } + LinearComplexity::AllColumnsFromRight => { + *input_stats.last()?.total_byte_size.get_value()? + } + LinearComplexity::AllOutputColumns => *output_stat.total_byte_size.get_value()?, + }, + Self::Log(n, m) => { + let n = n.cost(output_stat, input_stats)?; + let m = m.cost(output_stat, input_stats)?; + // `ilog2` panics on 0, which happens whenever the logged input has zero estimated + // bytes/rows (e.g. an empty or fully-pruned relation). Flooring at 1 makes log2 + // contribute 0 there, i.e. sorting/merging nothing costs nothing. + n * m.checked_ilog2().unwrap_or(0) as usize + } + Self::Plus(n, m) => { + n.cost(output_stat, input_stats)? + m.cost(output_stat, input_stats)? + } + Self::Multiply(n, m) => { + n.cost(output_stat, input_stats)? * m.cost(output_stat, input_stats)? + } + }) + } +} + +impl Debug for Complexity { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn trim_parenthesis(dbg: &Complexity) -> String { + let s = format!("{dbg:?}"); + if s.starts_with('(') && s.ends_with(')') { + s[1..s.len() - 1].to_string() + } else { + s + } + } + match self { + Self::Constant(v) => write!(f, "{v}"), + Self::Linear(linear) => match linear { + LinearComplexity::Column(i) => write!(f, "Col{i}"), + LinearComplexity::AllColumns => write!(f, "Cols"), + LinearComplexity::ColumnFromLeft(i) => write!(f, "left_Col{i}"), + LinearComplexity::AllColumnsFromLeft => write!(f, "left_Cols"), + LinearComplexity::ColumnFromRight(i) => write!(f, "right_Col{i}"), + LinearComplexity::AllColumnsFromRight => write!(f, "right_Cols"), + LinearComplexity::AllOutputColumns => write!(f, "out_Cols"), + }, + Self::Log(n, m) => write!(f, "{n:?}*Log({m:?})"), + Self::Plus(n, m) => { + if matches!(n.as_ref(), &Self::Plus(_, _)) { + write!(f, "({}+{m:?})", trim_parenthesis(n)) + } else { + write!(f, "({n:?}+{m:?})") + } + } + Self::Multiply(n, m) => { + if matches!(n.as_ref(), &Self::Multiply(_, _)) { + write!(f, "({}*{m:?})", trim_parenthesis(n)) + } else { + write!(f, "({n:?}*{m:?})") + } + } + } + } +} + +/// Calculates what's the cost, expressed as a number, per input row for each input children. +/// +/// The Vec return has equal size to `node.children()`, and determines how many each input needs +/// to be processed +pub(super) fn calculate_compute_complexity(node: &Arc) -> Complexity { + // NestedLoopJoinExec: O(n*m) - evaluates join condition for each pair of rows + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/joins/nested_loop_join.rs + if let Some(node) = node.downcast_ref::() { + // Assume we need to do read all input rows one by one. + let n = Complexity::Linear(LinearComplexity::AllColumnsFromLeft); + let m = Complexity::Linear(LinearComplexity::AllColumnsFromRight); + let mut c = n.multiply(m); + // The join condition is evaluated on every (left, right) pair. We can't express the + // exact per-pair cost (it would be filter_cost * n * m), so we add the filter columns + // as a lower-bound refinement; the O(n*m) materialization term above already dominates. + if let Some(filter) = node.filter() { + c = c.plus(join_filter_complexity(filter)); + } + return c; + } + + // CrossJoinExec: O(n*m) - produces Cartesian product of all row pairs + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/joins/cross_join.rs + if let Some(_node) = node.downcast_ref::() { + // Assume we need to do read all input rows one by one. + let n = Complexity::Linear(LinearComplexity::AllColumnsFromLeft); + let m = Complexity::Linear(LinearComplexity::AllColumnsFromRight); + return n.multiply(m); + } + + // SortExec: O(n log n) - uses lexsort_to_indices, may spill to disk + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/sorts/sort.rs + if let Some(node) = node.downcast_ref::() { + // All the rows will need to be copied one by one. + let mut n = Complexity::Linear(LinearComplexity::AllColumns); + // The sort comparators read every sort key on every row, so even a plain column key costs + // its bytes (a wide UTF8 key is far costlier to compare than an int). + for expr in node.expr() { + n = n.plus(hashed_or_sorted_key_complexity(&expr.expr)) + } + return n.clone().log(n); + } + + // HashJoinExec: hash table build (O(n)) + probe (O(m)) + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/joins/hash_join/exec.rs + if let Some(join) = node.downcast_ref::() { + // Build side (left): concat_batches copies all data (2x read), plus hash table storage, + // plus hashing left join keys. + let mut c = Complexity::Linear(LinearComplexity::AllColumnsFromLeft) + .plus(Complexity::Linear(LinearComplexity::AllColumnsFromLeft)); + for (left_key, _) in join.on() { + c = c.plus(join_key_complexity(left_key, true)); + } + // Probe side (right): read all columns + hash right join keys + c = c.plus(Complexity::Linear(LinearComplexity::AllColumnsFromRight)); + for (_, right_key) in join.on() { + c = c.plus(join_key_complexity(right_key, false)); + } + // Optional join filter evaluated on candidate matches during the probe. + if let Some(filter) = join.filter() { + c = c.plus(join_filter_complexity(filter)); + } + return c; + } + + // SortMergeJoinExec: merge of sorted streams with comparisons + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs + // Unlike hash join, sort-merge doesn't buffer all data or build hash tables. It streams + // through both sorted inputs with O(max_group_size) memory, using partial_cmp comparisons + // (no hashing). Per-row cost is just key comparisons + optional filter evaluation. + if let Some(node) = node.downcast_ref::() { + let mut c: Option = None; + // Left side: compare join keys during merge + for (left_key, _) in node.on() { + let key = join_key_complexity(left_key, true); + c = Some(match c { + Some(existing) => existing.plus(key), + None => key, + }); + } + // Right side: compare join keys during merge + for (_, right_key) in node.on() { + let key = join_key_complexity(right_key, false); + c = Some(match c { + Some(existing) => existing.plus(key), + None => key, + }); + } + // Optional join filter evaluated on matched pairs during the merge. + if let Some(filter) = node.filter() { + let f = join_filter_complexity(filter); + c = Some(match c { + Some(existing) => existing.plus(f), + None => f, + }); + } + return c.unwrap_or(Complexity::Constant(1)); + } + + // SymmetricHashJoinExec: streaming join with hash tables on both sides + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/joins/symmetric_hash_join.rs + // More expensive than HashJoinExec: both sides maintain hash tables, concat_batches + // runs on every incoming batch (not once at end), plus pruning interval computation + // and HashSet tracking for visited rows. + if let Some(node) = node.downcast_ref::() { + // Both sides: concat_batches on every batch (2x read) + hash table + hash keys + let mut c = Complexity::Linear(LinearComplexity::AllColumnsFromLeft) + .plus(Complexity::Linear(LinearComplexity::AllColumnsFromLeft)); + for (left_key, _) in node.on() { + c = c.plus(join_key_complexity(left_key, true)); + } + c = c + .plus(Complexity::Linear(LinearComplexity::AllColumnsFromRight)) + .plus(Complexity::Linear(LinearComplexity::AllColumnsFromRight)); + for (_, right_key) in node.on() { + c = c.plus(join_key_complexity(right_key, false)); + } + // Optional join filter evaluated on matched pairs as batches stream in. + if let Some(filter) = node.filter() { + c = c.plus(join_filter_complexity(filter)); + } + return c; + } + + // Aggregation: hash group-by keys + accumulate aggregate inputs + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/aggregates/mod.rs + if let Some(agg) = node.downcast_ref::() { + // Base: read all input columns for accumulation + let mut c = Complexity::Linear(LinearComplexity::AllColumns); + // Additional: evaluate and hash group-by key expressions + for (expr, _) in agg.group_expr().expr() { + c = c.plus(hashed_or_sorted_key_complexity(expr)); + } + // Per-aggregate filter expressions (e.g. COUNT(*) FILTER (WHERE ...)) + for filter in agg.filter_expr().iter().flatten() { + c = c.plus(expression_complexity(filter)); + } + return c; + } + + // Window functions: buffer partitions, compute aggregates over windows + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/windows/window_agg_exec.rs + if let Some(node) = node.downcast_ref::() { + // Read all input data + evaluate/hash partition key expressions + let mut c = Complexity::Linear(LinearComplexity::AllColumns); + for expr in node.partition_keys() { + c = c.plus(hashed_or_sorted_key_complexity(&expr)); + } + return c; + } + + if let Some(node) = node.downcast_ref::() { + let mut c = Complexity::Linear(LinearComplexity::AllColumns); + for expr in node.partition_keys() { + c = c.plus(hashed_or_sorted_key_complexity(&expr)); + } + return c; + } + + // SortPreservingMergeExec: merges pre-sorted streams with comparisons + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs + // K-way merge: O(N log K) comparisons on sort key expressions + if let Some(node) = node.downcast_ref::() { + // need to copy all rows... + let mut n = Complexity::Linear(LinearComplexity::AllColumns); + // and compare the sort keys on all of them; a plain column key still costs its bytes. + for expr in node.expr() { + n = n.plus(hashed_or_sorted_key_complexity(&expr.expr)) + } + return n; + } + + // FilterExec: evaluates predicate expression per row + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/filter.rs + // Cost depends on predicate complexity - LIKE/Regex operations are expensive + if let Some(node) = node.downcast_ref::() { + // It needs to perform a copy operation just to the output rows... + let n = Complexity::Linear(LinearComplexity::AllOutputColumns); + // ...and predicate evaluation on all input rows. + return n.plus(expression_complexity(node.predicate())); + } + + // ProjectionExec: cost depends on whether it's simple columns or expressions + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/projection.rs + if let Some(node) = node.downcast_ref::() { + let mut n: Option = None; + for expr in node.expr() { + n = if let Some(n) = n { + Some(n.plus(expression_complexity(&expr.expr))) + } else { + Some(expression_complexity(&expr.expr)) + }; + } + return n.unwrap_or(Complexity::Constant(1)); + } + + // RepartitionExec with Hash: computes hash per row + take_arrays + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/repartition/mod.rs + if let Some(node) = node.downcast_ref::() { + // It needs to copy all the data for chunking it to the different output partitions... + let mut n = Complexity::Linear(LinearComplexity::AllColumns); + // And it might need to compute a hash per row based on the provided expressions; hashing a + // plain column key still costs its bytes. + match node.partitioning() { + Partitioning::Hash(expressions, _) => { + for expr in expressions { + n = n.plus(hashed_or_sorted_key_complexity(expr)) + } + } + Partitioning::RoundRobinBatch(_) => {} + Partitioning::UnknownPartitioning(_) => {} + }; + return n; + } + + // DataSourceExec: Produces data, so assume that it's an O(N) operation over all the columns. + if node.is::() { + return Complexity::Linear(LinearComplexity::AllOutputColumns); + } + + // Limit: just counts rows and stops early. + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/limit.rs + if node.is::() || node.is::() { + return Complexity::Constant(1); + } + + // CoalescePartitionsExec: receives batches from partitions, just passes through the record + // batches in a zero copy manner. + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/coalesce_partitions.rs + if node.is::() { + return Complexity::Constant(1); + } + + // BroadcastExec: This node does not do any computation, does not even read the data. + if node.is::() { + return Complexity::Constant(1); + } + + // UnionExec: combines multiple input streams, no processing + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/union.rs + if node.is::() || node.is::() { + return Complexity::Constant(1); + } + + // InterleaveExec: round-robin merging of inputs, no processing + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/union.rs + if node.is::() { + return Complexity::Constant(1); + } + + // EmptyExec: produces no data + // https://github.com/apache/datafusion/blob/branch-52/datafusion/physical-plan/src/empty.rs + if node.is::() { + return Complexity::Constant(1); + } + + // For unknown node types, assume we have to do an O(N) operation over all the rows. + Complexity::Linear(LinearComplexity::AllOutputColumns) +} + +struct BytesPerRow { + processed: Option, + cols_read: Vec, +} + +fn expression_complexity(expression: &Arc) -> Complexity { + _expression_complexity(expression) + .processed + .unwrap_or(Complexity::Constant(1)) +} + +/// Computes the complexity of processing a join key expression, including the cost of +/// reading the leaf columns from the appropriate child (left or right). +/// Unlike `expression_complexity`, this accounts for the cost of hashing/comparing +/// simple column references (which have zero evaluation cost but real I/O cost). +fn join_key_complexity(expression: &Arc, from_left: bool) -> Complexity { + let bpr = _expression_complexity(expression); + let mut result: Option = None; + for col_idx in &bpr.cols_read { + let linear = if from_left { + LinearComplexity::ColumnFromLeft(*col_idx) + } else { + LinearComplexity::ColumnFromRight(*col_idx) + }; + result = Some(match result { + Some(r) => r.plus(Complexity::Linear(linear)), + None => Complexity::Linear(linear), + }); + } + result.unwrap_or(Complexity::Constant(1)) +} + +/// Computes the per-row processing cost of a join filter predicate. +/// +/// A `JoinFilter` is evaluated against an intermediate batch whose columns are described by +/// `column_indices`: intermediate column `i` originates from the left or right child at some +/// original index. `expression_complexity` returns a `Complexity` whose `LinearComplexity::Column` +/// terms reference those intermediate indices, so we remap each of them back onto the +/// corresponding child column before the cost can be evaluated against child statistics. +fn join_filter_complexity(filter: &JoinFilter) -> Complexity { + remap_filter_columns( + expression_complexity(filter.expression()), + filter.column_indices(), + ) +} + +/// Rewrites a `Complexity` built from a join filter's intermediate schema so that every +/// `LinearComplexity::Column` term refers to the left/right child column it actually reads. +/// Columns belonging to neither side (the mark-join sentinel) carry no child bytes, so they +/// collapse to a constant. +fn remap_filter_columns(c: Complexity, column_indices: &[ColumnIndex]) -> Complexity { + match c { + Complexity::Constant(v) => Complexity::Constant(v), + Complexity::Linear(LinearComplexity::Column(i)) => match column_indices.get(i) { + Some(ColumnIndex { + index, + side: JoinSide::Left, + }) => Complexity::Linear(LinearComplexity::ColumnFromLeft(*index)), + Some(ColumnIndex { + index, + side: JoinSide::Right, + }) => Complexity::Linear(LinearComplexity::ColumnFromRight(*index)), + _ => Complexity::Constant(1), + }, + // `expression_complexity` only ever emits `Column` linear terms, but keep the rest + // intact so the remapping stays total. + Complexity::Linear(other) => Complexity::Linear(other), + Complexity::Log(n, m) => { + remap_filter_columns(*n, column_indices).log(remap_filter_columns(*m, column_indices)) + } + Complexity::Plus(n, m) => { + remap_filter_columns(*n, column_indices).plus(remap_filter_columns(*m, column_indices)) + } + Complexity::Multiply(n, m) => remap_filter_columns(*n, column_indices) + .multiply(remap_filter_columns(*m, column_indices)), + } +} + +/// Cost of using an expression as a hashing or comparison key. +/// +/// Unlike `expression_complexity` (which only counts the CPU of *evaluating* a `PhysicalExpr`, +/// so a bare column passthrough is free), this charges the bytes of each underlying leaf column. +/// The hashing/comparison itself is performed by the operator — hash-table build, partition +/// hashing, sort comparators — not by any expression in the plan, and its cost scales with the +/// key's byte width. Use it for group-by keys, hash-partition keys and sort keys. +fn hashed_or_sorted_key_complexity(expression: &Arc) -> Complexity { + let bpr = _expression_complexity(expression); + let mut result: Option = None; + for col_idx in &bpr.cols_read { + result = Some(match result { + Some(r) => r.plus(Complexity::Linear(LinearComplexity::Column(*col_idx))), + None => Complexity::Linear(LinearComplexity::Column(*col_idx)), + }); + } + result.unwrap_or(Complexity::Constant(1)) +} + +fn _expression_complexity(expression: &Arc) -> BytesPerRow { + if let Some(col) = expression.downcast_ref::() { + BytesPerRow { + processed: None, + cols_read: vec![col.index()], + } + } else if expression.is::() { + BytesPerRow { + processed: None, + cols_read: vec![], + } + } else { + // Generic handler for all other expressions: CastExpr, TryCastExpr, CaseExpr, + // InListExpr, IsNullExpr, IsNotNullExpr, NotExpr, NegativeExpr, LikeExpr, + // ScalarFunctionExpr, AsyncFuncExpr, etc. + let mut bytes_per_row = BytesPerRow { + processed: None, + cols_read: vec![], + }; + // This operation processes the result of every child once. We model its per-row cost as + // the sum of (1) the processing already incurred inside each child sub-expression and + // (2) one linear pass over each leaf column feeding the child. A leaf column therefore + // contributes once per operation sitting above it, i.e. its bytes are weighted by its + // depth in the expression tree. Carrying (1) is what keeps nested operations + // (e.g. the `+` in `(a + b) * c`) from being silently dropped. + for child in expression.children() { + let c = _expression_complexity(child); + if let Some(child_processed) = c.processed { + bytes_per_row.processed = Some(match bytes_per_row.processed.take() { + Some(processed) => processed.plus(child_processed), + None => child_processed, + }); + } + for col_read in &c.cols_read { + bytes_per_row.processed = Some(match bytes_per_row.processed.take() { + Some(processed) => { + processed.plus(Complexity::Linear(LinearComplexity::Column(*col_read))) + } + None => Complexity::Linear(LinearComplexity::Column(*col_read)), + }); + } + bytes_per_row.cols_read.extend(&c.cols_read); + } + bytes_per_row + } +} + +#[cfg(test)] +mod tests { + use crate::assert_snapshot; + use crate::distributed_planner::statistics::compute_per_node::calculate_compute_complexity; + use crate::test_utils::plans::TestPlanBuilder; + use datafusion::common::tree_node::{Transformed, TreeNode}; + use datafusion::physical_plan::{ExecutionPlan, displayable}; + use std::cell::RefCell; + use std::sync::Arc; + /* schema for the "weather" table + + MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL] + MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL] + Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL] + Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL] + Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL] + Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL] + Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL] + Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL] + Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL] + RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL] + RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL] + */ + + // DataSourceExec: produces data, modeled as O(N) over all output columns. + #[tokio::test] + async fn data_source_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT "MinTemp" FROM weather"#) + .await; + assert_snapshot!(plan_costs(plan), @"O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet"); + } + + // FilterExec: copies the output rows + evaluates the predicate over the input rows. + #[tokio::test] + async fn filter_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT * FROM weather WHERE "MinTemp" > 5"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O((out_Cols+Col0)) | FilterExec: MinTemp@0 > 5 + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet, predicate=MinTemp@0 > 5, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 5, required_guarantees=[] + "); + } + + // ProjectionExec: cost is the sum of its expressions; plain column passthroughs are free. + #[tokio::test] + async fn projection_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT "MinTemp" + "MaxTemp" AS s FROM weather"#) + .await; + assert_snapshot!(plan_costs(plan), @"O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp@0 + MaxTemp@1 as s], file_type=parquet"); + } + + // AggregateExec: reads all input columns + hashes the group-by keys. + #[tokio::test] + async fn aggregate_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT "RainToday", COUNT(*) FROM weather GROUP BY "RainToday""#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(2) | ProjectionExec: expr=[RainToday@0 as RainToday, count(Int64(1))@1 as count(*)] + O((Cols+Col0)) | AggregateExec: mode=Single, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet + "); + } + + // SortExec: O(n log n) copy + sort-key evaluation. + #[tokio::test] + async fn sort_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT * FROM weather ORDER BY "WindGustDir""#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O((Cols+Col5)*Log((Cols+Col5))) | SortExec: expr=[WindGustDir@5 ASC NULLS LAST], preserve_partitioning=[false] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet, sort_order_for_reorder=[WindGustDir@5 ASC NULLS LAST] + "); + } + + // SortPreservingMergeExec: appears when several pre-sorted partitions are merged. + #[tokio::test] + async fn sort_preserving_merge_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan(r#"SELECT * FROM weather ORDER BY "WindGustDir""#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O((Cols+Col5)) | SortPreservingMergeExec: [WindGustDir@5 ASC NULLS LAST] + O((Cols+Col5)*Log((Cols+Col5))) | SortExec: expr=[WindGustDir@5 ASC NULLS LAST], preserve_partitioning=[true] + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet, sort_order_for_reorder=[WindGustDir@5 ASC NULLS LAST] + "); + } + + // RepartitionExec (Hash): copies all data + hashes the partition keys. + #[tokio::test] + async fn repartition_hash_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan(r#"SELECT "RainToday", COUNT(*) FROM weather GROUP BY "RainToday""#) + .await; + + assert_snapshot!(plan_costs(plan), @r" + O(2) | ProjectionExec: expr=[RainToday@0 as RainToday, count(Int64(1))@1 as count(*)] + O((Cols+Col0)) | AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + O((Cols+Col0)) | RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=3 + O((Cols+Col0)) | AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet + "); + } + + // HashJoinExec: build side (2x read + key hash) + probe side (read + key hash). + #[tokio::test] + async fn hash_join_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a JOIN weather b ON a."RainToday" = b."RainToday" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(((2*left_Cols)+left_Col1+right_Cols+right_Col1)) | HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + "); + } + + // HashJoinExec with a residual filter: the equi-predicate becomes the hash join key while the + // inequality (`a.MinTemp > b.MaxTemp`) becomes a JoinFilter over an intermediate schema, so + // the cost must include the left/right columns the filter reads, not just the join keys. + #[tokio::test] + async fn hash_join_with_filter_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a + JOIN weather b + ON a."RainToday" = b."RainToday" + AND a."MinTemp" > b."MaxTemp" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(((2*left_Cols)+left_Col1+right_Cols+right_Col1+(left_Col0+right_Col0))) | HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], filter=MinTemp@0 > MaxTemp@1, projection=[MinTemp@0, MaxTemp@2] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + "); + } + + // CrossJoinExec: O(n*m) Cartesian product over all columns of both sides. + #[tokio::test] + async fn cross_join_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT a."MinTemp", b."MaxTemp" FROM weather a CROSS JOIN weather b"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O((left_Cols*right_Cols)) | CrossJoinExec + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp], file_type=parquet + "); + } + + // NestedLoopJoinExec: produced when a join has no equi-key, only an inequality filter. + #[tokio::test] + async fn nested_loop_join_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a JOIN weather b ON a."MinTemp" > b."MaxTemp" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(((left_Cols*right_Cols)+(left_Col0+right_Col0))) | NestedLoopJoinExec: join_type=Inner, filter=MinTemp@0 > MaxTemp@1 + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp], file_type=parquet + "); + } + + // SortMergeJoinExec: produced when hash joins are disabled; streams both sorted inputs. + // Requires target_partitions > 1 + repartition_joins + !prefer_hash_join (see physical_planner). + #[tokio::test] + async fn sort_merge_join_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .information_schema(true) + .prefer_hash_joins(false) + .physical_plan( + r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a JOIN weather b ON a."RainToday" = b."RainToday" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(2) | ProjectionExec: expr=[MinTemp@0 as MinTemp, MaxTemp@2 as MaxTemp] + O((left_Col1+right_Col1)) | SortMergeJoinExec: join_type=Inner, on=[(RainToday@1, RainToday@1)] + O((Cols+Col1)*Log((Cols+Col1))) | SortExec: expr=[RainToday@1 ASC], preserve_partitioning=[true] + O((Cols+Col1)) | RepartitionExec: partitioning=Hash([RainToday@1], 4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + O((Cols+Col1)*Log((Cols+Col1))) | SortExec: expr=[RainToday@1 ASC], preserve_partitioning=[true] + O((Cols+Col1)) | RepartitionExec: partitioning=Hash([RainToday@1], 4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet + "); + } + + // BoundedWindowAggExec: window function with an ORDER BY frame (RANK). + #[tokio::test] + async fn bounded_window_agg_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#" + SELECT RANK() OVER (PARTITION BY "RainToday" ORDER BY "MaxTemp") FROM weather + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r#" + O(1) | ProjectionExec: expr=[rank() PARTITION BY [weather.RainToday] ORDER BY [weather.MaxTemp ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank() PARTITION BY [weather.RainToday] ORDER BY [weather.MaxTemp ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] + O((Cols+Col1)) | BoundedWindowAggExec: wdw=[rank() PARTITION BY [weather.RainToday] ORDER BY [weather.MaxTemp ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "rank() PARTITION BY [weather.RainToday] ORDER BY [weather.MaxTemp ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + O((Cols+Col1+Col0)*Log((Cols+Col1+Col0))) | SortExec: expr=[RainToday@1 ASC NULLS LAST, MaxTemp@0 ASC NULLS LAST], preserve_partitioning=[false] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000002.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000000.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, sort_order_for_reorder=[RainToday@1 ASC NULLS LAST, MaxTemp@0 ASC NULLS LAST] + "#); + } + + // WindowAggExec: window aggregate without an ORDER BY (unbounded frame). + #[tokio::test] + async fn window_agg_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#"SELECT SUM("Rainfall") OVER (PARTITION BY "WindGustDir") FROM weather"#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r#" + O(1) | ProjectionExec: expr=[sum(weather.Rainfall) PARTITION BY [weather.WindGustDir] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as sum(weather.Rainfall) PARTITION BY [weather.WindGustDir] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] + O((Cols+Col1)) | WindowAggExec: wdw=[sum(weather.Rainfall) PARTITION BY [weather.WindGustDir] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(weather.Rainfall) PARTITION BY [weather.WindGustDir] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + O((Cols+Col1)*Log((Cols+Col1))) | SortExec: expr=[WindGustDir@1 ASC NULLS LAST], preserve_partitioning=[false] + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[Rainfall, WindGustDir], file_type=parquet, sort_order_for_reorder=[WindGustDir@1 ASC NULLS LAST] + "#); + } + + // UnionExec: combines input streams with no per-row processing. + #[tokio::test] + async fn union_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan( + r#" + SELECT "MinTemp" AS t FROM weather + UNION ALL + SELECT "MaxTemp" AS t FROM weather + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(1) | UnionExec + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp@0 as t], file_type=parquet + O(out_Cols) | DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp@1 as t], file_type=parquet + "); + } + + // AggregateExec with no GROUP BY + CoalescePartitionsExec: the filter prevents the planner from + // answering COUNT(*) straight from parquet metadata, so a real partial aggregate runs per + // partition and is merged through a CoalescePartitionsExec before the single final aggregate. + #[tokio::test] + async fn aggregate_no_group_by_and_coalesce_partitions() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan(r#"SELECT COUNT(*) FROM weather WHERE "MinTemp" > 5"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(1) | ProjectionExec: expr=[count(Int64(1))@0 as count(*)] + O(Cols) | AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] + O(1) | CoalescePartitionsExec + O(Cols) | AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] + O((out_Cols+Col0)) | FilterExec: MinTemp@0 > 5, projection=[] + O(Cols) | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp], file_type=parquet, predicate=MinTemp@0 > 5, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 5, required_guarantees=[] + "); + } + + // GlobalLimitExec: an OFFSET can't be pushed down as a per-partition fetch, so a GlobalLimitExec + // is materialized. (LocalLimitExec shares this exact cost branch but the DF53 planner prefers to + // carry `fetch` on CoalescePartitionsExec rather than emit a separate LocalLimitExec node.) + #[tokio::test] + async fn global_limit_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan(r#"SELECT * FROM weather WHERE "MinTemp" > 5 LIMIT 10 OFFSET 5"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(1) | GlobalLimitExec: skip=5, fetch=10 + O(1) | CoalescePartitionsExec: fetch=15 + O((out_Cols+Col0)) | FilterExec: MinTemp@0 > 5, fetch=15 + O(Cols) | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet, predicate=MinTemp@0 > 5, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 5, required_guarantees=[] + "); + } + + // EmptyExec: an always-false predicate collapses to an empty relation that produces no data. + #[tokio::test] + async fn empty_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT "MinTemp" FROM weather WHERE 1 = 0"#) + .await; + assert_snapshot!(plan_costs(plan), @"O(1) | EmptyExec"); + } + + // RoundRobin RepartitionExec: has no hash keys, so it takes the bare all-columns copy path. + #[tokio::test] + async fn round_robin_repartition_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan(r#"SELECT * FROM weather WHERE "MinTemp" > 5"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O((out_Cols+Col0)) | FilterExec: MinTemp@0 > 5 + O(Cols) | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, MaxTemp, Rainfall, Evaporation, Sunshine, WindGustDir, WindGustSpeed, WindDir9am, WindDir3pm, WindSpeed9am, WindSpeed3pm, Humidity9am, Humidity3pm, Pressure9am, Pressure3pm, Cloud9am, Cloud3pm, Temp9am, Temp3pm, RainToday, RISK_MM, RainTomorrow], file_type=parquet, predicate=MinTemp@0 > 5, pruning_predicate=MinTemp_null_count@1 != row_count@2 AND MinTemp_max@0 > 5, required_guarantees=[] + "); + } + + // HashJoinExec in Partitioned mode: a distinct planner path from CollectLeft. The cost formula + // is the same (build 2x read + key hash, probe read + key hash), now over hash-repartitioned + // inputs rather than a collected left side. + #[tokio::test] + async fn partitioned_hash_join_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .information_schema(true) + // Zero the single-partition thresholds so the planner uses Partitioned mode (hash + // repartition on both sides) instead of collecting the left side. + .hash_join_single_partition_threshold(0) + .hash_join_single_partition_threshold_rows(0) + .physical_plan( + r#" + SELECT a."MinTemp", b."MaxTemp" + FROM weather a JOIN weather b ON a."RainToday" = b."RainToday" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(((2*left_Cols)+left_Col1+right_Cols+right_Col1)) | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2] + O((Cols+Col1)) | RepartitionExec: partitioning=Hash([RainToday@1], 4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet + O((Cols+Col1)) | RepartitionExec: partitioning=Hash([RainToday@1], 4), input_partitions=3 + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ] + "); + } + + // InterleaveExec: unioning two identically hash-partitioned aggregates lets the planner + // interleave the partitions instead of concatenating streams. + #[tokio::test] + async fn interleave_exec() { + let plan = TestPlanBuilder::new() + .target_partitions(4) + .physical_plan( + r#" + SELECT "RainToday" AS k, COUNT(*) AS c FROM weather GROUP BY "RainToday" + UNION ALL + SELECT "RainTomorrow" AS k, COUNT(*) AS c FROM weather GROUP BY "RainTomorrow" + "#, + ) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(1) | InterleaveExec + O(2) | ProjectionExec: expr=[RainToday@0 as k, count(Int64(1))@1 as c] + O((Cols+Col0)) | AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + O((Cols+Col0)) | RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=3 + O((Cols+Col0)) | AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet + O(2) | ProjectionExec: expr=[RainTomorrow@0 as k, count(Int64(1))@1 as c] + O((Cols+Col0)) | AggregateExec: mode=FinalPartitioned, gby=[RainTomorrow@0 as RainTomorrow], aggr=[count(Int64(1))] + O((Cols+Col0)) | RepartitionExec: partitioning=Hash([RainTomorrow@0], 4), input_partitions=3 + O((Cols+Col0)) | AggregateExec: mode=Partial, gby=[RainTomorrow@0 as RainTomorrow], aggr=[count(Int64(1))] + O(out_Cols) | DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainTomorrow], file_type=parquet + "); + } + + // Default fallback for unhandled nodes: `SELECT 1` plans a PlaceholderRowExec, which has no + // dedicated branch and therefore takes the catch-all O(N)-over-output-columns estimate. This is + // intentionally conservative; for a 1-row placeholder the output byte size is tiny anyway. + #[tokio::test] + async fn default_fallback_unhandled_node() { + let plan = TestPlanBuilder::new() + .target_partitions(1) + .physical_plan(r#"SELECT 1"#) + .await; + assert_snapshot!(plan_costs(plan), @r" + O(1) | ProjectionExec: expr=[1 as Int64(1)] + O(out_Cols) | PlaceholderRowExec + "); + } + + // NOTE: BroadcastExec and ChildrenIsolatorUnionExec are inserted only by the distributed + // planner (not by plain DataFusion planning), and SymmetricHashJoinExec requires unbounded + // streaming inputs — none are reachable from these parquet-backed queries. + + fn plan_costs(plan: Arc) -> String { + let mut display = String::new(); + let depth = RefCell::new(0); + plan.transform_down_up( + |plan| { + let indent = " ".repeat(*depth.borrow()); + // `one_line()` renders just this node (with its full config), not its children. + let node = displayable(plan.as_ref()).one_line().to_string(); + display += &format!( + "{indent}O({:?}) | {}\n", + calculate_compute_complexity(&plan), + node.trim_end() + ); + *depth.borrow_mut() += 1; + Ok(Transformed::no(plan)) + }, + |plan| { + *depth.borrow_mut() -= 1; + Ok(Transformed::no(plan)) + }, + ) + .expect("Cannot fail"); + display + } +} diff --git a/src/distributed_planner/statistics/cost.rs b/src/distributed_planner/statistics/cost.rs new file mode 100644 index 00000000..befb8de7 --- /dev/null +++ b/src/distributed_planner/statistics/cost.rs @@ -0,0 +1,29 @@ +use crate::DistributedConfig; +use crate::distributed_planner::statistics::compute_per_node::calculate_compute_complexity; +use crate::distributed_planner::statistics::plan_statistics::plan_statistics; +use datafusion::common::Result; +use datafusion::physical_plan::{ExecutionPlan, Statistics}; +use std::sync::Arc; + +pub(crate) fn calculate_cost( + plan: &Arc, + cfg: &DistributedConfig, +) -> Result { + f(plan, cfg).map(|(cost, _stats)| cost) +} + +fn f(plan: &Arc, d_cfg: &DistributedConfig) -> Result<(usize, Arc)> { + let children = plan.children(); + let mut child_stats = Vec::with_capacity(children.len()); + let mut acc_cost = 0; + for child in children { + let (cost, child_stat) = f(child, d_cfg)?; + acc_cost += cost; + child_stats.push(child_stat); + } + + let stats = plan_statistics(plan, &child_stats, d_cfg)?; + let complexity = calculate_compute_complexity(plan); + acc_cost += complexity.cost(&stats, &child_stats).unwrap_or(0); + Ok((acc_cost, stats)) +} diff --git a/src/distributed_planner/statistics/default_bytes_for_datatype.rs b/src/distributed_planner/statistics/default_bytes_for_datatype.rs new file mode 100644 index 00000000..994fd6ea --- /dev/null +++ b/src/distributed_planner/statistics/default_bytes_for_datatype.rs @@ -0,0 +1,137 @@ +use datafusion::arrow::datatypes::{DataType, IntervalUnit}; + +/// Default data size estimate for variable-width columns when no statistics are available. +/// +/// Reference: Trino's PlanNodeStatsEstimate.java:40 +/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L40 +const DEFAULT_DATA_SIZE_PER_COLUMN: usize = 50; + +/// This function returns the amount of bytes each row is estimated to occupy. +/// +/// The estimation follows Trino's approach for calculating output size per row: +/// - For fixed-width (primitive) types: uses the type's fixed byte width +/// - For variable-width types: uses a default estimate plus offset overhead +/// - Accounts for validity bitmap overhead (1 bit per value, rounded to 1 byte per row) +/// +/// DataFusion has `Statistics::calculate_total_byte_size()` which uses `DataType::primitive_width()`, +/// but it returns `Precision::Absent` (unknown) when encountering any non-primitive type: +/// https://github.com/apache/datafusion/blob/branch-52/datafusion/common/src/stats.rs#L326-L347 +/// +/// For distributed query planning, we need estimates even for variable-width types to make +/// cost-based decisions about data shuffling and task count assignation. This implementation +/// provides estimates for all types following Trino's cost model. +/// +/// Reference: Trino's PlanNodeStatsEstimate.getOutputSizeForSymbol() +/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L89-L114 +pub(super) fn default_bytes_for_datatype(data_type: &DataType) -> usize { + // 1 byte for validity bitmap per row (Arrow uses 1 bit, but we round up for estimation). + // Trino calls this the "is null" boolean array. + // Reference: PlanNodeStatsEstimate.java:98-99 + // https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L98-L99 + const VALIDITY_OVERHEAD: usize = 1; + + // Handle non-primitive types. + // NOTE: The cases below are Arrow-specific adaptations. Trino only distinguishes between + // FixedWidthType and variable-width types, using Integer.BYTES (4) for offsets. + // Reference: PlanNodeStatsEstimate.java:108-109 + // https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L108-L109 + match data_type { + // Primitive types from data_type.primitive_width() + DataType::Int8 => VALIDITY_OVERHEAD + 1, + DataType::Int16 => VALIDITY_OVERHEAD + 2, + DataType::Int32 => VALIDITY_OVERHEAD + 4, + DataType::Int64 => VALIDITY_OVERHEAD + 8, + DataType::UInt8 => VALIDITY_OVERHEAD + 1, + DataType::UInt16 => VALIDITY_OVERHEAD + 2, + DataType::UInt32 => VALIDITY_OVERHEAD + 4, + DataType::UInt64 => VALIDITY_OVERHEAD + 8, + DataType::Float16 => VALIDITY_OVERHEAD + 2, + DataType::Float32 => VALIDITY_OVERHEAD + 4, + DataType::Float64 => VALIDITY_OVERHEAD + 8, + DataType::Timestamp(_, _) => VALIDITY_OVERHEAD + 8, + DataType::Date32 => VALIDITY_OVERHEAD + 4, + DataType::Date64 => VALIDITY_OVERHEAD + 8, + DataType::Time32(_) => VALIDITY_OVERHEAD + 4, + DataType::Time64(_) => VALIDITY_OVERHEAD + 8, + DataType::Duration(_) => VALIDITY_OVERHEAD + 8, + DataType::Interval(IntervalUnit::YearMonth) => VALIDITY_OVERHEAD + 4, + DataType::Interval(IntervalUnit::DayTime) => VALIDITY_OVERHEAD + 8, + DataType::Interval(IntervalUnit::MonthDayNano) => VALIDITY_OVERHEAD + 16, + DataType::Decimal32(_, _) => VALIDITY_OVERHEAD + 4, + DataType::Decimal64(_, _) => VALIDITY_OVERHEAD + 8, + DataType::Decimal128(_, _) => VALIDITY_OVERHEAD + 16, + DataType::Decimal256(_, _) => VALIDITY_OVERHEAD + 32, + // Null type has no data (Arrow-specific) + DataType::Null => 0, + + // Boolean is stored as bits (1/8 byte per value), but we round up (Arrow-specific) + DataType::Boolean => VALIDITY_OVERHEAD + 1, + + // Fixed-size binary: just the fixed size + validity (Arrow-specific) + DataType::FixedSizeBinary(size) => VALIDITY_OVERHEAD + (*size as usize), + + // Fixed-size list: fixed count * element size (Arrow-specific) + DataType::FixedSizeList(field, size) => { + VALIDITY_OVERHEAD + (*size as usize) * default_bytes_for_datatype(field.data_type()) + } + + // Struct: sum of all child field sizes (Arrow-specific) + // Trino would treat ROW types as variable-width + DataType::Struct(fields) => fields + .iter() + .map(|f| default_bytes_for_datatype(f.data_type())) + .sum(), + + // Dictionary-encoded: just the key indices, values are shared across rows (Arrow-specific) + // Trino doesn't have dictionary encoding at the type level + DataType::Dictionary(key_type, _value_type) => default_bytes_for_datatype(key_type), + + // Union: type_id (1 byte) + max child size (Arrow-specific) + DataType::Union(fields, _) => { + let max_child_size = fields + .iter() + .map(|(_, f)| default_bytes_for_datatype(f.data_type())) + .max() + .unwrap_or(0); + 1 + max_child_size + } + + // Run-end encoded: estimate as if it were the value type (Arrow-specific) + // Actual compression depends on data distribution + DataType::RunEndEncoded(_, values) => default_bytes_for_datatype(values.data_type()), + + // Variable-width string/binary types. + // Offset size follows Trino's Integer.BYTES (4 bytes). + // Reference: PlanNodeStatsEstimate.java:109 + DataType::Utf8 | DataType::Binary => { + VALIDITY_OVERHEAD + size_of::() + DEFAULT_DATA_SIZE_PER_COLUMN + } + // Large variants use i64 offsets (Arrow-specific, Trino doesn't have large variants) + DataType::LargeUtf8 | DataType::LargeBinary => { + VALIDITY_OVERHEAD + size_of::() + DEFAULT_DATA_SIZE_PER_COLUMN + } + // View types use 16-byte inline representation (Arrow-specific) + // Reference: https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-view-layout + DataType::Utf8View | DataType::BinaryView => VALIDITY_OVERHEAD + 16, + + // List types (Arrow-specific adaptation) + // Spark assumes 1 element average for collections (SPARK-18853). Trino treats them + // as flat variable-width with 50-byte default. We follow Spark's 1-element assumption + // to avoid massive overestimation (e.g. Map was 605 bytes with 10 elements). + DataType::List(field) => { + VALIDITY_OVERHEAD + size_of::() + default_bytes_for_datatype(field.data_type()) + } + DataType::LargeList(field) => { + VALIDITY_OVERHEAD + size_of::() + default_bytes_for_datatype(field.data_type()) + } + DataType::ListView(field) | DataType::LargeListView(field) => { + VALIDITY_OVERHEAD + 8 + default_bytes_for_datatype(field.data_type()) + } + + // Map type: stored as List> (Arrow-specific) + // Uses same 1-element assumption as List types (following Spark). + DataType::Map(field, _) => { + VALIDITY_OVERHEAD + size_of::() + default_bytes_for_datatype(field.data_type()) + } // Fallback for any other types - use Trino's default + } +} diff --git a/src/distributed_planner/statistics/mod.rs b/src/distributed_planner/statistics/mod.rs new file mode 100644 index 00000000..9a23b794 --- /dev/null +++ b/src/distributed_planner/statistics/mod.rs @@ -0,0 +1,7 @@ +mod compute_per_node; +mod cost; +mod default_bytes_for_datatype; +mod plan_statistics; + +#[allow(unused)] // will be used in a follow-up PR. +pub(crate) use cost::calculate_cost; diff --git a/src/distributed_planner/statistics/plan_statistics.rs b/src/distributed_planner/statistics/plan_statistics.rs new file mode 100644 index 00000000..b1888df1 --- /dev/null +++ b/src/distributed_planner/statistics/plan_statistics.rs @@ -0,0 +1,160 @@ +use crate::DistributedConfig; +use crate::distributed_planner::statistics::default_bytes_for_datatype::default_bytes_for_datatype; +use datafusion::common::stats::Precision; +use datafusion::common::{Statistics, not_impl_err, plan_err}; +use datafusion::config::ConfigOptions; +use datafusion::error::Result; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::execution_plan::CardinalityEffect; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use delegate::delegate; +use itertools::Itertools; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Uses upstream DataFusion stats system with some small overrides. +pub(super) fn plan_statistics( + node: &Arc, + children_stats: &[Arc], + opts: &DistributedConfig, +) -> Result> { + let mut stats = partition_statistics_with_children_override(node, None, children_stats)?; + + // If rows are absent, be conservative and assume that all the rows from all the children + // are going to be returned. + if matches!(stats.num_rows, Precision::Absent) { + let num_rows = children_stats + .iter() + .flat_map(|v| v.num_rows.get_value()) + .sum1::(); + let num_rows = if let Some(num_rows) = num_rows { + num_rows + } else if let Some(default) = opts.default_estimated_row_count { + default + } else { + return plan_err!( + "{} does not provide row stats, and none of its children [{}] provides a row count", + node.name(), + node.children() + .iter() + .map(|v| v.name()) + .collect::>() + .join(", ") + ); + }; + stats.num_rows = Precision::Inexact(num_rows) + } + + let schema = node.schema(); + + for (i, col_stats) in &mut stats.column_statistics.iter_mut().enumerate() { + let rows = stats.num_rows.get_value().unwrap_or(&0); + + // If some of the NDVs are not present in one of the column-level stats, assume the + // worst and use the same as the input number of rows. + if matches!(col_stats.distinct_count, Precision::Absent) { + col_stats.distinct_count = Precision::Inexact(*rows); + } + + // If the per-column byte size stats are not present, estimate the byte size based on the + // data type and the row count. + let Some(dt) = schema.fields.get(i).map(|v| v.data_type()) else { + return plan_err!("Field with index {i} not present in schema: {schema:?}"); + }; + + // If it turns out that we do not have `byte_size` stats, but we do have an estimated number + // of rows, do a best-effort in trying to infer the byte size for each column. + if matches!(col_stats.byte_size, Precision::Absent) { + col_stats.byte_size = Precision::Inexact(default_bytes_for_datatype(dt) * rows) + } + } + + // If bytes are absent, let's just infer them based on the schema and the + // number of rows. + if matches!(stats.total_byte_size, Precision::Absent) { + let mut total_byte_size = 0; + for col_stats in &stats.column_statistics { + total_byte_size += col_stats.byte_size.get_value().unwrap_or(&0); + } + stats.total_byte_size = Precision::Inexact(total_byte_size); + } + + Ok(Arc::new(stats)) +} + +// FIXME: because of limitations the the statistics API on DataFusion, we need to resource to +// this sketchy way of overriding child statistics, as we cannot just provide our own. +// If we don't do this: +// 1. we cannot tell nodes to compute statistics based on the ones we provide. +// 2. we recompute statistics unnecessarily across the plan +// This is tracked by https://github.com/apache/datafusion/issues/20184 upstream, and until +// that one is solved, we need to resource to this wrapper. +fn partition_statistics_with_children_override( + node: &Arc, + partition: Option, + child_stats: &[Arc], +) -> Result { + // DataFusion stats system is not very mature yet. This override layer brings in changes + // that might not have already been released or informed overrides. + let statistics_wrapped_children = child_stats + .iter() + .zip(node.children()) + .map(|(stats, child)| StatisticsWrapper { + inner: Arc::clone(child), + stats: Arc::clone(stats), + }) + .map(|v| Arc::new(v) as _) + .collect(); + + let stats = Arc::clone(node) + .with_new_children(statistics_wrapped_children)? + .partition_statistics(partition)?; + + Ok(stats.as_ref().clone()) +} + +#[derive(Debug)] +struct StatisticsWrapper { + stats: Arc, + inner: Arc, +} + +impl DisplayAs for StatisticsWrapper { + delegate! { + to self.inner { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result; + } + } +} + +impl ExecutionPlan for StatisticsWrapper { + fn partition_statistics(&self, partition: Option) -> Result> { + if partition.is_some() { + return plan_err!("StatisticsWrapper not prepared for partition-specific stats"); + } + Ok(Arc::clone(&self.stats)) + } + + delegate! { + to self.inner { + fn name(&self) -> &str; + fn properties(&self) -> &Arc; + fn maintains_input_order(&self) -> Vec; + fn benefits_from_input_partitioning(&self) -> Vec; + fn children(&self) -> Vec<&Arc>; + fn repartitioned(&self, _target_partitions: usize, _config: &ConfigOptions) -> Result>>; + fn execute(&self, partition: usize, context: Arc) -> Result; + fn supports_limit_pushdown(&self) -> bool; + fn with_fetch(&self, _limit: Option) -> Option>; + fn fetch(&self) -> Option; + fn cardinality_effect(&self) -> CardinalityEffect; + } + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + not_impl_err!("with_new_children not implemented") + } +} diff --git a/src/test_utils/plans.rs b/src/test_utils/plans.rs index c5245697..7d299489 100644 --- a/src/test_utils/plans.rs +++ b/src/test_utils/plans.rs @@ -139,6 +139,9 @@ pub(crate) struct TestPlanBuilder { distributed_partial_reduce: Option, distributed_children_isolator_unions: Option, distributed_max_tasks_per_stage: Option, + prefer_hash_join: Option, + hash_join_single_partition_threshold: Option, + hash_join_single_partition_threshold_rows: Option, } #[cfg(test)] @@ -156,6 +159,9 @@ impl TestPlanBuilder { distributed_partial_reduce: None, distributed_children_isolator_unions: None, distributed_max_tasks_per_stage: None, + prefer_hash_join: None, + hash_join_single_partition_threshold: None, + hash_join_single_partition_threshold_rows: None, } } @@ -192,6 +198,21 @@ impl TestPlanBuilder { self } + pub fn prefer_hash_joins(mut self, prefer_hash_joins: bool) -> Self { + self.prefer_hash_join = Some(prefer_hash_joins); + self + } + + pub fn hash_join_single_partition_threshold(mut self, v: usize) -> Self { + self.hash_join_single_partition_threshold = Some(v); + self + } + + pub fn hash_join_single_partition_threshold_rows(mut self, v: usize) -> Self { + self.hash_join_single_partition_threshold_rows = Some(v); + self + } + pub fn broadcast_joins(mut self, enabled: bool) -> Self { self.broadcast_joins = enabled; self @@ -250,6 +271,21 @@ impl TestPlanBuilder { if let Some(enabled) = self.information_schema { config = config.with_information_schema(enabled); } + if let Some(enabled) = self.prefer_hash_join { + config.options_mut().optimizer.prefer_hash_join = enabled + } + if let Some(value) = self.hash_join_single_partition_threshold { + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold = value + } + if let Some(value) = self.hash_join_single_partition_threshold_rows { + config + .options_mut() + .optimizer + .hash_join_single_partition_threshold_rows = value + } config } @@ -310,6 +346,9 @@ impl Default for TestPlanBuilder { distributed_partial_reduce: None, distributed_children_isolator_unions: None, distributed_max_tasks_per_stage: None, + prefer_hash_join: None, + hash_join_single_partition_threshold: None, + hash_join_single_partition_threshold_rows: None, } } }