Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,6 @@ def forward(

@triton.jit
def _convert_req_index_to_global_index_kernel(
qo_indptr, # int32 [num_requests]
kv_indptr, # int32 [num_requests+1]
page_kv_indptr, # int32 [num_requests+1]
kv_indices, # int32 [num_requests * max_num_blocks_per_req]
Expand All @@ -1315,33 +1314,33 @@ def _convert_req_index_to_global_index_kernel(
kv_end = tl.load(kv_indptr + batch_id + 1)
out_kv_start = tl.load(page_kv_indptr + batch_id)
kv_len = kv_end - kv_start
qo_start = tl.load(qo_indptr + batch_id)
qo_end = tl.load(qo_indptr + batch_id + 1)

for token_id in range(qo_start, qo_end):
# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32

# Guard block_table access
valid_mask = (indice_id < kv_len) & (indice_id < NUM_TOPK_TOKENS)
out_val = tl.load(
kv_indices + kv_start + tok,
mask=valid_mask,
other=0,
)

# Store results
out_ptr_ij = out_kv_indices + out_kv_start + indice_id
tl.store(
out_ptr_ij,
out_val,
mask=valid_mask,
)
# Decode-only: token_id == batch_id. Don't read qo_indptr (cu_seqlens_q),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_id == batch_id this is not True for mtp

@JiaoliangYu JiaoliangYu Jun 29, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# NOTE: MTP (max_seqlen_q > 1) uses triton_convert_req_index_to_global_index_dsa_prefill instead

if attn_metadata.max_seqlen_q > 1: triton_gather_kv_indices_sparse( attn_metadata.sparse_kv_indptr, attn_metadata.token_to_seq_idxs, topk_indices, attn_metadata.kv_indices, attn_metadata.kv_indptr, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, ) else: triton_convert_req_index_to_global_index( attn_metadata.cu_seqlens_q, attn_metadata.kv_indptr, attn_metadata.sparse_kv_indptr, attn_metadata.kv_indices, topk_indices, NUM_TOPK_TOKENS=topk_tokens, out=sparse_kv_indices_buffer, )

mtp does not use this kernel, right?

# which can be stale/cumulative under CUDA-graph replay.
token_id = batch_id

# Load token indices for this tile
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
tok = tl.load(ti_ptr) # int32

# Guard block_table access
valid_mask = (indice_id < kv_len) & (indice_id < NUM_TOPK_TOKENS)
out_val = tl.load(
kv_indices + kv_start + tok,
mask=valid_mask,
other=0,
)

# Store results
out_ptr_ij = out_kv_indices + out_kv_start + indice_id
tl.store(
out_ptr_ij,
out_val,
mask=valid_mask,
)


def triton_convert_req_index_to_global_index(
qo_indptr: torch.Tensor, # int32 [num_tokens + 1]
kv_indptr: torch.Tensor, # int32 [num_tokens + 1]
page_kv_indptr: torch.Tensor, # int32 [num_tokens + 1]
kv_indices: torch.Tensor, # int32 [total_kv_seqlen]
Expand Down Expand Up @@ -1373,7 +1372,6 @@ def triton_convert_req_index_to_global_index(
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N

# Ensure contiguous tensors on the same device
qo_indptr_c = qo_indptr.contiguous()
kv_indptr_c = kv_indptr.contiguous()
kv_indices_c = kv_indices.contiguous()
token_indices_c = token_indices.contiguous()
Expand All @@ -1391,7 +1389,6 @@ def triton_convert_req_index_to_global_index(
grid = (num_batch, tiles_per_row)

_convert_req_index_to_global_index_kernel[grid](
qo_indptr_c,
kv_indptr_c,
page_kv_indptr_c,
kv_indices_c,
Expand Down
1 change: 0 additions & 1 deletion atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,6 @@ def sparse_attn_indexer(
)
else:
triton_convert_req_index_to_global_index(
attn_metadata.cu_seqlens_q,
attn_metadata.kv_indptr,
attn_metadata.sparse_kv_indptr,
attn_metadata.kv_indices,
Expand Down
Loading