From c5e321e883e4f231839388a43aec8336eaee7a63 Mon Sep 17 00:00:00 2001 From: Gabriel <45515538+gabotechs@users.noreply.github.com> Date: Tue, 19 May 2026 07:51:21 +0200 Subject: [PATCH 1/2] Factor out distributed recursion --- src/common/mod.rs | 2 + src/common/recursion.rs | 678 ++++++++++++++++++ src/coordinator/task_spawner.rs | 133 +--- .../prepare_network_boundaries.rs | 96 +-- 4 files changed, 748 insertions(+), 161 deletions(-) create mode 100644 src/common/recursion.rs diff --git a/src/common/mod.rs b/src/common/mod.rs index 2e245388..de4e60e5 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -2,6 +2,7 @@ mod children_helpers; mod map_last_stream; mod on_drop_stream; mod once_lock; +mod recursion; mod task_context_helpers; mod uuid; @@ -9,5 +10,6 @@ pub(crate) use children_helpers::require_one_child; pub(crate) use map_last_stream::map_last_stream; pub(crate) use on_drop_stream::on_drop_stream; pub(crate) use once_lock::OnceLockResult; +pub(crate) use recursion::TreeNodeExt; pub(crate) use task_context_helpers::task_ctx_with_extension; pub(crate) use uuid::{deserialize_uuid, serialize_uuid}; diff --git a/src/common/recursion.rs b/src/common/recursion.rs new file mode 100644 index 00000000..1e1e0257 --- /dev/null +++ b/src/common/recursion.rs @@ -0,0 +1,678 @@ +use crate::execution_plans::ChildrenIsolatorUnionExec; +use crate::{DistributedTaskContext, NetworkBoundaryExt}; +use datafusion::common::Result; +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion}; +use datafusion::physical_plan::ExecutionPlan; +use std::cell::RefCell; +use std::sync::Arc; + +pub(crate) trait TreeNodeExt { + /// Applies `f` to the node then each of its children, recursively (a + /// top-down, pre-order traversal), propagating the [DistributedTaskContext] correctly + /// across nodes that mutate this context, and ignoring nodes that do not belong to + /// the passed [DistributedTaskContext]. + /// + /// For example, the presence of [ChildrenIsolatorUnionExec] will make this function + /// not recurse into nodes that would be ignored because of the contextual + /// [DistributedTaskContext], and while recursing into its children, the propagated + /// [DistributedTaskContext] will be mutated. + /// + /// The return [`TreeNodeRecursion`] controls the recursion and can cause an early return. + /// + /// This function does not recurse into the input of network boundaries. + fn apply_with_dt_ctx Result>( + &self, + ctx: DistributedTaskContext, + f: F, + ) -> Result; + + /// Recursively rewrite the tree using `f` in a top-down (pre-order) fashion, propagating + /// the appropriate [DistributedTaskContext] based on the presence of nodes that can isolate + /// tasks, like [ChildrenIsolatorUnionExec]. + /// + /// `f` is applied to the node first, and then its children. + #[allow(dead_code)] // Used in follow up work. + fn transform_down_with_dt_ctx< + F: FnMut(Self, DistributedTaskContext) -> Result>, + >( + self, + dt_ctx: DistributedTaskContext, + f: F, + ) -> Result> + where + Self: Sized; + + /// Recursively rewrite the tree using `f` in a bottom-up (post-order) fashion, propagating + /// the appropriate task count based on the presence of nodes that can isolate tasks, like + /// [ChildrenIsolatorUnionExec]. + /// + /// `f` is applied to the node's children first, and then to the node itself. + fn transform_up_with_task_count Result>>( + self, + task_count: usize, + f: F, + ) -> Result> + where + Self: Sized; + + /// Recursively rewrite the tree using `f` in a top-down (pre-order) fashion, propagating + /// the appropriate task count based on the presence of nodes that can isolate tasks, like + /// [ChildrenIsolatorUnionExec]. + /// + /// `f` is applied to the node first, and then its children. + #[allow(dead_code)] // Used in follow up work. + fn transform_down_with_task_count Result>>( + self, + task_count: usize, + f: F, + ) -> Result> + where + Self: Sized; +} + +impl TreeNodeExt for Arc { + fn apply_with_dt_ctx Result>( + &self, + ctx: DistributedTaskContext, + mut f: F, + ) -> Result { + fn recurse< + F: FnMut(&Arc, DistributedTaskContext) -> Result, + >( + plan: &Arc, + ctx: DistributedTaskContext, + f: &mut F, + ) -> Result { + f(plan, ctx.clone())?.visit_children(|| { + if let Some(ciu) = plan.as_any().downcast_ref::() { + // Just recurse to children that will actually get executed by this + // ChildrenIsolatorUnionExec. + ciu.task_idx_map[ctx.task_index].iter().apply_until_stop( + |(child_i, child_ctx)| { + recurse(&ciu.children[*child_i], child_ctx.clone(), f) + }, + ) + } else if plan.is_network_boundary() { + Ok(TreeNodeRecursion::Continue) + } else { + plan.children() + .into_iter() + .apply_until_stop(|child| recurse(child, ctx.clone(), f)) + } + }) + } + recurse(self, ctx, &mut f) + } + + fn transform_down_with_dt_ctx< + F: FnMut(Self, DistributedTaskContext) -> Result>, + >( + self, + dt_ctx: DistributedTaskContext, + mut f: F, + ) -> Result> + where + Self: Sized, + { + // None = skip this subtree (irrelevant CIU child for our task index). + let stack = RefCell::new(vec![Some(dt_ctx)]); + self.transform_down_up( + |node| { + let Some(dt_ctx) = stack.borrow_mut().pop().unwrap() else { + return Ok(Transformed { + data: node, + transformed: false, + tnr: TreeNodeRecursion::Jump, + }); + }; + let transformed = f(node, dt_ctx.clone())?; + if transformed.tnr != TreeNodeRecursion::Continue + || transformed.data.is_network_boundary() + { + return Ok(Transformed { + tnr: TreeNodeRecursion::Jump, + ..transformed + }); + } + let node = &transformed.data; + if let Some(ciu) = node.as_any().downcast_ref::() { + let mut child_ctxs = vec![None; ciu.children.len()]; + for (child_idx, child_ctx) in &ciu.task_idx_map[dt_ctx.task_index] { + child_ctxs[*child_idx] = Some(child_ctx.clone()); + } + stack.borrow_mut().extend(child_ctxs.into_iter().rev()); + } else { + stack + .borrow_mut() + .extend(node.children().iter().map(|_| Some(dt_ctx.clone())).rev()); + } + Ok(transformed) + }, + |node| Ok(Transformed::no(node)), + ) + } + + fn transform_up_with_task_count Result>>( + self, + task_count: usize, + mut f: F, + ) -> Result> { + let stack = RefCell::new(vec![task_count]); + self.transform_down_up( + |node| { + let cur = *stack.borrow().last().unwrap(); + let child_tcs = + if let Some(ciu) = node.as_any().downcast_ref::() { + ciu.child_task_counts() + } else if let Some(nb) = node.as_network_boundary() { + vec![nb.input_stage().task_count(); node.children().len()] + } else { + vec![cur; node.children().len()] + }; + stack.borrow_mut().extend(child_tcs.into_iter().rev()); + Ok(Transformed::no(node)) + }, + |node| { + let tc = stack.borrow_mut().pop().unwrap(); + f(node, tc) + }, + ) + } + + fn transform_down_with_task_count Result>>( + self, + task_count: usize, + mut f: F, + ) -> Result> { + let stack = RefCell::new(vec![task_count]); + self.transform_down_up( + |node| { + let tc = stack.borrow_mut().pop().unwrap(); + let transformed = f(node, tc)?; + if transformed.tnr != TreeNodeRecursion::Continue { + return Ok(transformed); + } + let child_tcs = if let Some(ciu) = transformed + .data + .as_any() + .downcast_ref::() + { + ciu.child_task_counts() + } else if let Some(nb) = transformed.data.as_network_boundary() { + vec![nb.input_stage().task_count(); transformed.data.children().len()] + } else { + vec![tc; transformed.data.children().len()] + }; + stack.borrow_mut().extend(child_tcs.into_iter().rev()); + Ok(transformed) + }, + |node| Ok(Transformed::no(node)), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::NetworkCoalesceExec; + use crate::execution_plans::ChildWeight; + use datafusion::arrow::datatypes::Schema; + use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::union::UnionExec; + use insta::assert_snapshot; + + // ── apply_with_dt_ctx ──────────────────────────────────────────────────────── + + #[test] + fn apply_leaf() { + let plan = leaf(); + assert_snapshot!(trace_apply(&plan, ctx(0, 1)), @"Leaf [ctx=0/1]"); + } + + #[test] + fn apply_top_down_order() { + let plan = union(vec![leaf(), leaf()]); + assert_snapshot!(trace_apply(&plan, ctx(0, 1)), @r" + Union [ctx=0/1] + Leaf [ctx=0/1] + Leaf [ctx=0/1] + "); + } + + #[test] + fn apply_deep_tree() { + let plan = single(single(leaf())); + assert_snapshot!(trace_apply(&plan, ctx(0, 1)), @r" + Single [ctx=0/1] + Single [ctx=0/1] + Leaf [ctx=0/1] + "); + } + + #[test] + fn apply_stop() { + let plan = single(leaf()); + assert_snapshot!( + trace_apply_with(&plan, ctx(0, 1), |_| TreeNodeRecursion::Stop), + @"Single [ctx=0/1] [->stop]", + ); + } + + #[test] + fn apply_jump_skips_subtree() { + let child = single(leaf()); + let plan = single(Arc::clone(&child)); + assert_snapshot!( + trace_apply_with(&plan, ctx(0, 1), |p| { + if Arc::ptr_eq(p, &child) { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue } + }), + @r" + Single [ctx=0/1] + Single [ctx=0/1] [->jump] + "); + } + + #[test] + fn apply_network_boundary() { + let plan = network_boundary(leaf(), 2); + assert_snapshot!(trace_apply(&plan, ctx(0, 1)), @"Network [ctx=0/1]"); + } + + #[test] + fn apply_ciu_routing() { + let plan = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + assert_snapshot!(trace_apply(&plan, ctx(0, 2)), @r" + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(1, 2)), @r" + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + } + + #[test] + fn apply_ciu_context_remapping() { + let plan = ciu(vec![leaf(), leaf(), leaf()], vec![1, 1, 1], 3).unwrap(); + assert_snapshot!(trace_apply(&plan, ctx(0, 3)), @r" + CIU [ctx=0/3] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(1, 3)), @r" + CIU [ctx=1/3] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(2, 3)), @r" + CIU [ctx=2/3] + Leaf [ctx=0/1] + "); + } + + #[test] + fn apply_nested_ciu() { + let inner0 = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + let inner1 = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + let plan = ciu(vec![inner0, inner1], vec![2, 2], 4).unwrap(); + assert_snapshot!(trace_apply(&plan, ctx(0, 4)), @r" + CIU [ctx=0/4] + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(1, 4)), @r" + CIU [ctx=1/4] + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(2, 4)), @r" + CIU [ctx=2/4] + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(3, 4)), @r" + CIU [ctx=3/4] + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + } + + #[test] + fn apply_ciu_multi_children_per_task() { + // 4 children split across 2 tasks → each task sees 2 children + let plan = ciu(vec![leaf(), leaf(), leaf(), leaf()], vec![1, 1, 1, 1], 2).unwrap(); + assert_snapshot!(trace_apply(&plan, ctx(0, 2)), @r" + CIU [ctx=0/2] + Leaf [ctx=0/1] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_apply(&plan, ctx(1, 2)), @r" + CIU [ctx=1/2] + Leaf [ctx=0/1] + Leaf [ctx=0/1] + "); + } + + // ── transform_down_with_dt_ctx ──────────────────────────────────────────── + + #[test] + fn dt_ctx_down_leaf() { + let plan = leaf(); + assert_snapshot!(trace_dt_ctx_down(plan, ctx(2, 4)), @"Leaf [ctx=2/4]"); + } + + #[test] + fn dt_ctx_down_top_down_order() { + let plan = single(leaf()); + assert_snapshot!(trace_dt_ctx_down(plan, ctx(0, 1)), @r" + Single [ctx=0/1] + Leaf [ctx=0/1] + "); + } + + #[test] + fn dt_ctx_down_ctx_propagation() { + let plan = union(vec![leaf(), leaf()]); + assert_snapshot!(trace_dt_ctx_down(plan, ctx(1, 3)), @r" + Union [ctx=1/3] + Leaf [ctx=1/3] + Leaf [ctx=1/3] + "); + } + + #[test] + fn dt_ctx_down_network_boundary() { + let plan = network_boundary(leaf(), 2); + assert_snapshot!(trace_dt_ctx_down(plan, ctx(0, 1)), @"Network [ctx=0/1]"); + } + + #[test] + fn dt_ctx_down_ciu_routing() { + let plan = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + assert_snapshot!(trace_dt_ctx_down(Arc::clone(&plan), ctx(0, 2)), @r" + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_dt_ctx_down(plan, ctx(1, 2)), @r" + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + } + + #[test] + fn dt_ctx_down_nested_ciu() { + let inner0 = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + let inner1 = ciu(vec![leaf(), leaf()], vec![1, 1], 2).unwrap(); + let plan = ciu(vec![inner0, inner1], vec![2, 2], 4).unwrap(); + assert_snapshot!(trace_dt_ctx_down(Arc::clone(&plan), ctx(0, 4)), @r" + CIU [ctx=0/4] + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_dt_ctx_down(Arc::clone(&plan), ctx(1, 4)), @r" + CIU [ctx=1/4] + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_dt_ctx_down(Arc::clone(&plan), ctx(2, 4)), @r" + CIU [ctx=2/4] + CIU [ctx=0/2] + Leaf [ctx=0/1] + "); + assert_snapshot!(trace_dt_ctx_down(Arc::clone(&plan), ctx(3, 4)), @r" + CIU [ctx=3/4] + CIU [ctx=1/2] + Leaf [ctx=0/1] + "); + } + + #[test] + fn dt_ctx_down_jump_skips_subtree() { + let child = single(leaf()); + let root = single(Arc::clone(&child)); + assert_snapshot!(trace_dt_ctx_down_with(root, ctx(0, 1), |p| { + if Arc::ptr_eq(p, &child) { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue } + }), @r" + Single [ctx=0/1] + Single [ctx=0/1] [->jump] + "); + } + + // ── transform_up_with_task_count ────────────────────────────────────────── + + #[test] + fn tc_up_leaf() { + let plan = leaf(); + assert_snapshot!(trace_tc_up(plan, 7), @"Leaf [tc=7]"); + } + + #[test] + fn tc_up_bottom_up_order() { + let plan = single(leaf()); + assert_snapshot!(trace_tc_up(plan, 1), @r" + Leaf [tc=1] + Single [tc=1] + "); + } + + #[test] + fn tc_up_uniform_task_count() { + let plan = union(vec![leaf(), leaf()]); + assert_snapshot!(trace_tc_up(plan, 5), @r" + Leaf [tc=5] + Leaf [tc=5] + Union [tc=5] + "); + } + + #[test] + fn tc_up_ciu_per_child_task_counts() { + let plan = ciu(vec![leaf(), leaf()], vec![2, 3], 5).unwrap(); + assert_snapshot!(trace_tc_up(plan, 5), @r" + Leaf [tc=2] + Leaf [tc=3] + CIU [tc=5] + "); + } + + #[test] + fn tc_up_network_boundary_changes_tc() { + // Nodes inside the NB run at the producer task count (2), not the outer count (5) + let plan = single(network_boundary(leaf(), 2)); + assert_snapshot!(trace_tc_up(plan, 5), @r" + Leaf [tc=2] + Network [tc=5] + Single [tc=5] + "); + } + + // ── transform_down_with_task_count ──────────────────────────────────────── + + #[test] + fn tc_down_leaf() { + let plan = leaf(); + assert_snapshot!(trace_tc_down(plan, 7), @"Leaf [tc=7]"); + } + + #[test] + fn tc_down_top_down_order() { + let plan = single(leaf()); + assert_snapshot!(trace_tc_down(plan, 1), @r" + Single [tc=1] + Leaf [tc=1] + "); + } + + #[test] + fn tc_down_uniform_task_count() { + let plan = union(vec![leaf(), leaf()]); + assert_snapshot!(trace_tc_down(plan, 5), @r" + Union [tc=5] + Leaf [tc=5] + Leaf [tc=5] + "); + } + + #[test] + fn tc_down_ciu_per_child_task_counts() { + let plan = ciu(vec![leaf(), leaf()], vec![2, 3], 5).unwrap(); + assert_snapshot!(trace_tc_down(plan, 5), @r" + CIU [tc=5] + Leaf [tc=2] + Leaf [tc=3] + "); + } + + #[test] + fn tc_down_network_boundary_changes_tc() { + let plan = single(network_boundary(leaf(), 2)); + assert_snapshot!(trace_tc_down(plan, 5), @r" + Single [tc=5] + Network [tc=5] + Leaf [tc=2] + "); + } + + // ── helpers: plan builders ──────────────────────────────────────────────── + + fn leaf() -> Arc { + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))) + } + + fn single(child: Arc) -> Arc { + Arc::new(CoalescePartitionsExec::new(child)) + } + + fn union(children: Vec>) -> Arc { + UnionExec::try_new(children).unwrap() + } + + fn network_boundary( + input: Arc, + producer_tasks: usize, + ) -> Arc { + Arc::new(NetworkCoalesceExec::try_new(input, producer_tasks, 1).unwrap()) + } + + fn ciu( + children: Vec>, + child_task_counts: Vec, + task_count: usize, + ) -> Result> { + Ok(Arc::new( + ChildrenIsolatorUnionExec::from_children_and_weights( + children, + child_task_counts + .iter() + .map(|v| ChildWeight::desired(*v as f64)), + task_count, + )?, + )) + } + + fn ctx(task_index: usize, task_count: usize) -> DistributedTaskContext { + DistributedTaskContext { + task_index, + task_count, + } + } + + // ── helpers: trace renderers ────────────────────────────────────────────── + + fn plan_label(p: &Arc) -> &'static str { + if p.as_any().is::() { + "Leaf" + } else if p.as_any().is::() { + "Single" + } else if p.as_any().is::() { + "Union" + } else if p.as_any().is::() { + "CIU" + } else if p.as_any().is::() { + "Network" + } else { + "?" + } + } + + fn trace_apply(root: &Arc, dt_ctx: DistributedTaskContext) -> String { + trace_apply_with(root, dt_ctx, |_| TreeNodeRecursion::Continue) + } + + fn trace_apply_with) -> TreeNodeRecursion>( + root: &Arc, + dt_ctx: DistributedTaskContext, + mut decide: F, + ) -> String { + let mut lines = vec![]; + root.apply_with_dt_ctx(dt_ctx, |p, c| { + let rec = decide(p); + let suffix = match rec { + TreeNodeRecursion::Continue => "", + TreeNodeRecursion::Jump => " [->jump]", + TreeNodeRecursion::Stop => " [->stop]", + }; + lines.push(format!( + "{} [ctx={}/{}]{suffix}", + plan_label(p), + c.task_index, + c.task_count, + )); + Ok(rec) + }) + .unwrap(); + lines.join("\n") + } + + fn trace_dt_ctx_down(root: Arc, dt_ctx: DistributedTaskContext) -> String { + trace_dt_ctx_down_with(root, dt_ctx, |_| TreeNodeRecursion::Continue) + } + + fn trace_dt_ctx_down_with) -> TreeNodeRecursion>( + root: Arc, + dt_ctx: DistributedTaskContext, + mut decide: F, + ) -> String { + let mut lines = vec![]; + root.transform_down_with_dt_ctx(dt_ctx, |p, c| { + let rec = decide(&p); + let suffix = match rec { + TreeNodeRecursion::Continue => "", + TreeNodeRecursion::Jump => " [->jump]", + TreeNodeRecursion::Stop => " [->stop]", + }; + lines.push(format!( + "{} [ctx={}/{}]{suffix}", + plan_label(&p), + c.task_index, + c.task_count, + )); + Ok(Transformed { + data: p, + transformed: false, + tnr: rec, + }) + }) + .unwrap(); + lines.join("\n") + } + + fn trace_tc_up(root: Arc, tc: usize) -> String { + let mut lines = vec![]; + root.transform_up_with_task_count(tc, |p, tc| { + lines.push(format!("{} [tc={tc}]", plan_label(&p))); + Ok(Transformed::no(p)) + }) + .unwrap(); + lines.join("\n") + } + + fn trace_tc_down(root: Arc, tc: usize) -> String { + let mut lines = vec![]; + root.transform_down_with_task_count(tc, |p, tc| { + lines.push(format!("{} [tc={tc}]", plan_label(&p))); + Ok(Transformed::no(p)) + }) + .unwrap(); + lines.join("\n") + } +} diff --git a/src/coordinator/task_spawner.rs b/src/coordinator/task_spawner.rs index f960da45..c959a75b 100644 --- a/src/coordinator/task_spawner.rs +++ b/src/coordinator/task_spawner.rs @@ -1,7 +1,6 @@ -use crate::common::{serialize_uuid, task_ctx_with_extension}; +use crate::common::{TreeNodeExt, serialize_uuid, task_ctx_with_extension}; use crate::config_extension_ext::get_config_extension_propagation_headers; use crate::coordinator::MetricsStore; -use crate::execution_plans::ChildrenIsolatorUnionExec; use crate::passthrough_headers::get_passthrough_headers; use crate::protobuf::tonic_status_to_datafusion_error; use crate::stage::LocalStage; @@ -15,6 +14,7 @@ use crate::{ use datafusion::common::Result; use datafusion::common::instant::Instant; use datafusion::common::runtime::JoinSet; +use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::common::{DataFusionError, exec_datafusion_err}; use datafusion::execution::TaskContext; use datafusion::physical_expr_common::metrics::{ @@ -24,7 +24,6 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; -use futures::future::BoxFuture; use http::Extensions; use prost::Message; use std::fmt::Display; @@ -120,55 +119,23 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { UnboundedReceiver, )> { let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?; - /// Searches recursively for nodes exposing [crate::WorkUnitFeed]s, and executes their - /// feeds, keeping into account that some of them might be executed within a - /// [ChildrenIsolatorUnionExec] context. This means that some of them are irrelevant for - /// the current [task_i], and we don't want to account for them here. - /// - /// It places in the `out` argument all the collected [WorkUnitFeedDeclaration]s necessary - /// for sending the plan. - fn gather_work_unit_feed_declarations( - plan: &Arc, - ctx: DistributedTaskContext, - d_cfg: &DistributedConfig, - out: &mut Vec, - ) { - let wuf = if let Some(wuf) = d_cfg - .__private_work_unit_feed_registry - .get_work_unit_feed(plan) - { - wuf - } else if let Some(ciu) = plan.as_any().downcast_ref::() { - for (child_i, ctx) in &ciu.task_idx_map[ctx.task_index] { - let child = &ciu.children[*child_i]; - // Just recurse to children that will actually get executed by this - // ChildrenIsolatorUnionExec. - gather_work_unit_feed_declarations(child, ctx.clone(), d_cfg, out) - } - return; - } else { - for child in plan.children() { - gather_work_unit_feed_declarations(child, ctx.clone(), d_cfg, out) - } - return; - }; + let wuf_registry = &d_cfg.__private_work_unit_feed_registry; - out.push(WorkUnitFeedDeclaration { + let mut work_unit_feed_declarations = vec![]; + let d_ctx = DistributedTaskContext { + task_index: task_i, + task_count: self.task_count, + }; + self.plan.apply_with_dt_ctx(d_ctx, |plan, _| { + let Some(wuf) = wuf_registry.get_work_unit_feed(plan) else { + return Ok(TreeNodeRecursion::Continue); + }; + work_unit_feed_declarations.push(WorkUnitFeedDeclaration { id: serialize_uuid(&wuf.id()), partitions: plan.properties().partitioning.partition_count() as u64, - }) - } - - let mut work_unit_feed_declarations = vec![]; - gather_work_unit_feed_declarations( - self.plan, - DistributedTaskContext { - task_index: task_i, - task_count: self.task_count, - }, - d_cfg, - &mut work_unit_feed_declarations, - ); + }); + Ok(TreeNodeRecursion::Continue) + })?; let task_key = TaskKey { query_id: serialize_uuid(&self.query_id), @@ -273,49 +240,26 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { tx: UnboundedSender, ) -> Result<()> { let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?; - /// Recurses into the plan looking for [WorkUnitFeedExec] nodes that should be handled by - /// the provided [task_i]. Because of [ChildrenIsolatorUnionExec]s being present in the - /// plan, there might be some present [WorkUnitFeedExec] that will not necessarily get - /// executed, so we don't want to stream any [WorkUnit] to those. - /// - /// It places in `out` the list of futures that should be polled for driving the [WorkUnit] - /// network streams forward. - fn gather_work_unit_feed_tasks( - plan: &Arc, - dt_ctx: DistributedTaskContext, - t_ctx: &Arc, - d_cfg: &DistributedConfig, - tx: &UnboundedSender, - out: &mut Vec>>, - ) -> Result<()> { - let wuf = if let Some(wuf) = d_cfg - .__private_work_unit_feed_registry - .get_work_unit_feed(plan) - { - wuf - } else if let Some(ciu) = plan.as_any().downcast_ref::() { - for (child_i, dt_ctx) in &ciu.task_idx_map[dt_ctx.task_index] { - // Just recurse to children that will actually get executed by this - // ChildrenIsolatorUnionExec. - let child = &ciu.children[*child_i]; - gather_work_unit_feed_tasks(child, dt_ctx.clone(), t_ctx, d_cfg, tx, out)?; - } - return Ok(()); - } else { - for child in plan.children() { - gather_work_unit_feed_tasks(child, dt_ctx.clone(), t_ctx, d_cfg, tx, out)? - } - return Ok(()); + let wuf_registry = &d_cfg.__private_work_unit_feed_registry; + + let d_ctx = DistributedTaskContext { + task_index: task_i, + task_count: self.task_count, + }; + let mut futures = vec![]; + self.plan.apply_with_dt_ctx(d_ctx, |plan, d_ctx| { + let Some(wuf) = wuf_registry.get_work_unit_feed(plan) else { + return Ok(TreeNodeRecursion::Continue); }; let partitions = plan.properties().partitioning.partition_count(); - let start_partition = partitions * dt_ctx.task_index; + let start_partition = partitions * d_ctx.task_index; let end_partition = start_partition + partitions; let dist_feed_ctx = DistributedWorkUnitFeedContext { - fan_out_tasks: dt_ctx.task_count, + fan_out_tasks: d_ctx.task_count, }; - let t_ctx = Arc::new(task_ctx_with_extension(t_ctx, dist_feed_ctx)); + let t_ctx = Arc::new(task_ctx_with_extension(&ctx, dist_feed_ctx)); // There should be as many partition feeds as [num partitions] * [num tasks], so that // each task index handles a non-overlapping set of partition feeds. @@ -326,7 +270,7 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { let mut work_unit_feed = wuf.feed(feed_idx, Arc::clone(&t_ctx))?; let tx = tx.clone(); let id = wuf.id(); - out.push(Box::pin(async move { + futures.push(Box::pin(async move { // At this point, the partition feed contains a stream of decoded messages, // so they must be encoded in order to send them over the wire. while let Some(data_or_err) = work_unit_feed.next().await { @@ -346,21 +290,8 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { Ok::<_, DataFusionError>(()) })); } - Ok(()) - } - - let mut futures = vec![]; - gather_work_unit_feed_tasks( - self.plan, - DistributedTaskContext { - task_index: task_i, - task_count: self.task_count, - }, - &ctx, - d_cfg, - &tx, - &mut futures, - )?; + Ok(TreeNodeRecursion::Continue) + })?; self.join_set.spawn(async move { futures::future::try_join_all(futures).await?; Ok(()) diff --git a/src/distributed_planner/prepare_network_boundaries.rs b/src/distributed_planner/prepare_network_boundaries.rs index f00360ff..26c04544 100644 --- a/src/distributed_planner/prepare_network_boundaries.rs +++ b/src/distributed_planner/prepare_network_boundaries.rs @@ -1,8 +1,9 @@ +use crate::common::TreeNodeExt; use crate::distributed_planner::network_boundary::network_boundary_scale_input; -use crate::execution_plans::ChildrenIsolatorUnionExec; use crate::stage::LocalStage; use crate::{NetworkBoundaryExt, Stage}; use datafusion::common::Result; +use datafusion::common::tree_node::Transformed; use datafusion::physical_plan::ExecutionPlan; use std::sync::Arc; use uuid::Uuid; @@ -14,68 +15,43 @@ use uuid::Uuid; pub(crate) fn prepare_network_boundaries( plan: Arc, ) -> Result> { - prepare(plan, 1, Uuid::new_v4(), &mut 1) -} + let mut stage_id = 1; + let query_id = Uuid::new_v4(); -fn prepare( - plan: Arc, - consumer_task_count: usize, - query_id: Uuid, - stage_id: &mut usize, -) -> Result> { - // A `ChildrenIsolatorUnionExec` runs each child in only a subset of the surrounding stage's - // tasks. Boundaries living inside one of those children must scale by that child's own task - // count, not the full stage's task count, otherwise hash partitioning is over-scaled and - // data ends up routed to partitions no consumer reads. - if let Some(ciu) = plan.as_any().downcast_ref::() { - let children_and_task_count = ciu.children().into_iter().zip(ciu.child_task_counts()); - let new_children = children_and_task_count - .map(|(child, per_child_count)| { - prepare(Arc::clone(child), per_child_count, query_id, stage_id) - }) - .collect::>>()?; - return plan.with_new_children(new_children); - } + let transformed = plan.transform_up_with_task_count(1, |plan, task_count| { + let Some(nb) = plan.as_network_boundary() else { + return Ok(Transformed::no(plan)); + }; + // If the input stage is already remote, it was already sent over the network, so nothing else + // we can do here. + let Stage::Local(input_stage) = nb.input_stage() else { + return Ok(Transformed::no(plan)); + }; - let Some(nb) = plan.as_network_boundary() else { - let new_children = plan - .children() - .into_iter() - .map(|c| prepare(Arc::clone(c), consumer_task_count, query_id, stage_id)) - .collect::>>()?; - return plan.with_new_children(new_children); - }; + // 1) If there are both 1 producer and consumer tasks, optimize the network boundary out. + if task_count == 1 && input_stage.tasks == 1 { + return Ok(Transformed::yes(Arc::clone(&input_stage.plan))); + } - // If the input stage is already remote, it was already sent over the network, so nothing else - // we can do here. - let Stage::Local(local_stage) = nb.input_stage() else { - return Ok(plan); - }; - let producer_task_count = local_stage.tasks; - let new_input = prepare( - Arc::clone(&local_stage.plan), - producer_task_count, - query_id, - stage_id, - )?; - // 1) If there are both 1 producer and consumer tasks, optimize the network boundary out. - if consumer_task_count == 1 && producer_task_count == 1 { - return Ok(new_input); - } - let consumer_partitions = nb.properties().partitioning.partition_count(); + // 2) Scale up the head node of the input stage in order to account for the amount of partition + // and consumer count above it. + let plan = network_boundary_scale_input( + Arc::clone(&input_stage.plan), + nb.properties().partitioning.partition_count(), + task_count, + )?; - // 2) Scale up the head node of the input stage in order to account for the amount of partition - // and consumer count above it. - let plan = network_boundary_scale_input(new_input, consumer_partitions, consumer_task_count)?; + // 3) Make sure the input stage can be uniquely identified with a stage index and query id. + // If there were already some `query_id` and `num` that's fine. + let nb = nb.with_input_stage(Stage::Local(LocalStage { + query_id, + num: stage_id, + plan, + tasks: input_stage.tasks, + }))?; + stage_id += 1; + Ok(Transformed::yes(nb)) + })?; - // 3) Make sure the input stage can be uniquely identified with a stage index and query id. - // If there were already some `query_id` and `num` that's fine. - let nb = nb.with_input_stage(Stage::Local(LocalStage { - query_id, - num: *stage_id, - plan, - tasks: local_stage.tasks, - })); - *stage_id += 1; - nb + Ok(transformed.data) } From b0078f5065638eff42358aaf3296c972f486d3c3 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Wed, 27 May 2026 08:40:18 +0200 Subject: [PATCH 2/2] Address feedback --- src/common/recursion.rs | 128 +++++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 42 deletions(-) diff --git a/src/common/recursion.rs b/src/common/recursion.rs index 1e1e0257..2abbb0d7 100644 --- a/src/common/recursion.rs +++ b/src/common/recursion.rs @@ -14,8 +14,8 @@ pub(crate) trait TreeNodeExt { /// /// For example, the presence of [ChildrenIsolatorUnionExec] will make this function /// not recurse into nodes that would be ignored because of the contextual - /// [DistributedTaskContext], and while recursing into its children, the propagated - /// [DistributedTaskContext] will be mutated. + /// [DistributedTaskContext], and while recursing into its children, a different + /// [DistributedTaskContext] will be passed. /// /// The return [`TreeNodeRecursion`] controls the recursion and can cause an early return. /// @@ -43,8 +43,9 @@ pub(crate) trait TreeNodeExt { Self: Sized; /// Recursively rewrite the tree using `f` in a bottom-up (post-order) fashion, propagating - /// the appropriate task count based on the presence of nodes that can isolate tasks, like - /// [ChildrenIsolatorUnionExec]. + /// the appropriate task count based on the presence of nodes that can isolate tasks (e.g., + /// [ChildrenIsolatorUnionExec]) and the presence of network boundaries that change the task + /// count. /// /// `f` is applied to the node's children first, and then to the node itself. fn transform_up_with_task_count Result>>( @@ -56,8 +57,9 @@ pub(crate) trait TreeNodeExt { Self: Sized; /// Recursively rewrite the tree using `f` in a top-down (pre-order) fashion, propagating - /// the appropriate task count based on the presence of nodes that can isolate tasks, like - /// [ChildrenIsolatorUnionExec]. + /// the appropriate task count based on the presence of nodes that can isolate tasks (e.g., + /// [ChildrenIsolatorUnionExec]) and the presence of network boundaries that change the task + /// count. /// /// `f` is applied to the node first, and then its children. #[allow(dead_code)] // Used in follow up work. @@ -115,41 +117,39 @@ impl TreeNodeExt for Arc { Self: Sized, { // None = skip this subtree (irrelevant CIU child for our task index). - let stack = RefCell::new(vec![Some(dt_ctx)]); - self.transform_down_up( - |node| { - let Some(dt_ctx) = stack.borrow_mut().pop().unwrap() else { - return Ok(Transformed { - data: node, - transformed: false, - tnr: TreeNodeRecursion::Jump, - }); - }; - let transformed = f(node, dt_ctx.clone())?; - if transformed.tnr != TreeNodeRecursion::Continue - || transformed.data.is_network_boundary() - { - return Ok(Transformed { - tnr: TreeNodeRecursion::Jump, - ..transformed - }); - } - let node = &transformed.data; - if let Some(ciu) = node.as_any().downcast_ref::() { - let mut child_ctxs = vec![None; ciu.children.len()]; - for (child_idx, child_ctx) in &ciu.task_idx_map[dt_ctx.task_index] { - child_ctxs[*child_idx] = Some(child_ctx.clone()); - } - stack.borrow_mut().extend(child_ctxs.into_iter().rev()); - } else { - stack - .borrow_mut() - .extend(node.children().iter().map(|_| Some(dt_ctx.clone())).rev()); + let mut stack = vec![Some(dt_ctx)]; + self.transform_down(|node| { + let Some(dt_ctx) = stack.pop().unwrap() else { + return Ok(Transformed { + data: node, + transformed: false, + tnr: TreeNodeRecursion::Jump, + }); + }; + let transformed = f(node, dt_ctx.clone())?; + if transformed.tnr == TreeNodeRecursion::Stop { + return Ok(transformed); + } + if transformed.tnr != TreeNodeRecursion::Continue + || transformed.data.is_network_boundary() + { + return Ok(Transformed { + tnr: TreeNodeRecursion::Jump, + ..transformed + }); + } + let node = &transformed.data; + if let Some(ciu) = node.as_any().downcast_ref::() { + let mut child_ctxs = vec![None; ciu.children.len()]; + for (child_idx, child_ctx) in &ciu.task_idx_map[dt_ctx.task_index] { + child_ctxs[*child_idx] = Some(child_ctx.clone()); } - Ok(transformed) - }, - |node| Ok(Transformed::no(node)), - ) + stack.extend(child_ctxs.into_iter().rev()); + } else { + stack.extend(node.children().iter().map(|_| Some(dt_ctx.clone())).rev()); + } + Ok(transformed) + }) } fn transform_up_with_task_count Result>>( @@ -214,14 +214,14 @@ impl TreeNodeExt for Arc { #[cfg(test)] mod tests { use super::*; - use crate::NetworkCoalesceExec; use crate::execution_plans::ChildWeight; + use crate::stage::RemoteStage; + use crate::{NetworkCoalesceExec, Stage}; use datafusion::arrow::datatypes::Schema; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::union::UnionExec; use insta::assert_snapshot; - // ── apply_with_dt_ctx ──────────────────────────────────────────────────────── #[test] @@ -485,6 +485,22 @@ mod tests { "); } + #[test] + fn tc_up_remote_nb_has_no_subtree() { + let plan = union(vec![ + single(network_boundary(leaf(), 2)), + single(remote_network_boundary()), + ]); + assert_snapshot!(trace_tc_up(plan, 5), @r" + Leaf [tc=2] + Network [tc=5] + Single [tc=5] + Network [tc=5] + Single [tc=5] + Union [tc=5] + "); + } + // ── transform_down_with_task_count ──────────────────────────────────────── #[test] @@ -532,6 +548,22 @@ mod tests { "); } + #[test] + fn tc_down_remote_nb_has_no_subtree() { + let plan = union(vec![ + single(network_boundary(leaf(), 2)), + single(remote_network_boundary()), + ]); + assert_snapshot!(trace_tc_down(plan, 5), @r" + Union [tc=5] + Single [tc=5] + Network [tc=5] + Leaf [tc=2] + Single [tc=5] + Network [tc=5] + "); + } + // ── helpers: plan builders ──────────────────────────────────────────────── fn leaf() -> Arc { @@ -553,6 +585,18 @@ mod tests { Arc::new(NetworkCoalesceExec::try_new(input, producer_tasks, 1).unwrap()) } + fn remote_network_boundary() -> Arc { + network_boundary(leaf(), 1) + .as_network_boundary() + .unwrap() + .with_input_stage(Stage::Remote(RemoteStage { + query_id: uuid::Uuid::nil(), + num: 0, + workers: vec![], + })) + .unwrap() + } + fn ciu( children: Vec>, child_task_counts: Vec,