Skip to content

[None][fix] guard CUDA graph capture against ADP asymmetric batch-None deadlock#14986

Open
longcheng-nv wants to merge 1 commit into
NVIDIA:mainfrom
longcheng-nv:fix/cuda-graph-capture-adp-batch-deadlock
Open

[None][fix] guard CUDA graph capture against ADP asymmetric batch-None deadlock#14986
longcheng-nv wants to merge 1 commit into
NVIDIA:mainfrom
longcheng-nv:fix/cuda-graph-capture-adp-batch-deadlock

Conversation

@longcheng-nv
Copy link
Copy Markdown
Collaborator

@longcheng-nv longcheng-nv commented Jun 5, 2026

Summary

  • Bug: _capture_generation_cuda_graphs in PyTorchModelEngine lacks the cross-rank batch-None check that _general_warmup_impl and _run_autotuner_warmup already have.
  • Impact: DEP mode + enable_attention_dp=true + CUDA graph batch sizes ≥ 16 → permanent distributed deadlock during engine warmup. Process hangs before producing any Python output (run.log = 0 bytes), 100% reproducible.
  • Fix: Mirror the _assert_all_tp_ranks_have_warmup_batch guard (already present in the two other warmup paths) into _capture_generation_cuda_graphs.

Root Cause

_capture_generation_cuda_graphs iterates batch sizes in reverse order (largest first, so smaller graphs can reuse the memory pool). Under attention-DP, KV-cache capacity differs per TP rank. _create_cuda_graph_warmup_request returns None on ranks that lack space. Without a cross-rank check:

  • Ranks where batch is None silently continue
  • Other ranks enter forward() containing tp_comm collectives (NCCL allreduce / DEP alltoall)
  • Skipped ranks never reach the collective → permanent deadlock

The two other warmup paths already guard this with _assert_all_tp_ranks_have_warmup_batch (added in an earlier PR). This PR closes the remaining gap.

Fix

# Before
if batch is None:
    continue

# After
if batch is None and self.mapping.tp_size <= 1:
    continue
self._assert_all_tp_ranks_have_warmup_batch(batch, bs)
if batch is None:
    continue  # all ranks agree: skip

Tests

Added tests/unittest/_torch/executor/test_cuda_graph_capture_adp_guard.py with four unit tests:

  1. Structural: guard call is present in source (regression sentinel)
  2. Asymmetric None raises: some ranks None, others valid → RuntimeError before forward()
  3. All None skips: all ranks None → graceful skip, no raise
  4. All valid proceeds: all ranks valid → forward() called normally

Test plan

  • Run pytest tests/unittest/_torch/executor/test_cuda_graph_capture_adp_guard.py (no GPU needed)
    4/4 passed (no GPU required; 1.61 s wall time on 8×B300 host)
  • End-to-end: DEP + enable_attention_dp=true + cuda_graph_config.batch_sizes=[1,2,...,32] no longer deadlocks at warmup
    → Verified via DSv4 Pro DEP GSM8K eval (8×B300, MTP=3, GVR ON, BS up to 32):
    engine initialized cleanly in 452 s, inferred all 1319 problems without hang.
    Accuracy: 96.51% average (flexible-extract 96.51, strict-match 96.51),
    matching TEP mode (96.63%) within ±0.5 pp statistical error.

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Improved CUDA graph generation to detect and handle asymmetric KV-cache capacity across distributed ranks, preventing potential deadlocks in multi-rank configurations.
  • Tests

    • Added comprehensive unit tests validating correct behavior under asymmetric batch conditions in distributed tensor-parallel execution scenarios.

…e deadlock

_capture_generation_cuda_graphs iterates batch sizes in reverse order
(largest first, so smaller graphs can reuse the memory pool).  Under
attention-DP (enable_attention_dp=true), KV-cache capacity can differ
across TP ranks.  _create_cuda_graph_warmup_request returns None on
ranks that lack space for the requested batch size.

Without a cross-rank check, ranks where batch=None silently `continue`
while other ranks enter forward() containing tp_comm collectives
(NCCL allreduce / DEP alltoall).  The skipped ranks never reach the
collective, causing a permanent distributed deadlock.  The process
hangs before producing any Python output (run.log = 0 bytes) and must
be killed externally.

The identical scenario is already guarded in _general_warmup_impl and
_run_autotuner_warmup via _assert_all_tp_ranks_have_warmup_batch.
Apply the same pattern to _capture_generation_cuda_graphs:
- If tp_size <= 1: single-rank path, safe to skip silently (unchanged).
- If tp_size > 1: call _assert_all_tp_ranks_have_warmup_batch to detect
  asymmetry and raise RuntimeError before entering forward(), then skip
  only when all ranks agree the batch is None.

Add unit tests covering the three cases: asymmetric None raises,
all-None skips gracefully, all-valid proceeds to forward().  A fourth
structural test asserts the guard call is present in the source to
catch future regressions.

Made-with: Claude Code
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com>
@longcheng-nv longcheng-nv requested a review from a team as a code owner June 5, 2026 03:08
@longcheng-nv longcheng-nv requested a review from achartier June 5, 2026 03:08
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 5, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR prevents potential deadlocks in CUDA graph generation when TP ranks have asymmetric KV-cache capacity. The engine now asserts batch consistency across ranks before skipping warmup, and comprehensive tests validate the asymmetric detection, graceful all-None skip, and normal all-valid progression paths.

Changes

Deadlock Prevention in CUDA Graph Warmup

Layer / File(s) Summary
Fail-fast batch consistency assertion in CUDA graph warmup
tensorrt_llm/_torch/pyexecutor/model_engine.py
When batch is None due to insufficient KV cache and tp_size > 1, the engine calls _assert_all_tp_ranks_have_warmup_batch to detect asymmetric KV-cache capacity across TP ranks and fail fast instead of silently continuing.
Comprehensive test validation of asymmetric KV-cache handling
tests/unittest/_torch/executor/test_cuda_graph_capture_adp_guard.py
Test suite validates that the guard detects asymmetric batch=None across ranks (raises RuntimeError), all ranks with batch=None skip gracefully without raising, and all ranks with valid batch proceed normally. Includes structural assertion of guard presence, mocked PyTorchModelEngine setup, and three behavioral scenarios.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • achartier
  • reasonsolo
  • yuxianq
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding a guard to prevent CUDA graph capture deadlock in attention-DP mode with asymmetric batch-None conditions.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description comprehensively covers the bug, root cause, fix implementation, and test coverage with clear examples and verification results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@longcheng-nv longcheng-nv removed the request for review from achartier June 5, 2026 05:01
@longcheng-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52252 [ run ] triggered by Bot. Commit: 40c1dd0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #52252 [ run ] completed with state SUCCESS. Commit: 40c1dd0
/LLM/main/L0_MergeRequest_PR pipeline #41566 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@chenglong92
Copy link
Copy Markdown

/bot run --disable-fail-fast

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants