From 5064fd087f6f8683bc866ff667d6fd98fc95e3c0 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 30 May 2026 08:36:32 +0200 Subject: [PATCH] No eager buffering in network connections --- src/worker/worker_connection_pool.rs | 36 +++++++++++++++++++++------- src/worker/worker_service.rs | 12 +++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/worker/worker_connection_pool.rs b/src/worker/worker_connection_pool.rs index 2e1f4aed..204e0e72 100644 --- a/src/worker/worker_connection_pool.rs +++ b/src/worker/worker_connection_pool.rs @@ -22,7 +22,7 @@ use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::physical_expr_common::metrics::{ExecutionPlanMetricsSet, MetricValue}; use datafusion::physical_plan::metrics::{MetricBuilder, Time}; use futures::stream::BoxStream; -use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use pin_project::{pin_project, pinned_drop}; use prost::Message; @@ -160,6 +160,7 @@ struct RemoteWorkerConnection { cancel_token: CancellationToken, per_partition_rx: DashMap>, + first_poll_notify: Arc, // Signals the demux task that buffered memory has been freed by a consumer. mem_available_notify: Arc, @@ -246,6 +247,9 @@ impl RemoteWorkerConnection { let mem_available_notify = Arc::new(Notify::new()); let mem_available_notify_for_task = Arc::clone(&mem_available_notify); + let first_poll_notify = Arc::new(Notify::new()); + let first_poll_notify_for_task = Arc::clone(&first_poll_notify); + // Cancellation token allows us to stop the background task promptly when all partition // streams are dropped (e.g., when the query is cancelled). let cancel_token = CancellationToken::new(); @@ -255,6 +259,12 @@ impl RemoteWorkerConnection { // fan them out to the appropriate `per_partition_rx` based on the "partition" declared // in each individual record batch flight metadata. let task = SpawnedTask::spawn(async move { + tokio::select! { + biased; + _ = cancel.cancelled() => return, + _ = first_poll_notify_for_task.notified() => {} + } + let mut client = match channel_resolver.get_worker_client_for_url(&url).await { Ok(v) => v, Err(err) => { @@ -364,6 +374,7 @@ impl RemoteWorkerConnection { not_consumed_streams: Arc::new(AtomicUsize::new(per_partition_rx.len())), per_partition_rx, mem_available_notify, + first_poll_notify, // metrics stuff memory_reservation: memory_reservation_clone, @@ -375,14 +386,17 @@ impl RemoteWorkerConnection { impl WorkerConnection for RemoteWorkerConnection { /// Streams the provided `partition` from the remote worker. /// - /// Note that this does not issue a network request, the actual network request happened before - /// in the init step, and is in charge of handling not only this `partition`, but also all the - /// partitions passed in `target_partition_range`. This method just streams all the record - /// batches belonging to the provided `partition` from an in-memory queue, but what populates - /// this queue is [WorkerConnection::init]. + /// This method does not handle any network connection. Instead, the network comms are delegated + /// to the task spawned by [WorkerConnection::init], who is in charge of polling data not only + /// from the requested `partition`, but from any other partition in `target_partition_range`. + /// This method just streams all the record batches belonging to the provided `partition` from + /// an in-memory queue. + /// + /// The task that polls data over the network is held inactive until the first poll to the + /// stream returned by this method. /// /// When the returned stream is dropped (e.g., due to query cancellation), the background task - /// pulling from the Flight stream will be cancelled promptly. + /// pulling from the Flight stream will be canceled promptly. fn execute(&self, partition: usize) -> Result>> { let Some((_, partition_receiver)) = self.per_partition_rx.remove(&partition) else { return internal_err!( @@ -392,7 +406,13 @@ impl WorkerConnection for RemoteWorkerConnection { let task = Arc::clone(&self.task); let cancel_token = self.cancel_token.clone(); - let stream = UnboundedReceiverStream::new(partition_receiver); + let first_poll_notify = Arc::clone(&self.first_poll_notify); + let stream = async move { + first_poll_notify.notify_one(); + UnboundedReceiverStream::new(partition_receiver) + } + .flatten_stream(); + let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err))); let reservation = Arc::clone(&self.memory_reservation); let mem_available_notify = Arc::clone(&self.mem_available_notify); diff --git a/src/worker/worker_service.rs b/src/worker/worker_service.rs index 0c572588..eac2da3e 100644 --- a/src/worker/worker_service.rs +++ b/src/worker/worker_service.rs @@ -22,6 +22,8 @@ use std::time::Duration; use tonic::codegen::BoxStream; use tonic::{Request, Response, Status, Streaming}; +const TASK_CACHE_TTI: Duration = Duration::from_mins(10); + #[allow(clippy::type_complexity)] #[derive(Clone, Default)] pub(super) struct WorkerHooks { @@ -35,9 +37,9 @@ pub(crate) type TaskDataEntries = Cache, - /// TTL-based cache for task execution data. Entries are automatically evicted after 60 seconds. - /// This prevents memory leaks from abandoned or incomplete queries while allowing concurrent - /// access to task results across multiple partition requests. + /// TTL-based cache for task execution data. Entries are automatically evicted after + /// TASK_CACHE_TTI seconds. This prevents memory leaks from abandoned or incomplete queries + /// while allowing concurrent access to task results across multiple partition requests. pub(super) task_data_entries: Arc, pub(super) session_builder: Arc, pub(super) hooks: WorkerHooks, @@ -47,9 +49,7 @@ pub struct Worker { impl Default for Worker { fn default() -> Self { - let cache = Cache::builder() - .time_to_idle(Duration::from_secs(60)) - .build(); + let cache = Cache::builder().time_to_idle(TASK_CACHE_TTI).build(); Self { runtime: Arc::new(RuntimeEnv::default()), task_data_entries: Arc::new(cache),