diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py index b717fc38ca1e..f2a84ec0416b 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py @@ -21,9 +21,9 @@ import math import random from abc import ABC, abstractmethod -from collections import namedtuple +from collections import OrderedDict, namedtuple from dataclasses import MISSING, astuple, dataclass, field, fields, replace -from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from tensorrt_llm.logger import logger @@ -178,6 +178,20 @@ def create( kv_cache_manager has block reuse enabled; DefaultADPRouter otherwise. """ + if ( + attention_dp_config is not None + and attention_dp_config.kv_cache_routing_conversation_affinity + ): + # Explicit conversation_id -> rank affinity. Independent of the KV + # cache manager (works with or without block reuse, though it is + # most beneficial with reuse on), so it is checked before the + # KV-cache-aware path and takes precedence when both are enabled. + return ConversationAwareADPRouter( + dist=dist, + max_sessions=attention_dp_config.kv_cache_routing_max_sessions, + fair_share_multiplier=attention_dp_config.kv_cache_routing_fair_share_multiplier, + ) + if ( attention_dp_config is not None and attention_dp_config.enable_kv_cache_aware_routing @@ -234,6 +248,58 @@ def gather_all_rank_states( responses = self.dist.tp_allgather(local_state.serialize()) return [RankState.deserialize(data=resp) for resp in responses] + @staticmethod + def _assign_explicit_dp_ranks( + requests: List["RequestQueueItem"], + all_ranks_new_requests: Dict[int, List["RequestQueueItem"]], + all_ranks_num_active_requests: List[int], + max_num_active_requests: int, + ) -> List["RequestQueueItem"]: + """Place requests carrying an explicit ``attention_dp_rank`` on that rank + (while it is under ``max_num_active_requests``) and return the rest for + load-balanced assignment. Mutates the ``all_ranks_*`` accumulators in + place. Shared by Default and ConversationAware routers. + """ + remaining: List["RequestQueueItem"] = [] + for req_item in requests: + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + target_dp_rank = ( + scheduling_params.attention_dp_rank if scheduling_params is not None else None + ) + if ( + target_dp_rank is not None + and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests + ): + all_ranks_num_active_requests[target_dp_rank] += 1 + all_ranks_new_requests[target_dp_rank].append(req_item) + else: + remaining.append(req_item) + return remaining + + @staticmethod + def _expected_num_active_requests( + all_ranks_num_active_requests: List[int], + num_new_requests: int, + tp_size: int, + *, + multiplier: float = 1.0, + hard_cap: Optional[int] = None, + ) -> int: + """Loose per-rank target = ``multiplier * ceil(fair-share)``, floored at the + busiest rank's current load and optionally capped at ``hard_cap`` + (``max_num_active_requests``). It is the soft cap that bounds per-rank + assignment so no rank exceeds the returned value -- the ADP padding + invariant asserted in py_executor._pad_attention_dp_dummy_request. + """ + fair_share = ( + sum(all_ranks_num_active_requests) + num_new_requests + tp_size - 1 + ) // tp_size + expected = max( + math.ceil(multiplier * fair_share), + max(all_ranks_num_active_requests), + ) + return min(expected, hard_cap) if hard_cap is not None else expected + @abstractmethod def route_requests( self, @@ -310,33 +376,21 @@ def get_relax_value(req_item): sorted_requests = sorted(new_requests, key=get_relax_value) - remaining_unscheduled = [] - for req_item in sorted_requests: - scheduled = False - scheduling_params = getattr(req_item.request, "py_scheduling_params", None) - if scheduling_params is not None: - target_dp_rank = scheduling_params.attention_dp_rank - if ( - target_dp_rank is not None - and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests - ): - all_ranks_num_active_requests[target_dp_rank] += 1 - scheduled = True - all_ranks_new_requests[target_dp_rank].append(req_item) - - if not scheduled: - remaining_unscheduled.append(req_item) + remaining_unscheduled = self._assign_explicit_dp_ranks( + sorted_requests, + all_ranks_new_requests, + all_ranks_num_active_requests, + max_num_active_requests, + ) num_new_requests_all_ranks = len(remaining_unscheduled) - total_num_active_requests = sum(all_ranks_num_active_requests) # Cap at max_num_active_requests so the per-rank target never # exceeds what the rank can physically schedule. - expected_num_active_requests = min( - max( - (total_num_active_requests + num_new_requests_all_ranks + tp_size - 1) // tp_size, - max(all_ranks_num_active_requests), - ), - max_num_active_requests, + expected_num_active_requests = self._expected_num_active_requests( + all_ranks_num_active_requests, + num_new_requests_all_ranks, + tp_size, + hard_cap=max_num_active_requests, ) all_ranks_new_requests = self._balance_requests_across_ranks( @@ -643,16 +697,12 @@ def _sort_key(req_item): # max_num_active_requests so the per-rank target never exceeds what # the rank can physically schedule (matches DefaultADPRouter). num_new_requests_all_ranks = len(remaining_unscheduled) - total_num_active_requests = sum(all_ranks_num_active_requests) - fair_share = ( - total_num_active_requests + num_new_requests_all_ranks + tp_size - 1 - ) // tp_size - expected_num_active_requests = min( - max( - math.ceil(self.fair_share_multiplier * fair_share), - max(all_ranks_num_active_requests), - ), - max_num_active_requests, + expected_num_active_requests = self._expected_num_active_requests( + all_ranks_num_active_requests, + num_new_requests_all_ranks, + tp_size, + multiplier=self.fair_share_multiplier, + hard_cap=max_num_active_requests, ) eligible_ranks = [ rank @@ -726,3 +776,158 @@ def _sort_key(req_item): ) return all_ranks_new_requests, expected_num_active_requests + + +class ConversationAwareADPRouter(ADPRouter): + """Pins each conversation to a single attention-DP rank: the first request + of a conversation is round-robined, and every later request with the same + ``conversation_id`` returns to that rank, keeping the conversation's + KV-cache prefix on one rank. Falls back to load-balanced round-robin when no + ``conversation_id`` is present. + """ + + # Default LRU cap on the conversation->rank map (entries are ~tens of + # bytes each). Bounds memory on long-running servers as conversations churn. + DEFAULT_MAX_SESSIONS = 1 << 16 + + def __init__( + self, + dist: "Distributed", + max_sessions: int = DEFAULT_MAX_SESSIONS, + fair_share_multiplier: float = 2.0, + ): + super().__init__(dist) + self._conv_to_rank: "OrderedDict[str, int]" = OrderedDict() + self._max_sessions = max(1, int(max_sessions)) + self._fair_share_multiplier = max(1.0, float(fair_share_multiplier)) + self._round_robin_cursor = 0 + + def create_rank_state( + self, + active_requests: list[LlmRequest], + new_requests: list[RequestQueueItem], + ) -> RankState: + if self.dist.has_cp_helix: + num_active_tokens = sum(req.total_input_len_cp for req in active_requests) + else: + num_active_tokens = sum(req.py_orig_prompt_len for req in active_requests) + return RankState( + rank=self.dist.tp_rank, + num_active_requests=len(active_requests), + num_active_tokens=num_active_tokens, + ) + + @staticmethod + def _conversation_id(req_item) -> "str | None": + req = getattr(req_item, "request", None) + if req is None: + return None + disagg = getattr(req, "py_disaggregated_params", None) + if disagg is None: + return None + conv_id = getattr(disagg, "conversation_id", None) + return conv_id if conv_id else None + + def _record_target_rank(self, conv_id: str, rank: int) -> None: + """Bind/refresh the rank a conversation is pinned to (LRU-touch + evict).""" + self._conv_to_rank[conv_id] = rank + self._conv_to_rank.move_to_end(conv_id) + while len(self._conv_to_rank) > self._max_sessions: + self._conv_to_rank.popitem(last=False) + + def route_requests( + self, + all_rank_states: list[RankState], + new_requests: list[RequestQueueItem], + max_num_active_requests: int, + ) -> Tuple[Dict[int, List[RequestQueueItem]], int]: + tp_size = len(all_rank_states) + all_ranks_new_requests: Dict[int, List[RequestQueueItem]] = { + s.rank: [] for s in all_rank_states + } + all_ranks_num_active_requests = [s.num_active_requests for s in all_rank_states] + + def get_relax_value(req_item): + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is None: + return True + return scheduling_params.attention_dp_relax + + sorted_requests = sorted(new_requests, key=get_relax_value) + + # 1) Honour an explicit attention_dp_rank first (strict placement). + remaining_unscheduled = self._assign_explicit_dp_ranks( + sorted_requests, + all_ranks_new_requests, + all_ranks_num_active_requests, + max_num_active_requests, + ) + + # 2) Loose soft cap for spreading new conversations across ranks; sticky + # returns may exceed it (hard cap), so it is re-bumped after the loop. + expected_num_active_requests = self._expected_num_active_requests( + all_ranks_num_active_requests, + len(remaining_unscheduled), + tp_size, + multiplier=self._fair_share_multiplier, + ) + + def _least_loaded(soft_cap: int) -> int: + """Lowest-active rank under soft_cap (deterministic), else global min.""" + cands = [r for r in range(tp_size) if all_ranks_num_active_requests[r] < soft_cap] + if not cands: + cands = list(range(tp_size)) + return min(cands, key=lambda r: (all_ranks_num_active_requests[r], r)) + + def _next_rr(soft_cap: int) -> int: + """Round-robin to the next rank under soft_cap; advance the cursor.""" + for _ in range(tp_size): + r = self._round_robin_cursor % tp_size + self._round_robin_cursor = (self._round_robin_cursor + 1) % tp_size + if all_ranks_num_active_requests[r] < soft_cap: + return r + return _least_loaded(soft_cap) + + for req_item in remaining_unscheduled: + conv_id = self._conversation_id(req_item) + rank = None + + if conv_id is not None and conv_id in self._conv_to_rank: + target_rank = self._conv_to_rank[conv_id] + # Sticky: return the conversation to its target rank while that + # rank is under the HARD cap (max_num_active_requests), NOT the + # soft fair-share `expected`. Stickiness wins over balance so a + # conversation's KV prefix stays resident on one rank instead of + # migrating when the rank gets moderately busy. + if all_ranks_num_active_requests[target_rank] < max_num_active_requests: + rank = target_rank + self._record_target_rank(conv_id, target_rank) # LRU touch + # else: target saturated this batch -> fall through to overflow + # WITHOUT rebinding, so the conversation returns next batch. + + if rank is None: + # First turn of a new conversation, sticky-overflow, or no + # conversation_id -> round-robin spread under the soft cap. + rank = _next_rr(expected_num_active_requests) + if conv_id is not None and conv_id not in self._conv_to_rank: + # Bind this new conversation to its first-turn rank. + self._record_target_rank(conv_id, rank) + + all_ranks_new_requests[rank].append(req_item) + all_ranks_num_active_requests[rank] += 1 + + # Sticky returns use the hard cap, so a rank may now exceed the pre-loop + # soft `expected`. Re-bump so the returned value covers the actual + # per-rank max -- _pad_attention_dp_dummy_request asserts + # expected >= len(active_requests) on every rank. + expected_num_active_requests = max( + expected_num_active_requests, max(all_ranks_num_active_requests) + ) + + logger.debug( + f"[adp_router][conv] new_reqs_per_rank=" + f"{[len(all_ranks_new_requests[r]) for r in range(tp_size)]} " + f"tracked_convs={len(self._conv_to_rank)}" + ) + + return all_ranks_new_requests, expected_num_active_requests diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 6d6459a8afda..93b9a38f2215 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -55,6 +55,10 @@ class DisaggregatedParams: ctx_info_endpoint: Optional[str] = None schedule_style: Optional[DisaggScheduleStyle] = None ctx_usage: Optional[Dict[str, Any]] = None + # Multi-turn conversation id (from session headers such as X-Session-ID), + # carried through so worker-side consumers (e.g. the ADP router) can see + # the same id the disagg orchestrator routed on. + conversation_id: Optional[str] = None # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 673a139372f9..91b7cffdefec 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -861,6 +861,28 @@ class AttentionDpConfig(StrictBaseModel): "scatter requests that would otherwise consolidate on a single warm " "rank, wasting prefill. Default False preserves pre-warmup routing. " "Only used when enable_kv_cache_aware_routing is True.") + kv_cache_routing_conversation_affinity: bool = Field( + default=False, + description= + "Enable explicit conversation-affinity routing for attention DP. When " + "True, the first request of each conversation is round-robined across " + "ranks and every subsequent request carrying the same conversation_id " + "(read from disaggregated_params, populated from the X-Session-ID " + "header) is pinned to that conversation's first-turn rank. This keeps a " + "multi-turn conversation's KV-cache prefix on one rank (maximizing " + "block reuse, minimizing cross-rank migration). Unlike " + "enable_kv_cache_aware_routing (affinity inferred from prefix-match " + "length, which is lost when blocks are evicted), the conversation->rank " + "map is explicit and survives eviction. Falls back to load-balanced " + "round-robin when no conversation_id is available. Takes precedence " + "over enable_kv_cache_aware_routing when both are set.") + kv_cache_routing_max_sessions: int = Field( + default=65536, + description= + "LRU cap on the conversation->rank map used by conversation-affinity " + "routing. The oldest conversations are evicted once more than this many " + "are tracked, bounding memory on long-running servers. Only used when " + "kv_cache_routing_conversation_affinity is True.") @model_validator(mode='after') def validate_attention_dp_config(self) -> 'AttentionDpConfig': diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 96316d2eb27d..beba7e7d46d7 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -1311,6 +1311,7 @@ def to_disaggregated_params( ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint, schedule_style=tllm_disagg_params.schedule_style, ctx_usage=ctx_usage, + conversation_id=tllm_disagg_params.conversation_id, ) @@ -1335,6 +1336,7 @@ def to_llm_disaggregated_params( ctx_info_endpoint=disaggregated_params.ctx_info_endpoint, schedule_style=disaggregated_params.schedule_style, ctx_usage=None if ctx_usage is None else ctx_usage.model_dump(), + conversation_id=disaggregated_params.conversation_id, ) diff --git a/tests/unittest/_torch/executor/test_adp_router.py b/tests/unittest/_torch/executor/test_adp_router.py index 4ead152409b7..b258e622b848 100644 --- a/tests/unittest/_torch/executor/test_adp_router.py +++ b/tests/unittest/_torch/executor/test_adp_router.py @@ -16,6 +16,7 @@ from tensorrt_llm._torch.pyexecutor.scheduler import FCFSWaitingQueue from tensorrt_llm._torch.pyexecutor.scheduler.adp_router import ( ADPRouter, + ConversationAwareADPRouter, DefaultADPRouter, KVCacheAwareADPRouter, RankIterStatsPayload, @@ -1085,3 +1086,176 @@ def test_fair_share_multiplier_caps_per_rank(self): ) assert len(result_strict[0]) == 2 assert sum(len(result_strict[r]) for r in range(1, tp_size)) == 6 + + +def _make_conv_request_item( + req_id, + conversation_id, + target_dp_rank=None, + attention_dp_relax=True, + num_tokens=10, +): + """Mock RequestQueueItem carrying a conversation_id on disaggregated_params.""" + item = MagicMock() + item.id = req_id + item.child_req_ids = None + scheduling_params = MagicMock() + scheduling_params.attention_dp_rank = target_dp_rank + scheduling_params.attention_dp_relax = attention_dp_relax + item.request = MagicMock() + item.request.py_scheduling_params = scheduling_params + item.request.input_token_ids = list(range(num_tokens)) + item.request.py_orig_prompt_len = num_tokens + if conversation_id is None: + item.request.py_disaggregated_params = None + else: + disagg = MagicMock() + disagg.conversation_id = conversation_id + item.request.py_disaggregated_params = disagg + return item + + +class TestConversationAwareADPRouter: + """Routing behavior of ConversationAwareADPRouter (explicit conv->rank).""" + + @staticmethod + def _router(tp_size=4, max_sessions=1 << 16): + return ConversationAwareADPRouter( + dist=_mock_dist(tp_size=tp_size), max_sessions=max_sessions + ) + + @staticmethod + def _states(tp_size, active=None): + active = active or [0] * tp_size + return [ + RankState(rank=r, num_active_requests=active[r], num_active_tokens=active[r] * 10) + for r in range(tp_size) + ] + + @staticmethod + def _route(router, states, items, cap=1000): + assign, _ = router.route_requests(states, items, max_num_active_requests=cap) + pos = {} + for rank, lst in assign.items(): + for it in lst: + pos[it.id] = rank + return pos + + def test_interface_compliance(self): + assert isinstance(self._router(), ADPRouter) + + def test_first_turn_round_robin(self): + """First request of each conversation spreads one-per-rank (RR).""" + router = self._router(tp_size=4) + items = [_make_conv_request_item(i, f"conv{i}") for i in range(4)] + pos = self._route(router, self._states(4), items) + assert sorted(pos.values()) == [0, 1, 2, 3] + + def test_subsequent_turns_are_sticky(self): + """Later turns of a conversation return to its first-turn rank.""" + router = self._router(tp_size=4) + convs = ["A", "B", "C", "D"] + first = [_make_conv_request_item(i, convs[i]) for i in range(4)] + pos1 = self._route(router, self._states(4), first) + home = {convs[i]: pos1[i] for i in range(4)} + # A new batch of second turns (fresh ids) for the same conversations. + second = [_make_conv_request_item(100 + i, convs[i]) for i in range(4)] + pos2 = self._route(router, self._states(4), second) + for i in range(4): + assert pos2[100 + i] == home[convs[i]] + + def test_no_conversation_id_falls_back_and_is_not_recorded(self): + router = self._router(tp_size=4) + items = [_make_conv_request_item(i, None) for i in range(8)] + pos = self._route(router, self._states(4), items) + assert len(pos) == 8 + # Nothing recorded for conversation-less requests. + assert dict(router._conv_to_rank) == {} + # Round-robin spread keeps ranks within one of each other. + counts = [sum(1 for v in pos.values() if v == r) for r in range(4)] + assert max(counts) - min(counts) <= 1 + + def test_routing_is_deterministic_across_ranks(self): + """Two independent instances (= two TP ranks) must agree exactly, or + the no-broadcast allgather protocol would deadlock.""" + convs = ["A", "B", "C", None, "A", "B", "E", None, "C"] + + def run(): + router = self._router(tp_size=4) + items = [_make_conv_request_item(i, convs[i]) for i in range(len(convs))] + return self._route(router, self._states(4), items) + + assert run() == run() + + def test_lru_eviction_bounds_map(self): + router = self._router(tp_size=2, max_sessions=2) + items = [_make_conv_request_item(i, f"c{i}") for i in range(5)] + self._route(router, self._states(2), items) + assert len(router._conv_to_rank) == 2 + assert set(router._conv_to_rank) == {"c3", "c4"} + + def test_explicit_target_dp_rank_respected(self): + router = self._router(tp_size=4) + item = _make_conv_request_item(1, "A", target_dp_rank=2, attention_dp_relax=False) + pos = self._route(router, self._states(4), [item]) + assert pos[1] == 2 + + def test_sticky_overflow_keeps_mapping(self): + """When the home rank is saturated at the HARD cap (max_num_active_requests), + the turn overflows off home but the conversation stays mapped home (so it + returns home next batch once that rank has capacity). Stickiness uses the + hard cap -- not the soft fair-share -- so a conversation's KV prefix stays + resident on one rank in the common case.""" + router = ConversationAwareADPRouter(dist=_mock_dist(tp_size=2)) + # Seed conv A -> rank 0. + self._route(router, self._states(2), [_make_conv_request_item(1, "A")]) + home = router._conv_to_rank["A"] + # Fill the home rank to the hard cap, route A again -> must overflow off home... + states = self._states(2) + states[home] = RankState(rank=home, num_active_requests=5, num_active_tokens=50) + pos = self._route(router, states, [_make_conv_request_item(2, "A")], cap=5) + assert pos[2] != home + # ...but the mapping is unchanged so it can return home later. + assert router._conv_to_rank["A"] == home + + def test_create_rank_state(self): + router = ConversationAwareADPRouter(dist=_mock_dist(tp_rank=2)) + req1 = Mock(py_orig_prompt_len=100) + req2 = Mock(py_orig_prompt_len=50) + state = router.create_rank_state(active_requests=[req1, req2], new_requests=[]) + assert state.rank == 2 + assert state.num_active_requests == 2 + assert state.num_active_tokens == 150 + + def test_factory_selects_conversation_router(self): + cfg = MagicMock() + cfg.kv_cache_routing_conversation_affinity = True + cfg.kv_cache_routing_max_sessions = 8 + router = ADPRouter.create(dist=_mock_dist(), kv_cache_manager=None, attention_dp_config=cfg) + assert isinstance(router, ConversationAwareADPRouter) + assert router._max_sessions == 8 + + def test_factory_default_when_disabled(self): + cfg = MagicMock() + cfg.kv_cache_routing_conversation_affinity = False + cfg.enable_kv_cache_aware_routing = False + router = ADPRouter.create(dist=_mock_dist(), kv_cache_manager=None, attention_dp_config=cfg) + assert isinstance(router, DefaultADPRouter) + + def test_returned_expected_covers_every_rank(self): + """Regression: returned expected_num_active_requests must be >= every + rank's post-assignment active count, else + py_executor._pad_attention_dp_dummy_request asserts and the executor + hangs. A sticky storm (one conversation hammered) must not push its home + rank past expected.""" + router = self._router(tp_size=8) + states = self._states(8) + # 40 requests for the SAME conversation: hard-cap stickiness concentrates + # all of them on the home rank, pushing it well past the pre-loop soft + # `expected`. The post-loop re-bump must lift `expected` to cover the home + # rank's actual count, else _pad_attention_dp_dummy_request asserts. + items = [_make_conv_request_item(i, "A") for i in range(40)] + assign, expected = router.route_requests(states, items, max_num_active_requests=256) + final = [states[r].num_active_requests + len(assign[r]) for r in range(8)] + assert all(expected >= f for f in final), (expected, final) + assert sum(len(v) for v in assign.values()) == 40 # nothing dropped diff --git a/tests/unittest/disaggregated/test_disaggregated_params.py b/tests/unittest/disaggregated/test_disaggregated_params.py index c20f81fcc30f..655fbcabbd13 100644 --- a/tests/unittest/disaggregated/test_disaggregated_params.py +++ b/tests/unittest/disaggregated/test_disaggregated_params.py @@ -64,6 +64,7 @@ def test_to_disaggregated_params(): "cached_tokens": 4, }, }, + conversation_id="conv-abc", ) openai_params = to_disaggregated_params(llm_params) @@ -74,6 +75,7 @@ def test_to_disaggregated_params(): assert openai_params.ctx_info_endpoint == "tcp://10.0.0.1:5000" assert openai_params.ctx_usage.prompt_tokens == 10 assert openai_params.ctx_usage.prompt_tokens_details.cached_tokens == 4 + assert openai_params.conversation_id == "conv-abc" def test_to_llm_disaggregated_params(): @@ -94,6 +96,7 @@ def test_to_llm_disaggregated_params(): total_tokens=10, prompt_tokens_details=PromptTokensDetails(cached_tokens=4), ), + conversation_id="conv-xyz", ) llm_params = to_llm_disaggregated_params(openai_params) @@ -103,6 +106,26 @@ def test_to_llm_disaggregated_params(): assert llm_params.ctx_info_endpoint == "tcp://10.0.0.1:5000" assert llm_params.ctx_usage["prompt_tokens"] == 10 assert llm_params.ctx_usage["prompt_tokens_details"]["cached_tokens"] == 4 + assert llm_params.conversation_id == "conv-xyz" + + +def test_disaggregated_params_conversation_id(): + """conversation_id defaults to None and survives the serve<->llm round-trip.""" + from tensorrt_llm.serve.openai_protocol import DisaggregatedParams as OpenAIDisaggregatedParams + from tensorrt_llm.serve.openai_protocol import ( + to_disaggregated_params, + to_llm_disaggregated_params, + ) + + assert DisaggregatedParams().conversation_id is None + + # serve -> llm -> serve preserves the conversation id end to end. + openai_params = OpenAIDisaggregatedParams( + request_type="context_only", conversation_id="conv-roundtrip" + ) + llm_params = to_llm_disaggregated_params(openai_params) + assert llm_params.conversation_id == "conv-roundtrip" + assert to_disaggregated_params(llm_params).conversation_id == "conv-roundtrip" @patch("tensorrt_llm.disaggregated_params.tllme")