[https://nvbugs/6266705][fix] Gate the FlashInfer import-time selection on get_sm_version() == 90 (in…#14973
Conversation
…to Triton elsewhere Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe GDN prefill kernel selection logic is updated to restrict FlashInfer usage to SM90 GPUs only. The change adds a ChangesGDN prefill kernel selection gating
🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
| # device major version 9 (SM90/Hopper); on other archs it aborts at load. Gate | ||
| # the selection on SM90 so non-Hopper GPUs (e.g. Blackwell SM120) use the | ||
| # device-agnostic Triton path. | ||
| if os.getenv("TLLM_USE_FLASHINFER_GDN_PREFILL", "1") == "1" and get_sm_version() == 90: |
There was a problem hiding this comment.
The FlashInfer GDN prefill kernel Requires SM90 (Hopper) or SM100 or SM103 (Blackwell) architecture.
Summary
gdn_mixer.pyboundchunk_gated_delta_ruleto the SM90-only FlashInfer GDN prefill kernel at import time with no device-arch guard, so on SM120 (Blackwell RTX PRO 6000) prefill aborted with "delta rule kernel does not support this device major version: 12" during model load.get_sm_version() == 90(in addition to the existing env flag); all other archs fall back to the device-agnostic Tritonchunk_gated_delta_rule. Verified EXIT_CODE=0 (model loads and serves).Test plan
Links
Summary by CodeRabbit