fix : (GLM-5-2-FP8)do not use buffer out of cudagraph when unnecessary#1391
fix : (GLM-5-2-FP8)do not use buffer out of cudagraph when unnecessary#1391JiaoliangYu wants to merge 3 commits into
Conversation
6d94ad1 to
1a0d5b8
Compare
| out_val, | ||
| mask=valid_mask, | ||
| ) | ||
| # Decode-only: token_id == batch_id. Don't read qo_indptr (cu_seqlens_q), |
There was a problem hiding this comment.
token_id == batch_id this is not True for mtp
There was a problem hiding this comment.
ATOM/atom/model_ops/attention_mla.py
Line 1381 in f797dd5
if attn_metadata.max_seqlen_q > 1: triton_gather_kv_indices_sparse( attn_metadata.sparse_kv_indptr, attn_metadata.token_to_seq_idxs, topk_indices, attn_metadata.kv_indices, attn_metadata.kv_indptr, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, ) else: triton_convert_req_index_to_global_index( attn_metadata.cu_seqlens_q, attn_metadata.kv_indptr, attn_metadata.sparse_kv_indptr, attn_metadata.kv_indices, topk_indices, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, )
mtp does not use this kernel, right?
Motivation
GLM-5.2-FP8 benchmark/serving (isl=1024 osl=1024, higher concurrency) intermittently dies with a GPU Memory access fault ... Reason: Unknown (process exit -6). The faulting kernel is the sparse-MLA decode index converter _convert_req_index_to_global_index_kernel (MEMORY_VIOLATION). The failure is non-deterministic (the same commit passes and fails across CI runs) and only happens under CUDA/HIP-graph capture — with --enforce-eager it never reproduces (145k+ eager decode steps, zero faults).
Under CUDA-graph replay, the convert kernel reads a stale/cumulative cu_seqlens_q (qo_indptr) instead of the decode arange. For decode each request has exactly one query token, so cu_seqlens_q must be [0,1,2,...,bs] (qo_end[i] = i+1 ≤ bs). At the fault, the registers show qo_end values far larger than bs (e.g. 860/875 with bs ≤ 512), i.e. a prefill-shaped cumulative layout.
That drives the loop variable token_id out of range, so the kernel reads a row of token_indices (a per-step torch.empty buffer) that the indexer never filled. Those positions hold the -1 "invalid" sentinel, which leaks into a lane the row mask treats as valid; the unbounded kv_indices + kv_start + tok load with tok = -1 then underflows to kv_indices_base - 4 and hits an unmapped page. This is confirmed byte-exact from the debug-agent dump (faulting address == kv_indices base − 4, kv_start == 0, tok == -1).
The Python metadata path writes cu_seqlens_q correctly (decode arange) every step on the main stream before replay; static analysis shows no logic bug. The staleness is a CUDA/HIP-graph runtime artifact on gfx950/MI355X — a captured kernel observing stale device memory under replay.
Technical Details
The convert kernel is decode-only (max_seqlen_q == 1), so token_id == batch_id by construction. Derive the row directly from tl.program_id(0) instead of reading qo_indptr/cu_seqlens_q. The kernel no longer depends on the buffer that goes stale under graph replay, so the fault cannot occur regardless of the underlying runtime root cause — it is a correctness-preserving change, not a band-aid.
Test Plan
Accuracy CI && nightly Benchmark
Test Result
Submission Checklist