From 554daf8490f0424d9a78da24664071cd6a23bcbf Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 30 May 2026 08:36:32 +0200 Subject: [PATCH 1/2] Add NetworkBoundaryBuilder argument to inject_network_boundaries.rs From 1c864d5fabb261abba490181bed32fec24abef22 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 30 May 2026 08:36:32 +0200 Subject: [PATCH 2/2] Refactor coordinator module and ensure cache invalidation on coordinator->worker stream drop --- src/coordinator/distributed.rs | 111 +++-- src/coordinator/latency_metric.rs | 62 +++ src/coordinator/mod.rs | 3 +- src/coordinator/prepare_static_plan.rs | 67 +-- .../{task_spawner.rs => query_coordinator.rs} | 391 ++++++++++-------- src/execution_plans/metrics.rs | 4 - src/metrics/task_metrics_rewriter.rs | 25 +- src/worker/generated/worker.rs | 5 +- src/worker/impl_coordinator_channel.rs | 47 ++- src/worker/worker.proto | 2 + tests/stateful_data_cleanup.rs | 53 +-- 11 files changed, 456 insertions(+), 314 deletions(-) create mode 100644 src/coordinator/latency_metric.rs rename src/coordinator/{task_spawner.rs => query_coordinator.rs} (54%) diff --git a/src/coordinator/distributed.rs b/src/coordinator/distributed.rs index 3828fbb3..fe1bbff3 100644 --- a/src/coordinator/distributed.rs +++ b/src/coordinator/distributed.rs @@ -1,10 +1,10 @@ use crate::common::{require_one_child, serialize_uuid}; use crate::coordinator::metrics_store::MetricsStore; use crate::coordinator::prepare_static_plan::prepare_static_plan; +use crate::coordinator::query_coordinator::QueryCoordinator; use crate::distributed_planner::NetworkBoundaryExt; use crate::worker::generated::worker::TaskKey; use datafusion::common::internal_datafusion_err; -use datafusion::common::runtime::JoinSet; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::{Result, exec_err}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; @@ -14,8 +14,7 @@ use datafusion::physical_plan::stream::RecordBatchReceiverStreamBuilder; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use futures::StreamExt; use std::fmt::Formatter; -use std::sync::Arc; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; /// [ExecutionPlan] that executes the inner plan in distributed mode. /// Before executing it, two modifications are lazily performed on the plan: @@ -26,22 +25,39 @@ use std::sync::Mutex; /// over the wire. #[derive(Debug)] pub struct DistributedExec { - plan: Arc, - prepared_plan: Arc>>>, + /// Initial [ExecutionPlan] present before execution. + /// - If the plan was distributed statically, this will be the final distributed plan with all + /// the appropriate network boundaries in it. + /// - If the plan is going to be distributed dynamically during execution, this is the initial + /// non-distributed plan. + base_plan: Arc, + /// Resulting [ExecutionPlan] after execution ready for visualization purposes. + /// - If the plan was distributed statically, this is equal to the base plan. + /// - If the plan is going to be distributed dynamically during execution, this is the resulting + /// plan re-calculated based on runtime statistics. + plan_for_viz: Arc>>>, + /// The head stage meant to be executed locally on [DistributedExec::execute]. + head_stage: Arc>>>, + /// DataFusion metrics. metrics: ExecutionPlanMetricsSet, + /// Storage where metrics collected from workers at runtime will place their results as they + /// finish their respective remote tasks. pub(crate) metrics_store: Option>, } pub(super) struct PreparedPlan { + /// The head stage meant to be executed locally by the coordinator. pub(super) head_stage: Arc, - pub(super) join_set: JoinSet>, + /// A final representation of the plan for visualization purposes. + pub(super) plan_for_viz: Arc, } impl DistributedExec { - pub fn new(plan: Arc) -> Self { + pub fn new(base_plan: Arc) -> Self { Self { - plan, - prepared_plan: Arc::new(Mutex::new(None)), + base_plan, + plan_for_viz: Arc::new(Mutex::new(None)), + head_stage: Arc::new(Mutex::new(None)), metrics: ExecutionPlanMetricsSet::new(), metrics_store: None, } @@ -68,7 +84,10 @@ impl DistributedExec { let Some(task_metrics) = &self.metrics_store else { return; }; - let _ = self.plan.apply(|plan| { + let Some(plan) = self.plan_for_viz.lock().unwrap().as_ref().cloned() else { + return; + }; + let _ = plan.apply(|plan| { if let Some(boundary) = plan.as_network_boundary() { let stage = boundary.input_stage(); for i in 0..stage.task_count() { @@ -93,8 +112,8 @@ impl DistributedExec { /// Returns the plan which is lazily prepared on `execute()` and actually gets executed. /// It is updated on every call to `execute()`. Returns an error if `.execute()` has not been /// called. - pub(crate) fn prepared_plan(&self) -> Result> { - self.prepared_plan + pub(crate) fn plan_for_viz(&self) -> Result> { + self.plan_for_viz .lock() .map_err(|e| internal_datafusion_err!("Failed to lock prepared plan: {}", e))? .clone() @@ -102,6 +121,18 @@ impl DistributedExec { internal_datafusion_err!("No prepared plan found. Was execute() called?") }) } + + /// Returns the head stage that was actually executed. Unlike [`Self::plan_for_viz`] (which is + /// reconstructed for visualization, with `Stage::Local` boundaries and rebuilt ancestor + /// `Arc`s), this returns the original `Arc` instances whose metrics were populated during + /// execution. + pub(crate) fn head_stage(&self) -> Result> { + self.head_stage + .lock() + .map_err(|e| internal_datafusion_err!("Failed to lock head stage: {}", e))? + .clone() + .ok_or_else(|| internal_datafusion_err!("No head stage found. Was execute() called?")) + } } impl DisplayAs for DistributedExec { @@ -116,11 +147,11 @@ impl ExecutionPlan for DistributedExec { } fn properties(&self) -> &Arc { - self.plan.properties() + self.base_plan.properties() } fn children(&self) -> Vec<&Arc> { - vec![&self.plan] + vec![&self.base_plan] } fn with_new_children( @@ -128,8 +159,9 @@ impl ExecutionPlan for DistributedExec { children: Vec>, ) -> Result> { Ok(Arc::new(DistributedExec { - plan: require_one_child(&children)?, - prepared_plan: self.prepared_plan.clone(), + base_plan: require_one_child(&children)?, + plan_for_viz: Arc::new(Mutex::new(None)), + head_stage: Arc::new(Mutex::new(None)), metrics: self.metrics.clone(), metrics_store: self.metrics_store.clone(), })) @@ -150,36 +182,43 @@ impl ExecutionPlan for DistributedExec { ); } - let PreparedPlan { - head_stage, - join_set, - } = prepare_static_plan(&self.plan, &self.metrics, &self.metrics_store, &context)?; - { - let mut guard = self - .prepared_plan - .lock() - .map_err(|e| internal_datafusion_err!("Failed to lock prepared plan: {e}"))?; - *guard = Some(head_stage.clone()); - } + let base_plan = Arc::clone(&self.base_plan); + let plan_for_viz = Arc::clone(&self.plan_for_viz); + let head_stage = Arc::clone(&self.head_stage); + + let query_coordinator = QueryCoordinator::new( + Arc::clone(&context), + &self.metrics, + self.metrics_store.clone(), + ); + let mut builder = RecordBatchReceiverStreamBuilder::new(self.schema(), 1); let tx = builder.tx(); - // Spawn the task that pulls data from child... + builder.spawn(async move { - let mut stream = head_stage.execute(partition, context)?; + let _guard = query_coordinator.end_query_guard(); + + let result = prepare_static_plan(&query_coordinator, &base_plan)?; + + plan_for_viz + .lock() + .expect("poisoned lock") + .replace(result.plan_for_viz); + head_stage + .lock() + .expect("poisoned lock") + .replace(Arc::clone(&result.head_stage)); + let mut stream = result.head_stage.execute(partition, context)?; while let Some(msg) = stream.next().await { if tx.send(msg).await.is_err() { break; // channel closed } } + drop(tx); + query_coordinator.drain_pending_tasks().await?; Ok(()) }); - // ...in parallel to the one that feeds the plan to workers. - builder.spawn(async move { - for res in join_set.join_all().await { - res?; - } - Ok(()) - }); + Ok(builder.build()) } diff --git a/src/coordinator/latency_metric.rs b/src/coordinator/latency_metric.rs new file mode 100644 index 00000000..eea47932 --- /dev/null +++ b/src/coordinator/latency_metric.rs @@ -0,0 +1,62 @@ +use datafusion::common::instant::Instant; +use datafusion::physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, MetricBuilder, MetricValue, Time, +}; +use std::fmt::Display; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +/// DataFusion metrics system is pretty limited from an API standpoint. This intermediate struct +/// bridges the gaps that are not satisfied by upstream API for measuring latency. +pub(super) struct LatencyMetric { + max: Time, + avg: Time, + max_latency_micros: AtomicU64, + sum_latency_micros: AtomicU64, + count_latency_micros: AtomicU64, +} + +impl Drop for LatencyMetric { + fn drop(&mut self) { + self.max.add_duration(Duration::from_micros( + self.max_latency_micros.load(Ordering::Relaxed), + )); + self.avg.add_duration(Duration::from_micros( + self.sum_latency_micros.load(Ordering::Relaxed) + / self.count_latency_micros.load(Ordering::Relaxed).max(1), + )); + } +} + +impl LatencyMetric { + pub(super) fn new( + name: impl Display, + builder: impl Fn(MetricBuilder) -> MetricBuilder, + metrics: &ExecutionPlanMetricsSet, + ) -> Self { + let max = Time::new(); + builder(MetricBuilder::new(metrics)).build(MetricValue::Time { + name: format!("{name}_max").into(), + time: max.clone(), + }); + let avg = Time::new(); + builder(MetricBuilder::new(metrics)).build(MetricValue::Time { + name: format!("{name}_avg").into(), + time: avg.clone(), + }); + Self { + max, + avg, + max_latency_micros: AtomicU64::new(0), + sum_latency_micros: AtomicU64::new(0), + count_latency_micros: AtomicU64::new(0), + } + } + + pub(super) fn record(&self, start: &Instant) { + let micros = start.elapsed().as_micros() as u64; + self.max_latency_micros.fetch_max(micros, Ordering::Relaxed); + self.sum_latency_micros.fetch_add(micros, Ordering::Relaxed); + self.count_latency_micros.fetch_add(1, Ordering::Relaxed); + } +} diff --git a/src/coordinator/mod.rs b/src/coordinator/mod.rs index 2aea8442..8fe771d3 100644 --- a/src/coordinator/mod.rs +++ b/src/coordinator/mod.rs @@ -1,7 +1,8 @@ mod distributed; +mod latency_metric; mod metrics_store; mod prepare_static_plan; -mod task_spawner; +mod query_coordinator; pub use distributed::DistributedExec; pub(crate) use metrics_store::MetricsStore; diff --git a/src/coordinator/prepare_static_plan.rs b/src/coordinator/prepare_static_plan.rs index 3da4c56a..65d74276 100644 --- a/src/coordinator/prepare_static_plan.rs +++ b/src/coordinator/prepare_static_plan.rs @@ -1,20 +1,10 @@ -use crate::coordinator::MetricsStore; use crate::coordinator::distributed::PreparedPlan; -use crate::coordinator::task_spawner::{ - CoordinatorToWorkerMetrics, CoordinatorToWorkerTaskSpawner, -}; +use crate::coordinator::query_coordinator::QueryCoordinator; use crate::stage::RemoteStage; -use crate::{ - DistributedConfig, NetworkBoundaryExt, Stage, TaskEstimator, TaskRoutingContext, - get_distributed_worker_resolver, -}; -use datafusion::common::runtime::JoinSet; +use crate::{NetworkBoundaryExt, Stage}; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{Result, exec_err}; -use datafusion::execution::TaskContext; use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; -use rand::Rng; use std::sync::Arc; /// Prepares the distributed plan for execution, which implies: @@ -28,18 +18,9 @@ use std::sync::Arc; /// 4. Spawn a background task per worker that waits for the worker to finish and collects /// its metrics into [DistributedExec::task_metrics] via the coordinator channel. pub(super) fn prepare_static_plan( + query_coordinator: &QueryCoordinator, base_plan: &Arc, - metrics: &ExecutionPlanMetricsSet, - task_metrics: &Option>, - ctx: &Arc, ) -> Result { - let worker_resolver = get_distributed_worker_resolver(ctx.session_config())?; - - let available_urls = worker_resolver.get_urls()?; - - let metrics = CoordinatorToWorkerMetrics::new(metrics); - - let mut join_set = JoinSet::new(); let prepared = Arc::clone(base_plan).transform_up(|plan| { // The following logic is just applied on network boundaries. let Some(plan) = plan.as_network_boundary() else { @@ -50,46 +31,18 @@ pub(super) fn prepare_static_plan( return exec_err!("Input stage from network boundary was not in Local state"); }; - let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?; - let task_estimator = &d_cfg.__private_task_estimator; - - let mut spawner = - CoordinatorToWorkerTaskSpawner::new(stage, &metrics, task_metrics, ctx, &mut join_set)?; + let mut stage_coordinator = query_coordinator.stage_coordinator(stage); - let routed_urls = match task_estimator.route_tasks(&TaskRoutingContext { - task_ctx: Arc::clone(ctx), - plan: &stage.plan, - task_count: stage.tasks, - available_urls: &available_urls, - }) { - Ok(Some(routed_urls)) => routed_urls, - // If the user has not defined custom routing with a `route_tasks` implementation, we - // default to round-robin task assignation from a randomized starting point. - Ok(None) => { - let start_idx = rand::rng().random_range(0..available_urls.len()); - (0..stage.tasks) - .map(|i| available_urls[(start_idx + i) % available_urls.len()].clone()) - .collect() - } - Err(e) => return exec_err!("error routing tasks to workers: {e}"), - }; - - if routed_urls.len() != stage.tasks { - return exec_err!( - "number of tasks ({}) was not equal to number of urls ({}) at execution time", - stage.tasks, - routed_urls.len() - ); - } + let routed_urls = stage_coordinator.routed_urls()?; let mut workers = Vec::with_capacity(stage.tasks); for (i, routed_url) in routed_urls.into_iter().enumerate() { workers.push(routed_url.clone()); // Spawn a task that sends the subplan to the chosen URL. // There will be as many spawned tasks as workers. - let (tx, worker_rx) = spawner.send_plan_task(Arc::clone(ctx), i, routed_url)?; - spawner.metrics_collection_task(i, worker_rx); - spawner.work_unit_feed_task(Arc::clone(ctx), i, tx)?; + let (worker_tx, worker_rx) = stage_coordinator.send_plan_task(i, routed_url)?; + stage_coordinator.worker_to_coordinator_task(i, worker_rx); + stage_coordinator.coordinator_to_worker_task(i, worker_tx)?; } Ok(Transformed::yes(plan.with_input_stage(Stage::Remote( @@ -102,6 +55,8 @@ pub(super) fn prepare_static_plan( })?; Ok(PreparedPlan { head_stage: prepared.data, - join_set, + // If the plan was statically planned, the base plan is the same one that will be used for + // visualization. + plan_for_viz: Arc::clone(base_plan), }) } diff --git a/src/coordinator/task_spawner.rs b/src/coordinator/query_coordinator.rs similarity index 54% rename from src/coordinator/task_spawner.rs rename to src/coordinator/query_coordinator.rs index b4983ec0..f458f8cf 100644 --- a/src/coordinator/task_spawner.rs +++ b/src/coordinator/query_coordinator.rs @@ -1,6 +1,7 @@ use crate::common::{TreeNodeExt, now_ns, serialize_uuid, task_ctx_with_extension}; use crate::config_extension_ext::get_config_extension_propagation_headers; use crate::coordinator::MetricsStore; +use crate::coordinator::latency_metric::LatencyMetric; use crate::execution_plans::{ChildrenIsolatorUnionExec, DistributedLeafExec}; use crate::passthrough_headers::get_passthrough_headers; use crate::protobuf::tonic_status_to_datafusion_error; @@ -11,28 +12,28 @@ use crate::worker::generated::worker::coordinator_to_worker_msg::Inner; use crate::worker::generated::worker::set_plan_request::WorkUnitFeedDeclaration; use crate::{ DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, DistributedCodec, DistributedConfig, - DistributedTaskContext, DistributedWorkUnitFeedContext, TaskKey, - get_distributed_channel_resolver, + DistributedTaskContext, DistributedWorkUnitFeedContext, TaskEstimator, TaskKey, + TaskRoutingContext, get_distributed_channel_resolver, get_distributed_worker_resolver, }; -use datafusion::common::Result; use datafusion::common::instant::Instant; use datafusion::common::runtime::JoinSet; use datafusion::common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion::common::{DataFusionError, exec_datafusion_err}; +use datafusion::common::{Result, exec_err}; use datafusion::execution::TaskContext; use datafusion::physical_expr_common::metrics::{ - Count, ExecutionPlanMetricsSet, Label, MetricBuilder, MetricValue, Time, + Count, ExecutionPlanMetricsSet, Label, MetricBuilder, }; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use http::Extensions; use prost::Message; -use std::fmt::Display; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Duration; +use rand::Rng; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex}; +use tokio::sync::Notify; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Request; @@ -45,121 +46,107 @@ use uuid::Uuid; /// small batches. See [StreamExt::ready_chunks] docs for more details about how chunking works. const WORK_UNIT_FEED_CHUNK_SIZE: usize = 256; -/// Metrics that measure network details about communications between [DistributedExec] and a -/// worker. -#[derive(Clone)] -pub(super) struct CoordinatorToWorkerMetrics { - pub(super) plan_bytes_sent: Count, - pub(super) plan_send_latency: Arc, - pub(super) instantiation_time: u64, +/// Manages communication between coordinator and workers for a single query. +/// +/// The [QueryCoordinator]'s lifetime is scoped to a single query , and will instantiate independent +/// [StageCoordinator] scoped to each individual stage. +pub(super) struct QueryCoordinator { + task_ctx: Arc, + coordinator_to_worker_metrics: CoordinatorToWorkerMetrics, + metrics_store: Option>, + end_stream_notifier: Arc, + join_set: Mutex>>, } -impl CoordinatorToWorkerMetrics { - pub(super) fn new(metrics: &ExecutionPlanMetricsSet) -> Self { +impl QueryCoordinator { + /// Builds a new [QueryCoordinator] scoped to a query. + pub(super) fn new( + task_ctx: Arc, + metrics_set: &ExecutionPlanMetricsSet, + metrics_store: Option>, + ) -> Self { Self { - // Metric that measures to total sum of bytes worth of subplans sent. - plan_bytes_sent: MetricBuilder::new(metrics) - .with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0")) - .global_counter("plan_bytes_sent"), - // Latency statistics about the network calls issued to the workers for feeding subplans. - plan_send_latency: Arc::new(LatencyMetric::new( - "plan_send_latency", - |b| b.with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0")), - metrics, - )), - instantiation_time: now_ns(), + task_ctx, + metrics_store, + coordinator_to_worker_metrics: CoordinatorToWorkerMetrics::new(metrics_set), + end_stream_notifier: Arc::new(Notify::new()), + join_set: Mutex::new(JoinSet::new()), + } + } + + /// Builds a new [StageCoordinator] that will manage coordinator-worker connections for the given + /// stage. + pub(super) fn stage_coordinator<'a>(&'a self, stage: &'a LocalStage) -> StageCoordinator<'a> { + StageCoordinator { + plan: &stage.plan, + query_id: stage.query_id, + stage_id: stage.num, + task_count: stage.tasks, + task_ctx: &self.task_ctx, + metrics: &self.coordinator_to_worker_metrics, + metrics_store: &self.metrics_store, + end_stream_notifier: &self.end_stream_notifier, + join_set: &self.join_set, } } + + /// returns a guard that, when dropped, it signals all the coordinator->worker connections that + /// the query is finished, ending them, and propagating the EOS to the workers so that they can + /// clean up any remaining state. + pub(super) fn end_query_guard(&self) -> NotifyGuard { + NotifyGuard(Arc::clone(&self.end_stream_notifier)) + } + + /// Blocks until all background tasks have finished (e.g., sending WorkUnit feeds, or collecting + /// metrics) + pub(super) async fn drain_pending_tasks(self) -> Result<()> { + let join_set = std::mem::take(self.join_set.lock().unwrap().deref_mut()); + for res in join_set.join_all().await { + res?; + } + Ok(()) + } } -/// Builder for the different kind of tasks that handle the communications between the -/// [DistributedExec] node to the workers. This struct is responsible for instantiating the tasks -/// as boxed futures so that [DistributedExec] can tokio-spawn them at will. +/// Manages all the coordinator->worker and worker->coordinator comms that happen during the +/// execution of an individual Stage. As this struct is scoped per Stage, it will handle the +/// connection to N workers, where N is the number of tasks of the managed Stage. /// /// This struct is responsible for: /// - Building tasks that communicate a serialized plan to multiple workers for further execution. /// - Building tasks that stream partition feeds from local [WorkUnitFeedExec] nodes to their /// remote counterparts. -pub(super) struct CoordinatorToWorkerTaskSpawner<'a> { +pub(super) struct StageCoordinator<'a> { plan: &'a Arc, query_id: Uuid, stage_id: usize, task_count: usize, - task_ctx: &'a TaskContext, + task_ctx: &'a Arc, metrics: &'a CoordinatorToWorkerMetrics, - task_metrics: &'a Option>, - join_set: &'a mut JoinSet>, + metrics_store: &'a Option>, + end_stream_notifier: &'a Arc, + join_set: &'a Mutex>>, } -impl<'a> CoordinatorToWorkerTaskSpawner<'a> { - /// Builds a new [CoordinatorToWorkerTaskSpawner] based on the [Stage] that needs to be - /// fanned out to multiple workers. - pub(super) fn new( - stage: &'a LocalStage, - metrics: &'a CoordinatorToWorkerMetrics, - task_metrics: &'a Option>, - task_ctx: &'a TaskContext, - join_set: &'a mut JoinSet>, - ) -> Result { - Ok(Self { - plan: &stage.plan, - query_id: stage.query_id, - stage_id: stage.num, - task_count: stage.tasks, - task_ctx, - metrics, - task_metrics, - join_set, - }) - } - +impl<'a> StageCoordinator<'a> { /// Sends a serialized plan to a specific worker and sets up the bidirectional gRPC stream. /// Returns the sender for outbound coordinator-to-worker messages and the receiver for /// inbound worker-to-coordinator messages. pub(super) fn send_plan_task( &mut self, - ctx: Arc, task_i: usize, url: Url, ) -> Result<( UnboundedSender, UnboundedReceiver, )> { - let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?; - let wuf_registry = &d_cfg.__private_work_unit_feed_registry; - - let mut work_unit_feed_declarations = vec![]; - let d_ctx = DistributedTaskContext { - task_index: task_i, - task_count: self.task_count, - }; - - let plan = Arc::clone(self.plan); - let specialized = plan.transform_down_with_dt_ctx(d_ctx, |plan, d_ctx| { - if let Some(wuf) = wuf_registry.get_work_unit_feed(&plan) { - work_unit_feed_declarations.push(WorkUnitFeedDeclaration { - id: serialize_uuid(&wuf.id()), - partitions: plan.properties().partitioning.partition_count() as u64, - }); - }; - - if let Some(ciu) = plan.downcast_ref::() { - let ciu = ciu.to_task_specialized(d_ctx.task_index); - return Ok(Transformed::yes(Arc::new(ciu))); - }; + let session_config = self.task_ctx.session_config(); + let codec = DistributedCodec::new_combined_with_user(session_config); - if let Some(dle) = plan.downcast_ref::() { - let specialized = dle.to_task_specialized(d_ctx.task_index); - return Ok(Transformed::yes(specialized)); - } - - Ok(Transformed::no(plan)) - })?; - - let codec = DistributedCodec::new_combined_with_user(self.task_ctx.session_config()); + let (specialized, work_unit_feed_declarations) = self.task_specialized_plan(task_i)?; let plan_proto = - PhysicalPlanNode::try_from_physical_plan(specialized.data, &codec)?.encode_to_vec(); + PhysicalPlanNode::try_from_physical_plan(specialized, &codec)?.encode_to_vec(); let plan_size = plan_proto.len(); let task_key = TaskKey { @@ -183,22 +170,26 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { let (worker_to_coordinator_tx, worker_to_coordinator_rx) = tokio::sync::mpsc::unbounded_channel(); - let channel_resolver = get_distributed_channel_resolver(ctx.as_ref()); + let channel_resolver = get_distributed_channel_resolver(self.task_ctx.as_ref()); - let mut headers = get_config_extension_propagation_headers(ctx.session_config())?; - headers.extend(get_passthrough_headers(ctx.session_config())); + let mut headers = get_config_extension_propagation_headers(session_config)?; + headers.extend(get_passthrough_headers(session_config)); let request = Request::from_parts( MetadataMap::from_headers(headers), Extensions::default(), - futures::stream::once(async { msg }).chain( - UnboundedReceiverStream::new(coordinator_to_worker_rx).map(set_work_unit_send_time), - ), + futures::stream::once(async { msg }) + .chain(UnboundedReceiverStream::new(coordinator_to_worker_rx)) + .map(set_work_unit_send_time) + // Keep the request side of the channel open until the query ends: this tail emits + // no messages and only completes, once the `Notify` fires. Workers interpret this + // EOS of this stream as a query finished/aborted signal. + .chain(keep_stream_alive(Arc::clone(self.end_stream_notifier))), ); let metrics = self.metrics.clone(); - self.join_set.spawn(async move { + self.join_set.lock().unwrap().spawn(async move { let start = Instant::now(); let mut client = channel_resolver.get_worker_client_for_url(&url).await?; let response = client.coordinator_channel(request).await.map_err(|e| { @@ -210,14 +201,11 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { metrics.plan_bytes_sent.add(plan_size); let mut worker_to_coordinator_stream = response.into_inner(); while let Some(msg_or_err) = worker_to_coordinator_stream.next().await { - let msg = match msg_or_err { - Ok(msg) => msg, - Err(err) => { - return Err(tonic_status_to_datafusion_error(err).unwrap_or_else(|| { - exec_datafusion_err!("Unknown error on worker to coordinator stream") - })); - } - }; + let msg = msg_or_err.map_err(|err| { + tonic_status_to_datafusion_error(err).unwrap_or_else(|| { + exec_datafusion_err!("Unknown error on worker to coordinator stream") + }) + })?; if worker_to_coordinator_tx.send(msg).is_err() { break; // receiver dropped } @@ -228,7 +216,10 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { Ok((coordinator_to_worker_tx, worker_to_coordinator_rx)) } - pub(super) fn metrics_collection_task( + /// Spawns a background task in charge of collecting messages sent by a worker. Some things that + /// are collected from workers are: + /// - Execution metrics information, sent once the worker has finished executing the task. + pub(super) fn worker_to_coordinator_task( &mut self, task_i: usize, mut worker_to_coordinator_rx: UnboundedReceiver, @@ -238,7 +229,10 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { stage_id: self.stage_id as u64, task_number: task_i as u64, }; - let task_metrics = self.task_metrics.clone(); + let task_metrics = self.metrics_store.clone(); + + // Cannot use self.join_set because that's tied to the lifetime of the query, and the + // metrics collection process might outlive the query's lifetime. #[allow(clippy::disallowed_methods)] tokio::spawn(async move { while let Some(msg) = worker_to_coordinator_rx.recv().await { @@ -255,17 +249,16 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { }); } - /// Launches the task that based on the different local [WorkUnitFeedExec] nodes, sends their - /// inner [WorkUnitFeeds] over the network to their remote counterparts. - /// - /// Once this function is called, all the [WorkUnitFeedExec]s feeds will be consumed. - pub(super) fn work_unit_feed_task( + /// Spawns a background task in charge of sending messages to workers. Some things that are sent + /// to workers here are: + /// - WorkUnits collected from [WorkUnitFeeds] present in the plan. + pub(super) fn coordinator_to_worker_task( &mut self, - ctx: Arc, task_i: usize, tx: UnboundedSender, ) -> Result<()> { - let d_cfg = DistributedConfig::from_config_options(ctx.session_config().options())?; + let session_config = self.task_ctx.session_config(); + let d_cfg = DistributedConfig::from_config_options(session_config.options())?; let wuf_registry = &d_cfg.__private_work_unit_feed_registry; let d_ctx = DistributedTaskContext { @@ -285,7 +278,7 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { let dist_feed_ctx = DistributedWorkUnitFeedContext { fan_out_tasks: d_ctx.task_count, }; - let t_ctx = Arc::new(task_ctx_with_extension(&ctx, dist_feed_ctx)); + let t_ctx = Arc::new(task_ctx_with_extension(self.task_ctx, dist_feed_ctx)); let mut feeds = Vec::with_capacity(end_partition - start_partition); for (partition, feed_idx) in (start_partition..end_partition).enumerate() { @@ -312,65 +305,141 @@ impl<'a> CoordinatorToWorkerTaskSpawner<'a> { })); Ok(TreeNodeRecursion::Continue) })?; - self.join_set.spawn(async move { + + struct WorkUnitEosOnDrop(UnboundedSender); + impl Drop for WorkUnitEosOnDrop { + fn drop(&mut self) { + let _ = self.0.send(pb::CoordinatorToWorkerMsg { + inner: Some(Inner::WorkUnitEos(true)), + }); + } + } + + self.join_set.lock().unwrap().spawn(async move { + let _guard = WorkUnitEosOnDrop(tx); futures::future::try_join_all(futures).await?; Ok(()) }); Ok(()) } + + /// Specializes the [Arc] for this stage to provided task index. This implies + /// trimming down any unnecessary information that the specific `task_i` task is not going to + /// need, like unexecuted branches in [ChildrenIsolatorUnionExec], or unexecuted variants of + /// [DistributedLeafExec]. + fn task_specialized_plan( + &self, + task_i: usize, + ) -> Result<(Arc, Vec)> { + let session_config = self.task_ctx.session_config(); + let d_cfg = DistributedConfig::from_config_options(session_config.options())?; + let wuf_registry = &d_cfg.__private_work_unit_feed_registry; + + let mut work_unit_feed_declarations = vec![]; + let d_ctx = DistributedTaskContext { + task_index: task_i, + task_count: self.task_count, + }; + + let plan = Arc::clone(self.plan); + let transformed = plan.transform_down_with_dt_ctx(d_ctx, |plan, d_ctx| { + if let Some(wuf) = wuf_registry.get_work_unit_feed(&plan) { + work_unit_feed_declarations.push(WorkUnitFeedDeclaration { + id: serialize_uuid(&wuf.id()), + partitions: plan.properties().partitioning.partition_count() as u64, + }); + }; + + if let Some(ciu) = plan.downcast_ref::() { + let ciu = ciu.to_task_specialized(d_ctx.task_index); + return Ok(Transformed::yes(Arc::new(ciu))); + }; + + if let Some(dle) = plan.downcast_ref::() { + let specialized = dle.to_task_specialized(d_ctx.task_index); + return Ok(Transformed::yes(specialized)); + } + + Ok(Transformed::no(plan)) + })?; + Ok((transformed.data, work_unit_feed_declarations)) + } + + /// Returns as many URLs as the task count for the stage this [StageCoordinator] + /// is managing. These URLs can be: + /// - assigned randomly, if the user did not provide any custom routing. + /// - chosen by the user, if they provided an implementation for the + /// [TaskEstimator::route_tasks] method. + pub(super) fn routed_urls(&self) -> Result> { + let session_config = self.task_ctx.session_config(); + let d_cfg = DistributedConfig::from_config_options(session_config.options())?; + let worker_resolver = get_distributed_worker_resolver(session_config)?; + let task_estimator = &d_cfg.__private_task_estimator; + let available_urls = worker_resolver.get_urls()?; + + let routed_urls = match task_estimator.route_tasks(&TaskRoutingContext { + task_ctx: Arc::clone(self.task_ctx), + plan: self.plan, + task_count: self.task_count, + available_urls: &available_urls, + }) { + Ok(Some(routed_urls)) => routed_urls, + // If the user has not defined custom routing with a `route_tasks` implementation, we + // default to round-robin task assignation from a randomized starting point. + Ok(None) => { + let start_idx = rand::rng().random_range(0..available_urls.len()); + (0..self.task_count) + .map(|i| available_urls[(start_idx + i) % available_urls.len()].clone()) + .collect() + } + Err(e) => return exec_err!("error routing tasks to workers: {e}"), + }; + + if routed_urls.len() != self.task_count { + return exec_err!( + "number of tasks ({}) was not equal to number of urls ({}) at execution time", + self.task_count, + routed_urls.len() + ); + } + Ok(routed_urls) + } } -/// DataFusion metrics system is pretty limited from an API standpoint. This intermediate struct -/// bridges the gaps that are not satisfied by upstream API for measuring latency. -pub(super) struct LatencyMetric { - max: Time, - avg: Time, - max_latency_micros: AtomicU64, - sum_latency_micros: AtomicU64, - count_latency_micros: AtomicU64, +fn keep_stream_alive(notify: Arc) -> impl Stream + 'static { + futures::stream::once(notify.notified_owned()).filter_map(|()| futures::future::ready(None)) } -impl Drop for LatencyMetric { +pub(super) struct NotifyGuard(Arc); + +impl Drop for NotifyGuard { fn drop(&mut self) { - self.max.add_duration(Duration::from_micros( - self.max_latency_micros.load(Ordering::Relaxed), - )); - self.avg.add_duration(Duration::from_micros( - self.sum_latency_micros.load(Ordering::Relaxed) - / self.count_latency_micros.load(Ordering::Relaxed).max(1), - )); + self.0.notify_waiters(); } } -impl LatencyMetric { - pub(super) fn new( - name: impl Display, - builder: impl Fn(MetricBuilder) -> MetricBuilder, - metrics: &ExecutionPlanMetricsSet, - ) -> Self { - let max = Time::new(); - builder(MetricBuilder::new(metrics)).build(MetricValue::Time { - name: format!("{name}_max").into(), - time: max.clone(), - }); - let avg = Time::new(); - builder(MetricBuilder::new(metrics)).build(MetricValue::Time { - name: format!("{name}_avg").into(), - time: avg.clone(), - }); +/// Metrics that measure network details about communications between [DistributedExec] and a worker. +#[derive(Clone)] +pub(super) struct CoordinatorToWorkerMetrics { + pub(super) plan_bytes_sent: Count, + pub(super) plan_send_latency: Arc, + pub(super) instantiation_time: u64, +} + +impl CoordinatorToWorkerMetrics { + pub(super) fn new(metrics: &ExecutionPlanMetricsSet) -> Self { Self { - max, - avg, - max_latency_micros: AtomicU64::new(0), - sum_latency_micros: AtomicU64::new(0), - count_latency_micros: AtomicU64::new(0), + // Metric that measures to total sum of bytes worth of subplans sent. + plan_bytes_sent: MetricBuilder::new(metrics) + .with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0")) + .global_counter("plan_bytes_sent"), + // Latency statistics about the network calls issued to the workers for feeding subplans. + plan_send_latency: Arc::new(LatencyMetric::new( + "plan_send_latency", + |b| b.with_label(Label::new(DISTRIBUTED_DATAFUSION_TASK_ID_LABEL, "0")), + metrics, + )), + instantiation_time: now_ns(), } } - - fn record(&self, start: &Instant) { - let micros = start.elapsed().as_micros() as u64; - self.max_latency_micros.fetch_max(micros, Ordering::Relaxed); - self.sum_latency_micros.fetch_add(micros, Ordering::Relaxed); - self.count_latency_micros.fetch_add(1, Ordering::Relaxed); - } } diff --git a/src/execution_plans/metrics.rs b/src/execution_plans/metrics.rs index d7567e2b..0fefdf3e 100644 --- a/src/execution_plans/metrics.rs +++ b/src/execution_plans/metrics.rs @@ -25,10 +25,6 @@ impl MetricsWrapperExec { pub(crate) fn inner(&self) -> &Arc { &self.inner } - - pub(crate) fn inner_arc(&self) -> Arc { - Arc::clone(&self.inner) - } } /// MetricsWrapperExec is invisible during display. diff --git a/src/metrics/task_metrics_rewriter.rs b/src/metrics/task_metrics_rewriter.rs index 6bcee2ca..3ba5c88c 100644 --- a/src/metrics/task_metrics_rewriter.rs +++ b/src/metrics/task_metrics_rewriter.rs @@ -18,7 +18,6 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::internal_err; use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet}; use std::sync::Arc; -use std::vec; /// Format to use when displaying metrics for a distributed plan. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -51,37 +50,25 @@ pub async fn rewrite_distributed_plan_with_metrics( return Ok(plan); }; - // Check that the plan was executed before waiting — if not, prepared_plan() returns an - // error immediately rather than waiting forever for metrics that will never arrive. - let prepared = distributed_exec.prepared_plan()?; - distributed_exec.wait_for_metrics().await; let Some(metrics_collection) = distributed_exec.metrics_store.clone() else { return Ok(plan); }; - let task_metrics = collect_plan_metrics(&prepared)?; + let head_stage = distributed_exec.head_stage()?; + let task_metrics = collect_plan_metrics(&head_stage)?; // Rewrite the DistributedExec's child plan with metrics. let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics( format.to_rewrite_ctx(0), // Task id is 0 for the DistributedExec plan - plan.children()[0].clone(), + distributed_exec.plan_for_viz()?, task_metrics, )?; - let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?; - - let transformed = plan.transform_down(|plan| { - // After `rewrite_local_plan_with_metrics` above, every node (including network - // boundaries) is wrapped in a `MetricsWrapperExec`. Peek through the wrapper so we - // can still recognize a network boundary by downcasting the inner node. - let inner = plan - .downcast_ref::() - .map(|w| w.inner_arc()) - .unwrap_or_else(|| Arc::clone(&plan)); + let transformed = dist_exec_plan_with_metrics.transform_down(|plan| { // Transform all stages using NetworkShuffleExec and NetworkCoalesceExec as barriers. - if let Some(network_boundary) = inner.as_network_boundary() { + if let Some(network_boundary) = plan.as_network_boundary() { let Stage::Local(stage) = network_boundary.input_stage() else { return plan_err!("Stage was not in Local state"); }; @@ -102,7 +89,7 @@ pub async fn rewrite_distributed_plan_with_metrics( Ok(Transformed::no(plan)) })?; - Ok(transformed.data) + plan.with_new_children(vec![transformed.data]) } /// Extra information for rewriting local plans. diff --git a/src/worker/generated/worker.rs b/src/worker/generated/worker.rs index d5594617..fe7a0137 100644 --- a/src/worker/generated/worker.rs +++ b/src/worker/generated/worker.rs @@ -1,7 +1,7 @@ // This file is @generated by prost-build. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CoordinatorToWorkerMsg { - #[prost(oneof = "coordinator_to_worker_msg::Inner", tags = "1, 2")] + #[prost(oneof = "coordinator_to_worker_msg::Inner", tags = "1, 2, 3")] pub inner: ::core::option::Option, } /// Nested message and enum types in `CoordinatorToWorkerMsg`. @@ -17,6 +17,9 @@ pub mod coordinator_to_worker_msg { /// be executed within a partition, for example, a stream of file addresses that should be read. #[prost(message, tag = "2")] WorkUnitBatch(super::WorkUnitBatch), + /// Signals an EOS for WorkUnits. After this message is received, no more WorkUnits will be sent. + #[prost(bool, tag = "3")] + WorkUnitEos(bool), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/src/worker/impl_coordinator_channel.rs b/src/worker/impl_coordinator_channel.rs index 2a264d8c..3a8b237b 100644 --- a/src/worker/impl_coordinator_channel.rs +++ b/src/worker/impl_coordinator_channel.rs @@ -121,28 +121,51 @@ impl Worker { })?; // Continue reading remaining messages (work unit feed data) in the background. - let mut work_unit_senders = remote_work_unit_feed_registry.senders; + let mut work_unit_senders = Some(remote_work_unit_feed_registry.senders); + let task_data_entries = Arc::clone(&self.task_data_entries); #[allow(clippy::disallowed_methods)] tokio::spawn(async move { let mut body = body.map_ok(set_work_unit_received_time); while let Some(Ok(msg)) = body.next().await { - let Some(Inner::WorkUnitBatch(msg)) = msg.inner else { + let Some(msg) = msg.inner else { continue; }; - for msg in msg.batch { - let Ok(id) = deserialize_uuid(&msg.id) else { - continue; - }; - let partition = msg.partition as usize; - let Some(tx) = work_unit_senders.get(&(id, partition)) else { - continue; - }; - if tx.send(Ok(msg)).is_err() { - work_unit_senders.remove(&(id, partition)); + match msg { + Inner::SetPlanRequest(_) => { + // SetPlanRequest should be the first already polled message in the stream, + // if some reached here it means that something is wrong. continue; } + Inner::WorkUnitBatch(msg) => { + let Some(work_unit_senders) = work_unit_senders.as_mut() else { + continue; + }; + for wu in msg.batch { + let Ok(id) = deserialize_uuid(&wu.id) else { + continue; + }; + let partition = wu.partition as usize; + let Some(tx) = work_unit_senders.get(&(id, partition)) else { + continue; + }; + if tx.send(Ok(wu)).is_err() { + // Channel closed, this sender needs to be dropped, as none will ever + // be listening on the other side. + work_unit_senders.remove(&(id, partition)); + continue; + } + } + } + Inner::WorkUnitEos(_) => { + // No further work unit message will be received here, so drop all the + // sender sides so that receiver sides see an EOS upon draining the + // remaining messages. + let _ = work_unit_senders.take(); + } } } + #[allow(clippy::disallowed_methods)] + tokio::spawn(async move { task_data_entries.invalidate(&key).await }); }); // Stream back the metrics once the task finishes executing. diff --git a/src/worker/worker.proto b/src/worker/worker.proto index c87670c8..bc1e3412 100644 --- a/src/worker/worker.proto +++ b/src/worker/worker.proto @@ -21,6 +21,8 @@ message CoordinatorToWorkerMsg { // set_plan_request. A work unit feed is a per-partition stream of information that tells the node what should // be executed within a partition, for example, a stream of file addresses that should be read. WorkUnitBatch work_unit_batch = 2; + // Signals an EOS for WorkUnits. After this message is received, no more WorkUnits will be sent. + bool work_unit_eos = 3; } } diff --git a/tests/stateful_data_cleanup.rs b/tests/stateful_data_cleanup.rs index 5719ab40..a3fe7bca 100644 --- a/tests/stateful_data_cleanup.rs +++ b/tests/stateful_data_cleanup.rs @@ -5,12 +5,13 @@ mod tests { use datafusion::physical_plan::execute_stream; use datafusion::prelude::SessionContext; use datafusion_distributed::test_utils::localhost::start_localhost_context; - use datafusion_distributed::{DefaultSessionBuilder, DistributedExt}; + use datafusion_distributed::{DefaultSessionBuilder, DistributedExt, Worker}; use datafusion_distributed_benchmarks::datasets::{register_tables, tpch}; use futures::TryStreamExt; use std::fs; use std::path::Path; use std::time::Duration; + use test_case::test_case; use tokio::sync::OnceCell; use tokio::time::timeout; @@ -19,19 +20,17 @@ mod tests { const TPCH_DATA_PARTS: usize = 16; const CARDINALITY_TASK_COUNT_FACTOR: f64 = 1.0; - #[tokio::test] - async fn no_pending_tasks_if_query_completes() -> Result<()> { - let (d_ctx, _guard, workers) = + #[test_case(false; "metrics_disabled")] + #[test_case(true; "metrics_enabled")] + #[tokio::test(flavor = "multi_thread")] + async fn no_pending_tasks_if_dynamic_query_completes(collect_metrics: bool) -> Result<()> { + let (mut d_ctx, _guard, workers) = start_localhost_context(NUM_WORKERS, DefaultSessionBuilder).await; + d_ctx.set_distributed_metrics_collection(collect_metrics)?; + run_tpch_query(d_ctx, "q1").await?; - for (i, worker) in workers.iter().enumerate() { - let tasks_running = worker.tasks_running().await; - assert_eq!( - tasks_running, 0, - "Expected Worker {i} to have 0 tasks running, but got {tasks_running}" - ) - } + assert_no_tasks_running_eventually(&workers).await; Ok(()) } @@ -43,25 +42,31 @@ mod tests { let _ = timeout(Duration::from_millis(100), run_tpch_query(d_ctx, "q1")).await; + assert_no_tasks_running_eventually(&workers).await; + + Ok(()) + } + + /// Polls until every worker reports 0 running tasks, or fails after 5s. Task entries are + /// torn down asynchronously once the coordinator->worker channel disconnects (shortly after + /// the query's output stream is dropped), so cleanup is not observable synchronously the + /// instant the query future resolves — hence the poll rather than an immediate assert. + async fn assert_no_tasks_running_eventually(workers: &[Worker]) { let start = Instant::now(); - let mut tasks_running = 0; - while start.elapsed() < Duration::from_secs(5) { - tokio::time::sleep(Duration::from_millis(100)).await; - tasks_running = 0; - for worker in &workers { + loop { + let mut tasks_running = 0; + for worker in workers { tasks_running += worker.tasks_running().await; } if tasks_running == 0 { - return Ok(()); + return; } + assert!( + start.elapsed() < Duration::from_secs(5), + "Expected 0 tasks running across workers, but still had {tasks_running} after 5s" + ); + tokio::time::sleep(Duration::from_millis(50)).await; } - - assert_eq!( - tasks_running, 0, - "Expected to have 0 tasks running, but got {tasks_running}" - ); - - Ok(()) } async fn run_tpch_query(d_ctx: SessionContext, query_id: &str) -> Result<()> {