Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions atom/plugin/rtpllm/utils/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _query_start_loc(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
device=device,
)
cu_seqlens = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "cu_seqlens", None),
getattr(attn_inputs, "cu_seqlens_device", None),
device=device,
)
if cu_seqlens is not None and cu_seqlens.numel() > 1:
Expand Down Expand Up @@ -183,7 +183,7 @@ def _query_start_loc(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
# Decode: query length is runtime step token count (usually 1 per sequence),
# not prompt input_lengths.
sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
sequence_lengths = RTPForwardContext._non_empty_int32(
Expand Down Expand Up @@ -262,7 +262,7 @@ def _state_indices(

if is_prefill:
prefix_lengths = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "prefix_lengths_d", None),
getattr(attn_inputs, "prefix_lengths_device", None),
device=device,
)
if prefix_lengths is None:
Expand All @@ -283,7 +283,7 @@ def _state_indices(
else:
# RTP decode kernels use sequence_lengths_plus_1_d as canonical runtime value.
sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
if sequence_lengths_plus_1 is not None:
Expand Down Expand Up @@ -621,7 +621,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
# For chunked prefill, prefix_lengths can remain per-chunk while
# sequence_lengths_plus_1_d tracks the true cumulative context length.
sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
if sequence_lengths_plus_1 is not None:
Expand All @@ -633,7 +633,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
)
return sequence_lengths_plus_1.contiguous()
prefix_lengths = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "prefix_lengths_d", None),
getattr(attn_inputs, "prefix_lengths_device", None),
device=device,
)
if prefix_lengths is None:
Expand All @@ -654,7 +654,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
return (prefix_lengths + input_lengths).contiguous()

sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
if sequence_lengths_plus_1 is not None:
Expand Down Expand Up @@ -1052,8 +1052,19 @@ def _resolve_plugin_block_table(
cg_bufs: dict | None,
in_capture: bool,
) -> torch.Tensor | None:
# NOTE: kv_cache_block_id_device is RTP-LLM's cache-store physical block table.
# RTP-LLM does NOT refresh it inside the CUDA/HIP graph on replay -- by design only
# kv_cache_kernel_block_id_device is D2D-refreshed, and cache store runs outside the
# graph (see RTP-LLM cuda_graph_runner.cc and OpDefs.h). Returning it under capture
# bakes a stale block_table / slot_mapping into the graph, so every replay step
# reads/writes KV at the frozen capture-time physical blocks -> garbled output.
# Under capture we must rebuild the physical table from the (refreshed) kernel table.
physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None)
if physical_block_table is not None and physical_block_table.numel() > 0:
if (
not in_capture
and physical_block_table is not None
and physical_block_table.numel() > 0
):
return physical_block_table
kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs)
if kernel_block_table is None or kernel_block_table.numel() == 0:
Expand Down Expand Up @@ -1796,8 +1807,19 @@ def _resolve_plugin_block_table(
cg_bufs: dict | None,
in_capture: bool,
) -> torch.Tensor | None:
# NOTE: kv_cache_block_id_device is RTP-LLM's cache-store physical block table.
# RTP-LLM does NOT refresh it inside the CUDA/HIP graph on replay -- by design only
# kv_cache_kernel_block_id_device is D2D-refreshed, and cache store runs outside the
# graph (see RTP-LLM cuda_graph_runner.cc and OpDefs.h). Returning it under capture
# bakes a stale block_table / slot_mapping into the graph, so every replay step
# reads/writes KV at the frozen capture-time physical blocks -> garbled output.
# Under capture we must rebuild the physical table from the (refreshed) kernel table.
physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None)
if physical_block_table is not None and physical_block_table.numel() > 0:
if (
not in_capture
and physical_block_table is not None
and physical_block_table.numel() > 0
):
return physical_block_table
kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs)
if kernel_block_table is None or kernel_block_table.numel() == 0:
Expand Down Expand Up @@ -1871,7 +1893,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
is_prefill = bool(getattr(attn_inputs, "is_prefill", False))
if is_prefill:
prefix_lengths = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "prefix_lengths_d", None),
getattr(attn_inputs, "prefix_lengths_device", None),
device=device,
)
if prefix_lengths is None:
Expand All @@ -1896,7 +1918,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:
)
if non_cuda_graph_mode:
sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
if sequence_lengths_plus_1 is not None:
Expand All @@ -1923,7 +1945,7 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor:

if not non_cuda_graph_mode:
sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32(
getattr(attn_inputs, "sequence_lengths_plus_1_d", None),
getattr(attn_inputs, "sequence_lengths_plus_1_device", None),
device=device,
)
if sequence_lengths_plus_1 is not None:
Expand Down