Skip to content

fix : (GLM-5-2-FP8)do not use buffer out of cudagraph when unnecessary#1391

Open
JiaoliangYu wants to merge 3 commits into
ROCm:mainfrom
JiaoliangYu:fix/sparse-mla-convert-tok-bounds
Open

fix : (GLM-5-2-FP8)do not use buffer out of cudagraph when unnecessary#1391
JiaoliangYu wants to merge 3 commits into
ROCm:mainfrom
JiaoliangYu:fix/sparse-mla-convert-tok-bounds

Conversation

@JiaoliangYu

Copy link
Copy Markdown
Contributor

Motivation

GLM-5.2-FP8 benchmark/serving (isl=1024 osl=1024, higher concurrency) intermittently dies with a GPU Memory access fault ... Reason: Unknown (process exit -6). The faulting kernel is the sparse-MLA decode index converter _convert_req_index_to_global_index_kernel (MEMORY_VIOLATION). The failure is non-deterministic (the same commit passes and fails across CI runs) and only happens under CUDA/HIP-graph capture — with --enforce-eager it never reproduces (145k+ eager decode steps, zero faults).

Under CUDA-graph replay, the convert kernel reads a stale/cumulative cu_seqlens_q (qo_indptr) instead of the decode arange. For decode each request has exactly one query token, so cu_seqlens_q must be [0,1,2,...,bs] (qo_end[i] = i+1 ≤ bs). At the fault, the registers show qo_end values far larger than bs (e.g. 860/875 with bs ≤ 512), i.e. a prefill-shaped cumulative layout.

That drives the loop variable token_id out of range, so the kernel reads a row of token_indices (a per-step torch.empty buffer) that the indexer never filled. Those positions hold the -1 "invalid" sentinel, which leaks into a lane the row mask treats as valid; the unbounded kv_indices + kv_start + tok load with tok = -1 then underflows to kv_indices_base - 4 and hits an unmapped page. This is confirmed byte-exact from the debug-agent dump (faulting address == kv_indices base − 4, kv_start == 0, tok == -1).

The Python metadata path writes cu_seqlens_q correctly (decode arange) every step on the main stream before replay; static analysis shows no logic bug. The staleness is a CUDA/HIP-graph runtime artifact on gfx950/MI355X — a captured kernel observing stale device memory under replay.

Technical Details

The convert kernel is decode-only (max_seqlen_q == 1), so token_id == batch_id by construction. Derive the row directly from tl.program_id(0) instead of reading qo_indptr/cu_seqlens_q. The kernel no longer depends on the buffer that goes stale under graph replay, so the fault cannot occur regardless of the underlying runtime root cause — it is a correctness-preserving change, not a band-aid.

Test Plan

Accuracy CI && nightly Benchmark

Test Result

Submission Checklist

@zufayu zufayu requested a review from jiayyu June 29, 2026 02:54
@JiaoliangYu JiaoliangYu changed the title fix : do not use buffer out of cudagraph when unnecessary fix : (GLM-5-2-FP8)do not use buffer out of cudagraph when unnecessary Jun 29, 2026
@JiaoliangYu JiaoliangYu force-pushed the fix/sparse-mla-convert-tok-bounds branch from 6d94ad1 to 1a0d5b8 Compare June 29, 2026 03:00
@JiaoliangYu JiaoliangYu marked this pull request as ready for review June 29, 2026 03:01
@JiaoliangYu JiaoliangYu marked this pull request as draft June 29, 2026 03:01
@JiaoliangYu JiaoliangYu marked this pull request as ready for review June 29, 2026 06:33
out_val,
mask=valid_mask,
)
# Decode-only: token_id == batch_id. Don't read qo_indptr (cu_seqlens_q),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_id == batch_id this is not True for mtp

@JiaoliangYu JiaoliangYu Jun 29, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# NOTE: MTP (max_seqlen_q > 1) uses triton_convert_req_index_to_global_index_dsa_prefill instead

if attn_metadata.max_seqlen_q > 1: triton_gather_kv_indices_sparse( attn_metadata.sparse_kv_indptr, attn_metadata.token_to_seq_idxs, topk_indices, attn_metadata.kv_indices, attn_metadata.kv_indptr, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, ) else: triton_convert_req_index_to_global_index( attn_metadata.cu_seqlens_q, attn_metadata.kv_indptr, attn_metadata.sparse_kv_indptr, attn_metadata.kv_indices, topk_indices, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, )

mtp does not use this kernel, right?

@JiaoliangYu JiaoliangYu requested a review from valarLip June 30, 2026 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants