Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod children_helpers;
mod map_last_stream;
mod on_drop_stream;
mod once_lock;
mod task_context_helpers;
mod uuid;

pub(crate) use children_helpers::require_one_child;
pub(crate) use map_last_stream::map_last_stream;
pub(crate) use on_drop_stream::on_drop_stream;
pub(crate) use once_lock::OnceLockResult;
pub(crate) use task_context_helpers::task_ctx_with_extension;
pub(crate) use uuid::{deserialize_uuid, serialize_uuid};
5 changes: 5 additions & 0 deletions src/common/once_lock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use datafusion::error::DataFusionError;
use std::sync::{Arc, OnceLock};

/// A [OnceLock] that holds a clonable result.
pub(crate) type OnceLockResult<T> = OnceLock<Result<T, Arc<DataFusionError>>>;
4 changes: 2 additions & 2 deletions src/execution_plans/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::require_one_child;
use crate::common::{OnceLockResult, require_one_child};
use crossbeam_queue::SegQueue;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::runtime::SpawnedTask;
Expand Down Expand Up @@ -73,7 +73,7 @@ pub struct BroadcastExec {
input: Arc<dyn ExecutionPlan>,
consumer_task_count: usize,
properties: Arc<PlanProperties>,
queues: Vec<OnceLock<Result<StreamAndTask, Arc<DataFusionError>>>>,
queues: Vec<OnceLockResult<StreamAndTask>>,
}

type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl ExecutionPlan for NetworkBroadcastExec {
&context,
)?;

let stream = worker_connection.stream_partition(off + partition, |_meta| {})?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl ExecutionPlan for NetworkCoalesceExec {
&context,
)?;

let stream = worker_connection.stream_partition(target_partition, |_meta| {})?;
let stream = worker_connection.execute(target_partition)?;

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ExecutionPlan for NetworkShuffleExec {
&context,
)?;

let stream = worker_connection.stream_partition(off + partition, |_meta| {})?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}

Expand Down
1 change: 1 addition & 0 deletions src/worker/impl_coordinator_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl Worker {
task_count: request.task_count as usize,
}))
.with_extension(Arc::new(LocalWorkerContext {
task_data_entries: Arc::clone(&self.task_data_entries),
self_url: Url::parse(&request.target_worker_url)
.map_err(|e| DataFusionError::External(Box::new(e)))?,
}))
Expand Down
278 changes: 153 additions & 125 deletions src/worker/impl_execute_task.rs

Large diffs are not rendered by default.

169 changes: 140 additions & 29 deletions src/worker/worker_connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::common::{on_drop_stream, serialize_uuid};
use crate::common::{OnceLockResult, on_drop_stream, serialize_uuid};
use crate::metrics::LatencyMetricExt;
use crate::networking::get_distributed_channel_resolver;
use crate::passthrough_headers::get_passthrough_headers;
use crate::protobuf::{datafusion_error_to_tonic_status, map_flight_to_datafusion_error};
use crate::stage::RemoteStage;
use crate::worker::generated::worker::FlightAppMetadata;
use crate::worker::generated::worker::{ExecuteTaskRequest, TaskKey};
use crate::worker::impl_execute_task::execute_local_task;
use crate::worker::worker_service::TaskDataEntries;
use crate::{BytesMetricExt, ChannelResolver, DistributedConfig};
use arrow_flight::FlightData;
use arrow_flight::decode::FlightRecordBatchStream;
Expand All @@ -14,12 +16,13 @@ use dashmap::DashMap;
use datafusion::arrow::array::RecordBatch;
use datafusion::common::instant::Instant;
use datafusion::common::runtime::SpawnedTask;
use datafusion::common::{DataFusionError, Result, internal_err};
use datafusion::common::{DataFusionError, Result, internal_datafusion_err, internal_err};
use datafusion::execution::TaskContext;
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion::physical_expr_common::metrics::{ExecutionPlanMetricsSet, MetricValue};
use datafusion::physical_plan::metrics::{MetricBuilder, Time};
use futures::{Stream, TryStreamExt};
use futures::stream::BoxStream;
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use pin_project::{pin_project, pinned_drop};
use prost::Message;
Expand All @@ -28,12 +31,11 @@ use std::fmt::{Debug, Formatter};
use std::ops::Range;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::Notify;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::metadata::MetadataMap;
Expand All @@ -47,9 +49,10 @@ use url::Url;
/// This information can be used for executing tasks locally bypassing gRPC comms if the tasks that
/// needs to be remotely executed happens to be owned by this same worker.
pub(crate) struct LocalWorkerContext {
/// The registry of in-flight tasks the [crate::Worker] in the current scope owns.
pub(crate) task_data_entries: Arc<TaskDataEntries>,
/// The URL of the [crate::Worker] in scope. When trying to reach to a target URL that happens
/// to be the same as this one, local comms are preferred instead.
#[allow(dead_code)]
pub(crate) self_url: Url,
}

Expand All @@ -59,7 +62,7 @@ pub(crate) struct LocalWorkerContext {
/// it will initialize the corresponding position in the vector matching the provided `target_task`
/// index.
pub(crate) struct WorkerConnectionPool {
connections: Vec<OnceLock<Result<WorkerConnection, Arc<DataFusionError>>>>,
connections: Vec<OnceLockResult<Box<dyn WorkerConnection + Sync + Send>>>,
pub(crate) metrics: ExecutionPlanMetricsSet,
}

Expand All @@ -86,35 +89,60 @@ impl WorkerConnectionPool {
target_partitions: Range<usize>,
target_task: usize,
ctx: &Arc<TaskContext>,
) -> Result<&WorkerConnection> {
) -> Result<&(dyn WorkerConnection + Sync + Send)> {
let Some(worker_connection) = self.connections.get(target_task) else {
return internal_err!(
"WorkerConnections: Task index {target_task} not found, only have {} tasks",
self.connections.len()
);
};
ctx.session_config().get_extension::<LocalWorkerContext>();

let conn = worker_connection.get_or_init(|| {
WorkerConnection::init(
input_stage,
target_partitions,
target_task,
ctx,
&self.metrics,
)
.map_err(Arc::new)
let Some(target_url) = input_stage.workers.get(target_task) else {
internal_err!("input_stage.workers[{target_task}] out of range.")?
};
if let Some(lw_ctx) = ctx.session_config().get_extension::<LocalWorkerContext>()
&& &lw_ctx.self_url == target_url
{
// Instead of making a gRPC call to ourselves, better to just use local comms.
Ok(Box::new(LocalWorkerConnection::init(
input_stage,
target_partitions,
target_task,
lw_ctx,
&self.metrics,
)) as Box<_>)
} else {
// We are trying to reach a URL different from ours, so use normal gRPC streams.
RemoteWorkerConnection::init(
input_stage,
target_partitions,
target_task,
ctx,
&self.metrics,
)
.map(|v| Box::new(v) as Box<_>)
.map_err(Arc::new)
}
});

match conn {
Ok(v) => Ok(v),
Ok(v) => Ok(v.as_ref()),
Err(err) => Err(DataFusionError::Shared(Arc::clone(err))),
}
}
}

type WorkerMsg = Result<(FlightData, FlightAppMetadata), Status>;

/// Abstraction that allows treating remote and local comms as equal. Network boundaries do not
/// care if the stream comes over the wire or locally.
pub(crate) trait WorkerConnection {
/// Streams the specified partition. Consumers do not care if the implementation pulls data
/// from in-memory or from local comms.
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>>;
}

/// Represents a connection to one [Worker]. Network boundaries will use this for streaming
/// data from single partitions while the actual network communication is handling all the partitions
/// under the hood.
Expand All @@ -126,7 +154,7 @@ type WorkerMsg = Result<(FlightData, FlightAppMetadata), Status>;
/// the same underlying TCP connection, there do is some overhead in having one gRPC stream per
/// partition VS a single gRPC stream interleaving multiple partitions. The whole serialized plan
/// needs to be sent over the wire on every gRPC call, so the less gRPC calls we do the better.
pub(crate) struct WorkerConnection {
struct RemoteWorkerConnection {
task: Arc<SpawnedTask<()>>,
not_consumed_streams: Arc<AtomicUsize>,
cancel_token: CancellationToken,
Expand All @@ -140,7 +168,7 @@ pub(crate) struct WorkerConnection {
elapsed_compute: Time,
}

impl WorkerConnection {
impl RemoteWorkerConnection {
fn init(
input_stage: &RemoteStage,
target_partition_range: Range<usize>,
Expand Down Expand Up @@ -337,7 +365,9 @@ impl WorkerConnection {
elapsed_compute: elapsed_compute_clone,
})
}
}

impl WorkerConnection for RemoteWorkerConnection {
/// Streams the provided `partition` from the remote worker.
///
/// Note that this does not issue a network request, the actual network request happened before
Expand All @@ -348,11 +378,7 @@ impl WorkerConnection {
///
/// When the returned stream is dropped (e.g., due to query cancellation), the background task
/// pulling from the Flight stream will be cancelled promptly.
pub(crate) fn stream_partition(
&self,
partition: usize,
on_metadata: impl Fn(FlightAppMetadata) + Send + Sync + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + 'static> {
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>> {
let Some((_, partition_receiver)) = self.per_partition_rx.remove(&partition) else {
return internal_err!(
"WorkerConnection has no stream for target partition {partition}. Was it already consumed?"
Expand All @@ -365,12 +391,11 @@ impl WorkerConnection {
let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err)));
let reservation = Arc::clone(&self.memory_reservation);
let mem_available_notify = Arc::clone(&self.mem_available_notify);
let stream = stream.map_ok(move |(data, meta)| {
let stream = stream.map_ok(move |(data, _meta)| {
reservation.shrink(data.encoded_len());
// Wake the demux task in case it is blocked on the byte budget.
mem_available_notify.notify_one();
let _ = &task; // <- keep the task that polls data from the network alive.
on_metadata(meta);
data
});
let stream = FlightRecordBatchStream::new_from_flight_data(stream);
Expand All @@ -384,7 +409,93 @@ impl WorkerConnection {
if remaining_streams == 0 {
cancel_token.cancel();
}
}))
})
.boxed())
}
}

/// Equivalent to [RemoteWorkerConnection], but that pulls data from the local registry of tasks
/// rather than doing it across a gRPC interface.
pub(crate) struct LocalWorkerConnection {
partition_start: usize,
local_streams: Vec<Mutex<Option<BoxStream<'static, Result<RecordBatch>>>>>,
}

impl LocalWorkerConnection {
fn init(
input_stage: &RemoteStage,
target_partition_range: Range<usize>,
target_task: usize,
lw_ctx: Arc<LocalWorkerContext>,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
MetricBuilder::new(metrics)
.global_counter("local_connections_used")
.add(1);

let task_key = TaskKey {
query_id: serialize_uuid(&input_stage.query_id),
stage_id: input_stage.num as u64,
task_number: target_task as u64,
};

let partition_start = target_partition_range.start;
let mut local_streams = Vec::with_capacity(target_partition_range.len());
for partition_i in target_partition_range {
let request = ExecuteTaskRequest {
task_key: Some(task_key.clone()),
target_partition_start: partition_i as u64,
target_partition_end: (partition_i + 1) as u64,
};

let task_data_entries = Arc::clone(&lw_ctx.task_data_entries);

// The relevant entry from `task_data_entries` needs to be eagerly retrieved, it cannot be
// left for until someone decides to start polling the returned `BoxStream`, otherwise,
// there's risk that the entry is evicted by Moka's TTL, and by the time the returned stream
// is polled, the entry might not be there.
//
// Note that this does not start polling the returned streams, it just instantiates them.
let streams_future = SpawnedTask::spawn(async move {
let (streams, _) = execute_local_task(&task_data_entries, request).await?;
Ok::<_, DataFusionError>(streams)
});

let stream = async move {
let mut streams = streams_future
.await
.map_err(|err| internal_datafusion_err!("{err}"))??;
if streams.len() != 1 {
return internal_err!("Expected exactly 1 local stream");
}
Ok(streams.swap_remove(0))
}
.try_flatten_stream()
.boxed();

local_streams.push(Mutex::new(Some(stream)));
}

Self {
partition_start,
local_streams,
}
}
}

impl WorkerConnection for LocalWorkerConnection {
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>> {
let relative_i = partition - self.partition_start;
Comment thread
gabotechs marked this conversation as resolved.
Outdated
let Some(slot) = self.local_streams.get(relative_i) else {
return internal_err!(
"LocalWorkerConnection has no stream for partition {partition}. Was it already consumed?"
);
};
slot.lock().unwrap().take().ok_or_else(|| {
internal_datafusion_err!(
"LocalWorkerConnection stream for partition {partition} was already consumed"
)
})
}
}

Expand All @@ -408,7 +519,7 @@ impl Clone for WorkerConnectionPool {
}
}

impl Debug for WorkerConnection {
impl Debug for RemoteWorkerConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerConnection").finish()
}
Expand Down
Loading
Loading