Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
from torch import nn
from transformers import Qwen3NextConfig

from tensorrt_llm._utils import get_sm_version

# Default: FlashInfer GDN prefill ON. Set TLLM_USE_FLASHINFER_GDN_PREFILL=0 to
# fall back to the vendored Triton chunk_gated_delta_rule.
if os.getenv("TLLM_USE_FLASHINFER_GDN_PREFILL", "1") == "1":
# The FlashInfer GDN prefill kernel (gdn_prefill_launcher.cu) only supports
# device major version 9 (SM90/Hopper); on other archs it aborts at load. Gate
# the selection on SM90 so non-Hopper GPUs (e.g. Blackwell SM120) use the
# device-agnostic Triton path.
if os.getenv("TLLM_USE_FLASHINFER_GDN_PREFILL", "1") == "1" and get_sm_version() == 90:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FlashInfer GDN prefill kernel Requires SM90 (Hopper) or SM100 or SM103 (Blackwell) architecture.

from tensorrt_llm._torch.modules.fla.flashinfer_chunk import chunk_gated_delta_rule
else:
from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule
Expand Down
Loading