diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 0e536ace8..59cce44c4 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -151,7 +151,7 @@ def _query_start_loc(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: device=device, ) cu_seqlens = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "cu_seqlens", None), + getattr(attn_inputs, "cu_seqlens_device", None), device=device, ) if cu_seqlens is not None and cu_seqlens.numel() > 1: @@ -183,7 +183,7 @@ def _query_start_loc(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: # Decode: query length is runtime step token count (usually 1 per sequence), # not prompt input_lengths. sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) sequence_lengths = RTPForwardContext._non_empty_int32( @@ -262,7 +262,7 @@ def _state_indices( if is_prefill: prefix_lengths = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "prefix_lengths_d", None), + getattr(attn_inputs, "prefix_lengths_device", None), device=device, ) if prefix_lengths is None: @@ -283,7 +283,7 @@ def _state_indices( else: # RTP decode kernels use sequence_lengths_plus_1_d as canonical runtime value. sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) if sequence_lengths_plus_1 is not None: @@ -621,7 +621,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: # For chunked prefill, prefix_lengths can remain per-chunk while # sequence_lengths_plus_1_d tracks the true cumulative context length. sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) if sequence_lengths_plus_1 is not None: @@ -633,7 +633,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) return sequence_lengths_plus_1.contiguous() prefix_lengths = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "prefix_lengths_d", None), + getattr(attn_inputs, "prefix_lengths_device", None), device=device, ) if prefix_lengths is None: @@ -654,7 +654,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: return (prefix_lengths + input_lengths).contiguous() sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) if sequence_lengths_plus_1 is not None: @@ -1052,8 +1052,19 @@ def _resolve_plugin_block_table( cg_bufs: dict | None, in_capture: bool, ) -> torch.Tensor | None: + # NOTE: kv_cache_block_id_device is RTP-LLM's cache-store physical block table. + # RTP-LLM does NOT refresh it inside the CUDA/HIP graph on replay -- by design only + # kv_cache_kernel_block_id_device is D2D-refreshed, and cache store runs outside the + # graph (see RTP-LLM cuda_graph_runner.cc and OpDefs.h). Returning it under capture + # bakes a stale block_table / slot_mapping into the graph, so every replay step + # reads/writes KV at the frozen capture-time physical blocks -> garbled output. + # Under capture we must rebuild the physical table from the (refreshed) kernel table. physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) - if physical_block_table is not None and physical_block_table.numel() > 0: + if ( + not in_capture + and physical_block_table is not None + and physical_block_table.numel() > 0 + ): return physical_block_table kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) if kernel_block_table is None or kernel_block_table.numel() == 0: @@ -1796,8 +1807,19 @@ def _resolve_plugin_block_table( cg_bufs: dict | None, in_capture: bool, ) -> torch.Tensor | None: + # NOTE: kv_cache_block_id_device is RTP-LLM's cache-store physical block table. + # RTP-LLM does NOT refresh it inside the CUDA/HIP graph on replay -- by design only + # kv_cache_kernel_block_id_device is D2D-refreshed, and cache store runs outside the + # graph (see RTP-LLM cuda_graph_runner.cc and OpDefs.h). Returning it under capture + # bakes a stale block_table / slot_mapping into the graph, so every replay step + # reads/writes KV at the frozen capture-time physical blocks -> garbled output. + # Under capture we must rebuild the physical table from the (refreshed) kernel table. physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) - if physical_block_table is not None and physical_block_table.numel() > 0: + if ( + not in_capture + and physical_block_table is not None + and physical_block_table.numel() > 0 + ): return physical_block_table kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) if kernel_block_table is None or kernel_block_table.numel() == 0: @@ -1871,7 +1893,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) if is_prefill: prefix_lengths = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "prefix_lengths_d", None), + getattr(attn_inputs, "prefix_lengths_device", None), device=device, ) if prefix_lengths is None: @@ -1896,7 +1918,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) if non_cuda_graph_mode: sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) if sequence_lengths_plus_1 is not None: @@ -1923,7 +1945,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: if not non_cuda_graph_mode: sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + getattr(attn_inputs, "sequence_lengths_plus_1_device", None), device=device, ) if sequence_lengths_plus_1 is not None: