Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 217 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import math
import random
from abc import ABC, abstractmethod
from collections import namedtuple
from collections import OrderedDict, namedtuple
from dataclasses import MISSING, astuple, dataclass, field, fields, replace
from typing import TYPE_CHECKING, Dict, List, Set, Tuple

Expand Down Expand Up @@ -178,6 +178,21 @@ def create(
kv_cache_manager has block reuse enabled; DefaultADPRouter
otherwise.
"""
if attention_dp_config is not None and getattr(
attention_dp_config, "kv_cache_routing_conversation_affinity", False
):
# Explicit conversation_id -> rank affinity. Independent of the KV
# cache manager (works with or without block reuse, though it is
# most beneficial with reuse on), so it is checked before the
# KV-cache-aware path and takes precedence when both are enabled.
return ConversationAwareADPRouter(
dist=dist,
max_sessions=getattr(
attention_dp_config, "kv_cache_routing_max_sessions", 1 << 16
),
fair_share_multiplier=attention_dp_config.kv_cache_routing_fair_share_multiplier,
)

if (
attention_dp_config is not None
and attention_dp_config.enable_kv_cache_aware_routing
Expand Down Expand Up @@ -726,3 +741,204 @@ def _sort_key(req_item):
)

return all_ranks_new_requests, expected_num_active_requests


class ConversationAwareADPRouter(ADPRouter):
"""Conversation-affinity request router for attention data parallelism.

Routes the *first* request of each conversation round-robin across ranks,
then pins every later request carrying the same ``conversation_id`` to the
rank that served that conversation's first turn. This keeps a multi-turn
conversation's growing KV-cache prefix on a single rank (maximizing block
reuse, minimizing recompute and cross-rank migration) while still spreading
the birth of new conversations evenly.

Contrast with :class:`KVCacheAwareADPRouter`, which *infers* affinity from
probed prefix-match lengths: that affinity is lost as soon as a
conversation's blocks are evicted -- the request then re-routes by load and
the conversation "migrates" ranks. This router keeps an **explicit**
``conversation_id -> rank`` map, so stickiness is deterministic and survives
cache eviction. It is inspired by the serve-level ConversationRouter
(``tensorrt_llm/serve/router.py``) and the first-turn-round-robin idea from
PR #14744, but applied at the intra-instance ADP-rank level.

``conversation_id`` is read from
``req.py_disaggregated_params.conversation_id`` (populated by the serve-side
conversation / KV-aware router from the ``X-Session-ID`` header; see
PR #14744). When it is absent -- header not sent, non-disaggregated, or the
serve-side propagation is not present -- the request falls back to
load-balanced round-robin and is *not* recorded, so behaviour degrades
gracefully to ``DefaultADPRouter``-style spreading.

Determinism: :meth:`route_requests` runs locally on every TP rank with no
broadcast, so the round-robin cursor and the conversation->rank map MUST
evolve identically on every rank. They do, because every rank processes the
same ``new_requests`` in the same order. Any divergence would deadlock the
distributed allgather protocol -- this is the same invariant the warmup /
first-turn-round-robin cursors rely on above.
"""

# Default LRU cap on the conversation->rank map (entries are ~tens of
# bytes each). Bounds memory on long-running servers as conversations churn.
DEFAULT_MAX_SESSIONS = 1 << 16

def __init__(self, dist: "Distributed", max_sessions: int = DEFAULT_MAX_SESSIONS,
fair_share_multiplier: float = 2.0):
super().__init__(dist)
# conversation_id -> rank, LRU-ordered (most-recently-routed last).
self._conv_to_rank: "OrderedDict[str, int]" = OrderedDict()
self._max_sessions = max(1, int(max_sessions))
# Loose per-rank cap = fair_share_multiplier * ceil fair-share. Sticky
# concentration is allowed up to this but NEVER beyond, so the returned
# expected_num_active_requests stays >= every rank's active count (the
# ADP invariant asserted in py_executor._pad_attention_dp_dummy_request).
self._fair_share_multiplier = max(1.0, float(fair_share_multiplier))
# Round-robin cursor for first-turn / unkeyed requests. Mutated
# identically on every rank (route_requests is deterministic), exactly
# like KVCacheAwareADPRouter._first_turn_rr_counter; divergence deadlocks.
self._rr_counter = 0

def create_rank_state(
self,
active_requests: list[LlmRequest],
new_requests: list[RequestQueueItem],
) -> RankState:
if self.dist.has_cp_helix:
num_active_tokens = sum(req.total_input_len_cp for req in active_requests)
else:
num_active_tokens = sum(req.py_orig_prompt_len for req in active_requests)
return RankState(
rank=self.dist.tp_rank,
num_active_requests=len(active_requests),
num_active_tokens=num_active_tokens,
)

@staticmethod
def _conversation_id(req_item) -> "str | None":
"""Return the request's conversation id, or None when unavailable.

Read from ``py_disaggregated_params.conversation_id`` (serve-side
propagated from the X-Session-ID header). Empty strings are treated as
absent so they fall through to the load-balanced path.
"""
req = getattr(req_item, "request", None)
if req is None:
return None
disagg = getattr(req, "py_disaggregated_params", None)
if disagg is None:
return None
conv_id = getattr(disagg, "conversation_id", None)
return conv_id if conv_id else None

def _record_home(self, conv_id: str, rank: int) -> None:
"""Bind/refresh a conversation's home rank with LRU eviction."""
self._conv_to_rank[conv_id] = rank
self._conv_to_rank.move_to_end(conv_id)
while len(self._conv_to_rank) > self._max_sessions:
self._conv_to_rank.popitem(last=False)

def route_requests(
self,
all_rank_states: list[RankState],
new_requests: list[RequestQueueItem],
max_num_active_requests: int,
) -> Tuple[Dict[int, List[RequestQueueItem]], int]:
tp_size = len(all_rank_states)
all_ranks_new_requests: Dict[int, List[RequestQueueItem]] = {
s.rank: [] for s in all_rank_states
}
all_ranks_num_active_requests = [s.num_active_requests for s in all_rank_states]

def get_relax_value(req_item):
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
if scheduling_params is None:
return True
return scheduling_params.attention_dp_relax

sorted_requests = sorted(new_requests, key=get_relax_value)

# 1) Honour an explicit attention_dp_rank first (strict placement),
# matching DefaultADPRouter / KVCacheAwareADPRouter.
remaining_unscheduled: List[RequestQueueItem] = []
for req_item in sorted_requests:
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
target_dp_rank = (
scheduling_params.attention_dp_rank if scheduling_params is not None else None
)
if (
target_dp_rank is not None
and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests
):
all_ranks_num_active_requests[target_dp_rank] += 1
all_ranks_new_requests[target_dp_rank].append(req_item)
else:
remaining_unscheduled.append(req_item)

# 2) Loose per-rank cap = fair_share_multiplier * ceil fair-share
# (mirrors KVCacheAwareADPRouter). BOTH new-conversation spreading and
# sticky returns are capped at this, so no rank ever exceeds the
# returned expected_num_active_requests -- exceeding it breaks the ADP
# padding invariant (py_executor._pad_attention_dp_dummy_request) and
# crashes the executor. The multiplier gives sticky concentration
# slack before a conversation overflows off its home rank.
num_new_requests_all_ranks = len(remaining_unscheduled)
total_num_active_requests = sum(all_ranks_num_active_requests)
fair_share = (
total_num_active_requests + num_new_requests_all_ranks + tp_size - 1
) // tp_size
expected_num_active_requests = max(
math.ceil(self._fair_share_multiplier * fair_share),
max(all_ranks_num_active_requests),
)

def _least_loaded(soft_cap: int) -> int:
"""Lowest-active rank under soft_cap (deterministic), else global min."""
cands = [r for r in range(tp_size) if all_ranks_num_active_requests[r] < soft_cap]
if not cands:
cands = list(range(tp_size))
return min(cands, key=lambda r: (all_ranks_num_active_requests[r], r))

def _next_rr(soft_cap: int) -> int:
"""Round-robin to the next rank under soft_cap; advance the cursor."""
for _ in range(tp_size):
r = self._rr_counter % tp_size
self._rr_counter = (self._rr_counter + 1) % tp_size
if all_ranks_num_active_requests[r] < soft_cap:
return r
return _least_loaded(soft_cap)

for req_item in remaining_unscheduled:
conv_id = self._conversation_id(req_item)
rank = None

if conv_id is not None and conv_id in self._conv_to_rank:
home = self._conv_to_rank[conv_id]
# Sticky: return the conversation to its home rank while that
# rank is under `expected` (the loose fair-share*multiplier cap).
# Capping here -- NOT at the hard max_num_active_requests -- is
# required: a rank exceeding `expected` breaks the ADP padding
# invariant and crashes the executor.
if all_ranks_num_active_requests[home] < expected_num_active_requests:
rank = home
self._record_home(conv_id, home) # LRU touch
# else: home saturated this batch -> fall through to overflow
# WITHOUT rebinding, so the conversation returns home next batch.

if rank is None:
# First turn of a new conversation, sticky-overflow, or no
# conversation_id -> round-robin spread under the soft cap.
rank = _next_rr(expected_num_active_requests)
if conv_id is not None and conv_id not in self._conv_to_rank:
# Bind this new conversation's home to its first-turn rank.
self._record_home(conv_id, rank)

all_ranks_new_requests[rank].append(req_item)
all_ranks_num_active_requests[rank] += 1

logger.debug(
f"[adp_router][conv] new_reqs_per_rank="
f"{[len(all_ranks_new_requests[r]) for r in range(tp_size)]} "
f"tracked_convs={len(self._conv_to_rank)}"
)

return all_ranks_new_requests, expected_num_active_requests
4 changes: 4 additions & 0 deletions tensorrt_llm/disaggregated_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class DisaggregatedParams:
ctx_info_endpoint: Optional[str] = None
schedule_style: Optional[DisaggScheduleStyle] = None
ctx_usage: Optional[Dict[str, Any]] = None
# Multi-turn conversation id (from session headers such as X-Session-ID),
# carried through so worker-side consumers (e.g. the ADP router) can see
# the same id the disagg orchestrator routed on.
conversation_id: Optional[str] = None

# E-P Disaggregated Params
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
Expand Down
22 changes: 22 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,28 @@ class AttentionDpConfig(StrictBaseModel):
"scatter requests that would otherwise consolidate on a single warm "
"rank, wasting prefill. Default False preserves pre-warmup routing. "
"Only used when enable_kv_cache_aware_routing is True.")
kv_cache_routing_conversation_affinity: bool = Field(
default=False,
description=
"Enable explicit conversation-affinity routing for attention DP. When "
"True, the first request of each conversation is round-robined across "
"ranks and every subsequent request carrying the same conversation_id "
"(read from disaggregated_params, populated from the X-Session-ID "
"header) is pinned to that conversation's first-turn rank. This keeps a "
"multi-turn conversation's KV-cache prefix on one rank (maximizing "
"block reuse, minimizing cross-rank migration). Unlike "
"enable_kv_cache_aware_routing (affinity inferred from prefix-match "
"length, which is lost when blocks are evicted), the conversation->rank "
"map is explicit and survives eviction. Falls back to load-balanced "
"round-robin when no conversation_id is available. Takes precedence "
"over enable_kv_cache_aware_routing when both are set.")
kv_cache_routing_max_sessions: int = Field(
default=65536,
description=
"LRU cap on the conversation->rank map used by conversation-affinity "
"routing. The oldest conversations are evicted once more than this many "
"are tracked, bounding memory on long-running servers. Only used when "
"kv_cache_routing_conversation_affinity is True.")

@model_validator(mode='after')
def validate_attention_dp_config(self) -> 'AttentionDpConfig':
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ def to_disaggregated_params(
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint,
schedule_style=tllm_disagg_params.schedule_style,
ctx_usage=ctx_usage,
conversation_id=tllm_disagg_params.conversation_id,
)


Expand All @@ -1335,6 +1336,7 @@ def to_llm_disaggregated_params(
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint,
schedule_style=disaggregated_params.schedule_style,
ctx_usage=None if ctx_usage is None else ctx_usage.model_dump(),
conversation_id=disaggregated_params.conversation_id,
)


Expand Down
Loading
Loading