From 339ef497359ba06fdce27e8a2cf7c6bf6b67f020 Mon Sep 17 00:00:00 2001 From: Dongfeng Yu Date: Wed, 3 Jun 2026 02:07:58 +0000 Subject: [PATCH] [NVBUG-6250866][bugfix] fix DeepEP intranode combine fallback Signed-off-by: Dongfeng Yu [NVBUG-6250866][bugfix] wire DeepEP patch into FetchContent Signed-off-by: Dongfeng Yu --- 3rdparty/fetch_content.json | 3 +- .../deep_ep_intranode_combine_fix.patch | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 3rdparty/patches/deep_ep_intranode_combine_fix.patch diff --git a/3rdparty/fetch_content.json b/3rdparty/fetch_content.json index e6218c2629c2..6a4679db5262 100644 --- a/3rdparty/fetch_content.json +++ b/3rdparty/fetch_content.json @@ -26,7 +26,8 @@ "display_name": "deep_ep", "git_repository": "${github_base_url}/deepseek-ai/DeepEP", "git_tag": "5be51b228a7c82dbdb213ea58e77bffd12b38af8", - "use_url": true + "use_url": true, + "patch_file": "patches/deep_ep_intranode_combine_fix.patch" }, { "name": "deepgemm", diff --git a/3rdparty/patches/deep_ep_intranode_combine_fix.patch b/3rdparty/patches/deep_ep_intranode_combine_fix.patch new file mode 100644 index 000000000000..fbed0107f81b --- /dev/null +++ b/3rdparty/patches/deep_ep_intranode_combine_fix.patch @@ -0,0 +1,35 @@ +--- a/csrc/kernels/intranode.cu ++++ b/csrc/kernels/intranode.cu +@@ -844,9 +844,15 @@ + + #ifndef DISABLE_SM90_FEATURES + // Wait TMA arrival ++ // hidden_int4 is not always divisible by a warp. The final tile can have ++ // only a subset of lanes active, so synchronize only participating lanes. ++ auto const tile_start = i - lane_id; ++ auto const active_lanes = min(32, hidden_int4 - tile_start); ++ auto const sync_mask = active_lanes == 32 ? 0xffffffffu : ((1u << active_lanes) - 1u); ++ + if (lane_id == 0) + tma_store_wait(); +- __syncwarp(); ++ __syncwarp(sync_mask); + + // Write into TMA buffer + auto tma_stage_idx = (i / 32) % kNumStages; +@@ -854,13 +860,13 @@ + + // Issue TMA + tma_store_fence(); +- __syncwarp(); ++ __syncwarp(sync_mask); + if (lane_id == 0) { + auto tma_bytes = min(32, hidden_int4 - i) * static_cast(sizeof(int4)); + tma_store_1d(reinterpret_cast(tma_buffer) + tma_stage_idx * 32, + recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false); + } +- __syncwarp(); ++ __syncwarp(sync_mask); + #else + recv_int4[token_idx * hidden_int4 + i] = out_int4; + #endif