From d14234e2b04db6310143a1c13ffb50a42081618a Mon Sep 17 00:00:00 2001 From: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com> Date: Thu, 4 Jun 2026 12:57:03 -0700 Subject: [PATCH] [nvbugs/6266705][fix] Gate FlashInfer GDN prefill to SM90, fall back to Triton elsewhere Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com> --- tensorrt_llm/_torch/modules/mamba/gdn_mixer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py index a6882d11fbae..9235c36f460f 100644 --- a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py @@ -12,9 +12,15 @@ from torch import nn from transformers import Qwen3NextConfig +from tensorrt_llm._utils import get_sm_version + # Default: FlashInfer GDN prefill ON. Set TLLM_USE_FLASHINFER_GDN_PREFILL=0 to # fall back to the vendored Triton chunk_gated_delta_rule. -if os.getenv("TLLM_USE_FLASHINFER_GDN_PREFILL", "1") == "1": +# The FlashInfer GDN prefill kernel (gdn_prefill_launcher.cu) only supports +# device major version 9 (SM90/Hopper); on other archs it aborts at load. Gate +# the selection on SM90 so non-Hopper GPUs (e.g. Blackwell SM120) use the +# device-agnostic Triton path. +if os.getenv("TLLM_USE_FLASHINFER_GDN_PREFILL", "1") == "1" and get_sm_version() == 90: from tensorrt_llm._torch.modules.fla.flashinfer_chunk import chunk_gated_delta_rule else: from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule