Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/worker/worker_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion::physical_expr_common::metrics::{ExecutionPlanMetricsSet, MetricValue};
use datafusion::physical_plan::metrics::{MetricBuilder, Time};
use futures::stream::BoxStream;
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use pin_project::{pin_project, pinned_drop};
use prost::Message;
Expand Down Expand Up @@ -160,6 +160,7 @@ struct RemoteWorkerConnection {
cancel_token: CancellationToken,
per_partition_rx: DashMap<usize, UnboundedReceiver<WorkerMsg>>,

first_poll_notify: Arc<Notify>,
// Signals the demux task that buffered memory has been freed by a consumer.
mem_available_notify: Arc<Notify>,

Expand Down Expand Up @@ -246,6 +247,9 @@ impl RemoteWorkerConnection {
let mem_available_notify = Arc::new(Notify::new());
let mem_available_notify_for_task = Arc::clone(&mem_available_notify);

let first_poll_notify = Arc::new(Notify::new());
let first_poll_notify_for_task = Arc::clone(&first_poll_notify);

// Cancellation token allows us to stop the background task promptly when all partition
// streams are dropped (e.g., when the query is cancelled).
let cancel_token = CancellationToken::new();
Expand All @@ -255,6 +259,12 @@ impl RemoteWorkerConnection {
// fan them out to the appropriate `per_partition_rx` based on the "partition" declared
// in each individual record batch flight metadata.
let task = SpawnedTask::spawn(async move {
tokio::select! {
biased;
_ = cancel.cancelled() => return,
_ = first_poll_notify_for_task.notified() => {}
}

let mut client = match channel_resolver.get_worker_client_for_url(&url).await {
Ok(v) => v,
Err(err) => {
Expand Down Expand Up @@ -364,6 +374,7 @@ impl RemoteWorkerConnection {
not_consumed_streams: Arc::new(AtomicUsize::new(per_partition_rx.len())),
per_partition_rx,
mem_available_notify,
first_poll_notify,

// metrics stuff
memory_reservation: memory_reservation_clone,
Expand All @@ -375,14 +386,17 @@ impl RemoteWorkerConnection {
impl WorkerConnection for RemoteWorkerConnection {
/// Streams the provided `partition` from the remote worker.
///
/// Note that this does not issue a network request, the actual network request happened before
/// in the init step, and is in charge of handling not only this `partition`, but also all the
/// partitions passed in `target_partition_range`. This method just streams all the record
/// batches belonging to the provided `partition` from an in-memory queue, but what populates
/// this queue is [WorkerConnection::init].
/// This method does not handle any network connection. Instead, the network comms are delegated
/// to the task spawned by [WorkerConnection::init], who is in charge of polling data not only
/// from the requested `partition`, but from any other partition in `target_partition_range`.
/// This method just streams all the record batches belonging to the provided `partition` from
/// an in-memory queue.
///
/// The task that polls data over the network is held inactive until the first poll to the
/// stream returned by this method.
///
/// When the returned stream is dropped (e.g., due to query cancellation), the background task
/// pulling from the Flight stream will be cancelled promptly.
/// pulling from the Flight stream will be canceled promptly.
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>> {
let Some((_, partition_receiver)) = self.per_partition_rx.remove(&partition) else {
return internal_err!(
Expand All @@ -392,7 +406,13 @@ impl WorkerConnection for RemoteWorkerConnection {
let task = Arc::clone(&self.task);
let cancel_token = self.cancel_token.clone();

let stream = UnboundedReceiverStream::new(partition_receiver);
let first_poll_notify = Arc::clone(&self.first_poll_notify);
let stream = async move {
first_poll_notify.notify_one();
UnboundedReceiverStream::new(partition_receiver)
}
.flatten_stream();

let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err)));
let reservation = Arc::clone(&self.memory_reservation);
let mem_available_notify = Arc::clone(&self.mem_available_notify);
Expand Down
12 changes: 6 additions & 6 deletions src/worker/worker_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::time::Duration;
use tonic::codegen::BoxStream;
use tonic::{Request, Response, Status, Streaming};

const TASK_CACHE_TTI: Duration = Duration::from_mins(10);

#[allow(clippy::type_complexity)]
#[derive(Clone, Default)]
pub(super) struct WorkerHooks {
Expand All @@ -35,9 +37,9 @@ pub(crate) type TaskDataEntries = Cache<TaskKey, Arc<SingleWriteMultiRead<Result
#[derive(Clone)]
pub struct Worker {
pub(super) runtime: Arc<RuntimeEnv>,
/// TTL-based cache for task execution data. Entries are automatically evicted after 60 seconds.
/// This prevents memory leaks from abandoned or incomplete queries while allowing concurrent
/// access to task results across multiple partition requests.
/// TTL-based cache for task execution data. Entries are automatically evicted after
/// TASK_CACHE_TTI seconds. This prevents memory leaks from abandoned or incomplete queries
/// while allowing concurrent access to task results across multiple partition requests.
pub(super) task_data_entries: Arc<TaskDataEntries>,
pub(super) session_builder: Arc<dyn WorkerSessionBuilder + Send + Sync>,
pub(super) hooks: WorkerHooks,
Expand All @@ -47,9 +49,7 @@ pub struct Worker {

impl Default for Worker {
fn default() -> Self {
let cache = Cache::builder()
.time_to_idle(Duration::from_secs(60))
.build();
let cache = Cache::builder().time_to_idle(TASK_CACHE_TTI).build();
Self {
runtime: Arc::new(RuntimeEnv::default()),
task_data_entries: Arc::new(cache),
Expand Down
Loading