Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/distributed_planner/distributed_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ extensions_options! {
/// budget will still be admitted (otherwise we would livelock), so the actual peak per
/// connection is `worker_connection_buffer_budget_bytes + max_message_size`.
pub worker_connection_buffer_budget_bytes: usize, default = 64 * 1024 * 1024
/// Distributed DataFusion relies on row count estimation in order to infer how many workers
/// should be used in serving the query. Some plans might not implement any kind of row count
/// estimation, and this parameter sets the default estimated row count for those plans.
pub default_estimated_row_count: Option<usize>, default = Some(0)
/// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to
/// estimate how many tasks should be spawned for the [Stage] containing the leaf node.
pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default()
Expand Down
1 change: 1 addition & 0 deletions src/distributed_planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod partial_reduce_below_network_shuffles;
mod prepare_network_boundaries;
mod push_fetch_into_network_coalesce;
mod session_state_builder_ext;
mod statistics;
mod task_estimator;

pub use distributed_config::DistributedConfig;
Expand Down
1,029 changes: 1,029 additions & 0 deletions src/distributed_planner/statistics/compute_per_node.rs

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions src/distributed_planner/statistics/cost.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use crate::DistributedConfig;
use crate::distributed_planner::statistics::compute_per_node::calculate_compute_complexity;
use crate::distributed_planner::statistics::plan_statistics::plan_statistics;
use datafusion::common::Result;
use datafusion::physical_plan::{ExecutionPlan, Statistics};
use std::sync::Arc;

pub(crate) fn calculate_cost(
plan: &Arc<dyn ExecutionPlan>,
cfg: &DistributedConfig,
) -> Result<usize> {
f(plan, cfg).map(|(cost, _stats)| cost)
}

fn f(plan: &Arc<dyn ExecutionPlan>, d_cfg: &DistributedConfig) -> Result<(usize, Arc<Statistics>)> {
let children = plan.children();
let mut child_stats = Vec::with_capacity(children.len());
let mut acc_cost = 0;
for child in children {
let (cost, child_stat) = f(child, d_cfg)?;
acc_cost += cost;
child_stats.push(child_stat);
}

let stats = plan_statistics(plan, &child_stats, d_cfg)?;
let complexity = calculate_compute_complexity(plan);
acc_cost += complexity.cost(&stats, &child_stats).unwrap_or(0);
Ok((acc_cost, stats))
}
137 changes: 137 additions & 0 deletions src/distributed_planner/statistics/default_bytes_for_datatype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use datafusion::arrow::datatypes::{DataType, IntervalUnit};

/// Default data size estimate for variable-width columns when no statistics are available.
///
/// Reference: Trino's PlanNodeStatsEstimate.java:40
/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L40
const DEFAULT_DATA_SIZE_PER_COLUMN: usize = 50;

/// This function returns the amount of bytes each row is estimated to occupy.
///
/// The estimation follows Trino's approach for calculating output size per row:
/// - For fixed-width (primitive) types: uses the type's fixed byte width
/// - For variable-width types: uses a default estimate plus offset overhead
/// - Accounts for validity bitmap overhead (1 bit per value, rounded to 1 byte per row)
///
/// DataFusion has `Statistics::calculate_total_byte_size()` which uses `DataType::primitive_width()`,
/// but it returns `Precision::Absent` (unknown) when encountering any non-primitive type:
/// https://github.com/apache/datafusion/blob/branch-52/datafusion/common/src/stats.rs#L326-L347
///
/// For distributed query planning, we need estimates even for variable-width types to make
/// cost-based decisions about data shuffling and task count assignation. This implementation
/// provides estimates for all types following Trino's cost model.
///
/// Reference: Trino's PlanNodeStatsEstimate.getOutputSizeForSymbol()
/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L89-L114
pub(super) fn default_bytes_for_datatype(data_type: &DataType) -> usize {
// 1 byte for validity bitmap per row (Arrow uses 1 bit, but we round up for estimation).
// Trino calls this the "is null" boolean array.
// Reference: PlanNodeStatsEstimate.java:98-99
// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L98-L99
const VALIDITY_OVERHEAD: usize = 1;

// Handle non-primitive types.
// NOTE: The cases below are Arrow-specific adaptations. Trino only distinguishes between
// FixedWidthType and variable-width types, using Integer.BYTES (4) for offsets.
// Reference: PlanNodeStatsEstimate.java:108-109
// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L108-L109
match data_type {
// Primitive types from data_type.primitive_width()
DataType::Int8 => VALIDITY_OVERHEAD + 1,
DataType::Int16 => VALIDITY_OVERHEAD + 2,
DataType::Int32 => VALIDITY_OVERHEAD + 4,
DataType::Int64 => VALIDITY_OVERHEAD + 8,
DataType::UInt8 => VALIDITY_OVERHEAD + 1,
DataType::UInt16 => VALIDITY_OVERHEAD + 2,
DataType::UInt32 => VALIDITY_OVERHEAD + 4,
DataType::UInt64 => VALIDITY_OVERHEAD + 8,
DataType::Float16 => VALIDITY_OVERHEAD + 2,
DataType::Float32 => VALIDITY_OVERHEAD + 4,
DataType::Float64 => VALIDITY_OVERHEAD + 8,
DataType::Timestamp(_, _) => VALIDITY_OVERHEAD + 8,
DataType::Date32 => VALIDITY_OVERHEAD + 4,
DataType::Date64 => VALIDITY_OVERHEAD + 8,
DataType::Time32(_) => VALIDITY_OVERHEAD + 4,
DataType::Time64(_) => VALIDITY_OVERHEAD + 8,
DataType::Duration(_) => VALIDITY_OVERHEAD + 8,
DataType::Interval(IntervalUnit::YearMonth) => VALIDITY_OVERHEAD + 4,
DataType::Interval(IntervalUnit::DayTime) => VALIDITY_OVERHEAD + 8,
DataType::Interval(IntervalUnit::MonthDayNano) => VALIDITY_OVERHEAD + 16,
DataType::Decimal32(_, _) => VALIDITY_OVERHEAD + 4,
DataType::Decimal64(_, _) => VALIDITY_OVERHEAD + 8,
DataType::Decimal128(_, _) => VALIDITY_OVERHEAD + 16,
DataType::Decimal256(_, _) => VALIDITY_OVERHEAD + 32,
// Null type has no data (Arrow-specific)
DataType::Null => 0,

// Boolean is stored as bits (1/8 byte per value), but we round up (Arrow-specific)
DataType::Boolean => VALIDITY_OVERHEAD + 1,

// Fixed-size binary: just the fixed size + validity (Arrow-specific)
DataType::FixedSizeBinary(size) => VALIDITY_OVERHEAD + (*size as usize),

// Fixed-size list: fixed count * element size (Arrow-specific)
DataType::FixedSizeList(field, size) => {
VALIDITY_OVERHEAD + (*size as usize) * default_bytes_for_datatype(field.data_type())
}

// Struct: sum of all child field sizes (Arrow-specific)
// Trino would treat ROW types as variable-width
DataType::Struct(fields) => fields
.iter()
.map(|f| default_bytes_for_datatype(f.data_type()))
.sum(),

// Dictionary-encoded: just the key indices, values are shared across rows (Arrow-specific)
// Trino doesn't have dictionary encoding at the type level
DataType::Dictionary(key_type, _value_type) => default_bytes_for_datatype(key_type),

// Union: type_id (1 byte) + max child size (Arrow-specific)
DataType::Union(fields, _) => {
let max_child_size = fields
.iter()
.map(|(_, f)| default_bytes_for_datatype(f.data_type()))
.max()
.unwrap_or(0);
1 + max_child_size
}

// Run-end encoded: estimate as if it were the value type (Arrow-specific)
// Actual compression depends on data distribution
DataType::RunEndEncoded(_, values) => default_bytes_for_datatype(values.data_type()),

// Variable-width string/binary types.
// Offset size follows Trino's Integer.BYTES (4 bytes).
// Reference: PlanNodeStatsEstimate.java:109
DataType::Utf8 | DataType::Binary => {
VALIDITY_OVERHEAD + size_of::<i32>() + DEFAULT_DATA_SIZE_PER_COLUMN
}
// Large variants use i64 offsets (Arrow-specific, Trino doesn't have large variants)
DataType::LargeUtf8 | DataType::LargeBinary => {
VALIDITY_OVERHEAD + size_of::<i64>() + DEFAULT_DATA_SIZE_PER_COLUMN
}
// View types use 16-byte inline representation (Arrow-specific)
// Reference: https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-view-layout
DataType::Utf8View | DataType::BinaryView => VALIDITY_OVERHEAD + 16,

// List types (Arrow-specific adaptation)
// Spark assumes 1 element average for collections (SPARK-18853). Trino treats them
// as flat variable-width with 50-byte default. We follow Spark's 1-element assumption
// to avoid massive overestimation (e.g. Map<Int,String> was 605 bytes with 10 elements).
DataType::List(field) => {
VALIDITY_OVERHEAD + size_of::<i32>() + default_bytes_for_datatype(field.data_type())
}
DataType::LargeList(field) => {
VALIDITY_OVERHEAD + size_of::<i64>() + default_bytes_for_datatype(field.data_type())
}
DataType::ListView(field) | DataType::LargeListView(field) => {
VALIDITY_OVERHEAD + 8 + default_bytes_for_datatype(field.data_type())
}

// Map type: stored as List<Struct<key, value>> (Arrow-specific)
// Uses same 1-element assumption as List types (following Spark).
DataType::Map(field, _) => {
VALIDITY_OVERHEAD + size_of::<i32>() + default_bytes_for_datatype(field.data_type())
} // Fallback for any other types - use Trino's default
}
}
7 changes: 7 additions & 0 deletions src/distributed_planner/statistics/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod compute_per_node;
mod cost;
mod default_bytes_for_datatype;
mod plan_statistics;

#[allow(unused)] // will be used in a follow-up PR.
pub(crate) use cost::calculate_cost;
160 changes: 160 additions & 0 deletions src/distributed_planner/statistics/plan_statistics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use crate::DistributedConfig;
use crate::distributed_planner::statistics::default_bytes_for_datatype::default_bytes_for_datatype;
use datafusion::common::stats::Precision;
use datafusion::common::{Statistics, not_impl_err, plan_err};
use datafusion::config::ConfigOptions;
use datafusion::error::Result;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::execution_plan::CardinalityEffect;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use delegate::delegate;
use itertools::Itertools;
use std::fmt::Formatter;
use std::sync::Arc;

/// Uses upstream DataFusion stats system with some small overrides.
pub(super) fn plan_statistics(
node: &Arc<dyn ExecutionPlan>,
children_stats: &[Arc<Statistics>],
opts: &DistributedConfig,
) -> Result<Arc<Statistics>> {
let mut stats = partition_statistics_with_children_override(node, None, children_stats)?;

// If rows are absent, be conservative and assume that all the rows from all the children
// are going to be returned.
if matches!(stats.num_rows, Precision::Absent) {
let num_rows = children_stats
.iter()
.flat_map(|v| v.num_rows.get_value())
.sum1::<usize>();
let num_rows = if let Some(num_rows) = num_rows {
num_rows
} else if let Some(default) = opts.default_estimated_row_count {
default
} else {
return plan_err!(
"{} does not provide row stats, and none of its children [{}] provides a row count",
node.name(),
node.children()
.iter()
.map(|v| v.name())
.collect::<Vec<_>>()
.join(", ")
);
};
stats.num_rows = Precision::Inexact(num_rows)
}

let schema = node.schema();

for (i, col_stats) in &mut stats.column_statistics.iter_mut().enumerate() {
let rows = stats.num_rows.get_value().unwrap_or(&0);

// If some of the NDVs are not present in one of the column-level stats, assume the
// worst and use the same as the input number of rows.
if matches!(col_stats.distinct_count, Precision::Absent) {
col_stats.distinct_count = Precision::Inexact(*rows);
}

// If the per-column byte size stats are not present, estimate the byte size based on the
// data type and the row count.
let Some(dt) = schema.fields.get(i).map(|v| v.data_type()) else {
return plan_err!("Field with index {i} not present in schema: {schema:?}");
};

// If it turns out that we do not have `byte_size` stats, but we do have an estimated number
// of rows, do a best-effort in trying to infer the byte size for each column.
if matches!(col_stats.byte_size, Precision::Absent) {
col_stats.byte_size = Precision::Inexact(default_bytes_for_datatype(dt) * rows)
}
}

// If bytes are absent, let's just infer them based on the schema and the
// number of rows.
if matches!(stats.total_byte_size, Precision::Absent) {
let mut total_byte_size = 0;
for col_stats in &stats.column_statistics {
total_byte_size += col_stats.byte_size.get_value().unwrap_or(&0);
}
stats.total_byte_size = Precision::Inexact(total_byte_size);
}

Ok(Arc::new(stats))
}

// FIXME: because of limitations the the statistics API on DataFusion, we need to resource to
// this sketchy way of overriding child statistics, as we cannot just provide our own.
// If we don't do this:
// 1. we cannot tell nodes to compute statistics based on the ones we provide.
// 2. we recompute statistics unnecessarily across the plan
// This is tracked by https://github.com/apache/datafusion/issues/20184 upstream, and until
// that one is solved, we need to resource to this wrapper.
fn partition_statistics_with_children_override(
node: &Arc<dyn ExecutionPlan>,
partition: Option<usize>,
child_stats: &[Arc<Statistics>],
) -> Result<Statistics> {
// DataFusion stats system is not very mature yet. This override layer brings in changes
// that might not have already been released or informed overrides.
let statistics_wrapped_children = child_stats
.iter()
.zip(node.children())
.map(|(stats, child)| StatisticsWrapper {
inner: Arc::clone(child),
stats: Arc::clone(stats),
})
.map(|v| Arc::new(v) as _)
.collect();

let stats = Arc::clone(node)
.with_new_children(statistics_wrapped_children)?
.partition_statistics(partition)?;

Ok(stats.as_ref().clone())
}

#[derive(Debug)]
struct StatisticsWrapper {
stats: Arc<Statistics>,
inner: Arc<dyn ExecutionPlan>,
}

impl DisplayAs for StatisticsWrapper {
delegate! {
to self.inner {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result;
}
}
}

impl ExecutionPlan for StatisticsWrapper {
fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
if partition.is_some() {
return plan_err!("StatisticsWrapper not prepared for partition-specific stats");
}
Ok(Arc::clone(&self.stats))
}

delegate! {
to self.inner {
fn name(&self) -> &str;
fn properties(&self) -> &Arc<PlanProperties>;
fn maintains_input_order(&self) -> Vec<bool>;
fn benefits_from_input_partitioning(&self) -> Vec<bool>;
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>>;
fn repartitioned(&self, _target_partitions: usize, _config: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
fn execute(&self, partition: usize, context: Arc<TaskContext>) -> Result<SendableRecordBatchStream>;
fn supports_limit_pushdown(&self) -> bool;
fn with_fetch(&self, _limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>>;
fn fetch(&self) -> Option<usize>;
fn cardinality_effect(&self) -> CardinalityEffect;
}
}

fn with_new_children(
self: Arc<Self>,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("with_new_children not implemented")
}
}
Loading
Loading