diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ee98185f4..f320fad1a 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -1288,7 +1288,6 @@ def forward( @triton.jit def _convert_req_index_to_global_index_kernel( - qo_indptr, # int32 [num_requests] kv_indptr, # int32 [num_requests+1] page_kv_indptr, # int32 [num_requests+1] kv_indices, # int32 [num_requests * max_num_blocks_per_req] @@ -1315,33 +1314,33 @@ def _convert_req_index_to_global_index_kernel( kv_end = tl.load(kv_indptr + batch_id + 1) out_kv_start = tl.load(page_kv_indptr + batch_id) kv_len = kv_end - kv_start - qo_start = tl.load(qo_indptr + batch_id) - qo_end = tl.load(qo_indptr + batch_id + 1) - - for token_id in range(qo_start, qo_end): - # Load token indices for this tile - ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 - tok = tl.load(ti_ptr) # int32 - - # Guard block_table access - valid_mask = (indice_id < kv_len) & (indice_id < NUM_TOPK_TOKENS) - out_val = tl.load( - kv_indices + kv_start + tok, - mask=valid_mask, - other=0, - ) - # Store results - out_ptr_ij = out_kv_indices + out_kv_start + indice_id - tl.store( - out_ptr_ij, - out_val, - mask=valid_mask, - ) + # Decode-only: token_id == batch_id. Don't read qo_indptr (cu_seqlens_q), + # which can be stale/cumulative under CUDA-graph replay. + token_id = batch_id + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Guard block_table access + valid_mask = (indice_id < kv_len) & (indice_id < NUM_TOPK_TOKENS) + out_val = tl.load( + kv_indices + kv_start + tok, + mask=valid_mask, + other=0, + ) + + # Store results + out_ptr_ij = out_kv_indices + out_kv_start + indice_id + tl.store( + out_ptr_ij, + out_val, + mask=valid_mask, + ) def triton_convert_req_index_to_global_index( - qo_indptr: torch.Tensor, # int32 [num_tokens + 1] kv_indptr: torch.Tensor, # int32 [num_tokens + 1] page_kv_indptr: torch.Tensor, # int32 [num_tokens + 1] kv_indices: torch.Tensor, # int32 [total_kv_seqlen] @@ -1373,7 +1372,6 @@ def triton_convert_req_index_to_global_index( tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device - qo_indptr_c = qo_indptr.contiguous() kv_indptr_c = kv_indptr.contiguous() kv_indices_c = kv_indices.contiguous() token_indices_c = token_indices.contiguous() @@ -1391,7 +1389,6 @@ def triton_convert_req_index_to_global_index( grid = (num_batch, tiles_per_row) _convert_req_index_to_global_index_kernel[grid]( - qo_indptr_c, kv_indptr_c, page_kv_indptr_c, kv_indices_c, diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index ff47a5fcc..ef5fef1c6 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1301,7 +1301,6 @@ def sparse_attn_indexer( ) else: triton_convert_req_index_to_global_index( - attn_metadata.cu_seqlens_q, attn_metadata.kv_indptr, attn_metadata.sparse_kv_indptr, attn_metadata.kv_indices,