Skip to content

feat(visual_gen): add HunyuanDiT text-to-image pipeline with Ulysses parallelism#14991

Open
pkisfaludi-nv wants to merge 2 commits into
NVIDIA:mainfrom
pkisfaludi-nv:feat/hunyuandit-visual-gen
Open

feat(visual_gen): add HunyuanDiT text-to-image pipeline with Ulysses parallelism#14991
pkisfaludi-nv wants to merge 2 commits into
NVIDIA:mainfrom
pkisfaludi-nv:feat/hunyuandit-visual-gen

Conversation

@pkisfaludi-nv
Copy link
Copy Markdown

@pkisfaludi-nv pkisfaludi-nv commented Jun 5, 2026

Summary

  • New pipeline: HunyuanDiTPipeline for Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers (and v1.1, v1.0), following the existing BasePipeline pattern
  • Bilingual text encoding: BERT-based CLIP encoder (text_encoder) + MT5EncoderModel (text_encoder_2), both tokenizers, concatenated for cross-attention
  • Ulysses sequence parallelism: custom attention processor injects all_to_all_4d around self-attention; cross-attention uses standard SDPA (text replicated); image_rotary_emb (2D RoPE) sliced per rank; all_gather after the final block
  • Resolution binning: aspect-ratio snapping to training bucket resolutions via log-space distance
  • Config: examples/visual_gen/serve/configs/hunyuandit.yml
  • Docs: supported models table + feature matrix + footnote in visual-generation.md

Architecture notes

HunyuanDiT uses a diffusers HunyuanDiT2DModel backbone with U-Net skip connections. Because the model uses diffusers' HunyuanAttnProcessor2_0 rather than TRT-LLM's Attention module, Ulysses is implemented via:

  1. A custom HunyuanDiTUlyssesAttnProcessor replacing the self-attention processor in each block
  2. A patched forward() (via types.MethodType) that shards the latent sequence after patch-embed and gathers before the final norm/proj

Constraint: num_attention_heads=16 must be divisible by ulysses_size.

Files changed

tensorrt_llm/_torch/visual_gen/models/hunyuandit/   (new)
  __init__.py
  defaults.py
  transformer_hunyuandit.py
  pipeline_hunyuandit.py
tensorrt_llm/_torch/visual_gen/models/__init__.py   (modified)
tensorrt_llm/_torch/visual_gen/pipeline_registry.py (modified)
examples/visual_gen/serve/configs/hunyuandit.yml    (new)
docs/source/models/visual-generation.md             (modified)

Test plan

  • Import check: python -c "from tensorrt_llm._torch.visual_gen.pipeline_registry import PIPELINE_REGISTRY; assert 'HunyuanDiTPipeline' in PIPELINE_REGISTRY"
  • Single-GPU inference: trtllm-serve visual_gen --model Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers
  • Multi-GPU Ulysses: mpirun -n 2 ... ulysses_size: 2
  • trtllm-serve /v1/images/generations endpoint

🤖 Generated with Claude Code

Summary by CodeRabbit

New Features

  • Added HunyuanDiT text-to-image generation model with support for multiple checkpoint versions
  • Automatic resolution optimization for improved generation results
  • Multilingual text conditioning for broader language support
  • Enhanced distributed inference capabilities

Documentation

  • Added HunyuanDiT model documentation including supported checkpoints and feature capabilities matrix
  • Configuration examples provided for model deployment and setup

pkisfaludi-nv and others added 2 commits June 4, 2026 21:19
Integrates Tencent HunyuanDiT into the TensorRT-LLM VisualGen framework,
following the same BasePipeline pattern used by Qwen-Image and FLUX.

New files:
- tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py
  HunyuanDiTPipeline registered via @register_pipeline; supports v1.0-v1.2
  checkpoints. Implements bilingual (BertModel + MT5EncoderModel) text
  encoding, DDPM denoising loop, resolution binning to training buckets,
  2D RoPE embeddings, and VAE decode.
- tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py
  HunyuanDiT2DModelWrapper: thin nn.Module around diffusers HunyuanDiT2DModel
  with load_weights() compatible with the WeightLoader contract.
- tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py
  Default generation params (1024×1024, 50 steps, cfg=7.5) and extra-param
  schema (negative_prompt, use_resolution_binning).
- examples/visual_gen/serve/configs/hunyuandit.yml
  Serve config with VANILLA attention backend.

Modified:
- pipeline_registry.py: add HunyuanDiT detection in _detect_from_checkpoint
- models/__init__.py: export HunyuanDiTPipeline
- docs/source/models/visual-generation.md: add HunyuanDiT to model table
  and feature matrix

Qwen-Image (Qwen/Qwen-Image, Qwen/Qwen-Image-2512) was already present in
main; no code changes needed for it — confirmed in docs table.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements DeepSpeed Ulysses sequence parallelism for HunyuanDiT
following the same pattern used by FLUX (SequenceSharder + all-to-all
around attention).

HunyuanDiTUlyssesAttnProcessor
  Drop-in replacement for HunyuanAttnProcessor2_0. For self-attention
  (encoder_hidden_states is None) it wraps F.scaled_dot_product_attention
  with two all_to_all_4d calls:
    [B, S/U, H, D] → all-to-all → [B, S, H/U, D] → SDPA → all-to-all → [B, S/U, H, D]
  Cross-attention falls back to standard SDPA (text K/V is replicated on
  every rank, so no all-to-all is needed).

HunyuanDiT2DModelUlysses
  Monkeypatches HunyuanDiT2DModel.forward() via types.MethodType to:
    1. Shard the patch-embedded latent sequence across Ulysses ranks.
    2. Slice image_rotary_emb to the local sequence shard.
    3. Run the transformer blocks (with custom processors for self-attn).
    4. all_gather to reassemble the full sequence before norm_out/proj_out.
  U-Net-style skip tensors are sharded the same way as hidden_states so
  no special handling is needed for the skip connections.

Constraints
  - num_attention_heads (16) must be divisible by ulysses_size.
  - Validated at wrapper construction; raises ValueError otherwise.
  - Requires vgm.ulysses_group to be initialised (VisualGenMapping.init_device_mesh).

Docs: HunyuanDiT Ulysses column updated from No → Yes in the feature matrix.

Qwen-Image already has Ulysses support through the TRT-LLM Attention module
(inherits UlyssesAttention wrapping); no code changes needed there.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 5, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds HunyuanDiT text-to-image generation to TensorRT-LLM's VisualGen module, including dual-language prompt conditioning, resolution binning, and optional distributed sequence parallelism via Ulysses, with supporting documentation and checkpoint detection.

Changes

HunyuanDiT VisualGen Implementation

Layer / File(s) Summary
Module structure and exports
tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py, tensorrt_llm/_torch/visual_gen/models/__init__.py
Creates the hunyuandit module package with SPDX headers and exports HunyuanDiTPipeline and HunyuanDiT2DModelWrapper for public consumption; updates parent models __init__.py to expose the new pipeline.
Configuration and generation defaults
examples/visual_gen/serve/configs/hunyuandit.yml, tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py
Defines YAML configuration with VANILLA attention backend and parallel config (cfg_size=1, ulysses_size=1); provides default generation parameters (1024×1024 resolution, 50 steps, 7.5 guidance, 77 max tokens) and extra parameter specs for negative prompts and resolution binning toggle.
Transformer wrapper with Ulysses parallelism
tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py
Implements HunyuanDiT2DModelWrapper wrapping the diffusers transformer with conditional Ulysses sequence parallelism: HunyuanDiTUlyssesAttnProcessor injects all-to-all sharding+gathering for self-attention via SDPA, HunyuanDiT2DModelUlysses patches forward to shard/gather image token sequences across ranks with adjusted RoPE frequencies, wrapper validates divisibility constraints, conditionally applies Ulysses patch, loads weights with logging, and converts to inference dtype.
Pipeline generation flow
tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py
Implements HunyuanDiTPipeline with resolution aspect-ratio binning to snap dimensions to training buckets, loads dual text encoders (BERT for ≤77 tokens, MT5 for ≤256 tokens), VAE, and DDPM scheduler, supports classifier-free guidance via negative-prompt encoding, prepares latents from VAE downsampling, computes 2D RoPE embeddings, executes denoising loop with transformer noise prediction and CFG, and returns generated uint8 images with rank-0 timing logs.
Documentation and checkpoint detection
docs/source/models/visual-generation.md, tensorrt_llm/_torch/visual_gen/pipeline_registry.py
Adds HunyuanDiT v1.2/v1.1/v1.0 checkpoint IDs to supported models table and feature matrix with footnote explaining bilingual text conditioning and Ulysses parallelism constraints; extends AutoPipeline._detect_from_checkpoint() to recognize and dispatch HunyuanDiT model class names.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

VisualGen

Suggested reviewers

  • Shixiaowei02
  • zhenhuaw-me
  • yibinl-nvidia
  • NVShreyas
  • chang-l
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding a HunyuanDiT text-to-image pipeline with Ulysses parallelism support.
Description check ✅ Passed The description provides a comprehensive summary, architecture notes, file changes, and test plan, but lacks explicit PR title format following the template and doesn't fully complete the PR checklist.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py (1)

507-529: ⚡ Quick win

Unused loop variable i.

The loop counter i is declared but never used in the body. Rename to _ per PEP 8 and Ruff B007.

♻️ Proposed fix
-        for i, t in enumerate(timesteps):
+        for _, t in enumerate(timesteps):
             lat_in = torch.cat([latents] * 2) if do_cfg else latents
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py`
around lines 507 - 529, The for-loop in pipeline_hunyuandit.py currently
declares an unused loop variable "i" (for i, t in enumerate(timesteps)) which
triggers Ruff B007; change the loop to use an underscore for the unused index
(for _, t in enumerate(timesteps)) or drop enumerate if not needed, updating the
loop that constructs lat_in, calls self.transformer(...), performs CFG handling
and calls self.scheduler.step so only used variables remain.
tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py (1)

9-12: ⚡ Quick win

Sort __all__ entries alphabetically.

As per coding guidelines, maintain sorted __all__ lists for consistency across the codebase.

📋 Proposed fix
 __all__ = [
-    "HunyuanDiTPipeline",
     "HunyuanDiT2DModelWrapper",
+    "HunyuanDiTPipeline",
 ]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py` around lines 9
- 12, The __all__ list in the module exports is unsorted; reorder the entries
alphabetically so exports are consistent across the codebase — change the list
containing "HunyuanDiTPipeline" and "HunyuanDiT2DModelWrapper" to alphabetical
order (place "HunyuanDiT2DModelWrapper" before "HunyuanDiTPipeline") while
keeping the same string names and no other modifications.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py`:
- Around line 338-373: The _get_image_rotary_emb function currently returns None
when diffusers' get_2d_rotary_pos_embed is unavailable but its signature is
Tuple[torch.Tensor, torch.Tensor]; change the return annotation to
Optional[Tuple[torch.Tensor, torch.Tensor]] (import Optional from typing) and
remove the type: ignore on the None return so the signature matches callers
(e.g., where image_rotary_emb is checked for None); update any docstring/type
comments accordingly so static checkers understand None is a valid return.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py`:
- Around line 9-12: The __all__ list in the module exports is unsorted; reorder
the entries alphabetically so exports are consistent across the codebase —
change the list containing "HunyuanDiTPipeline" and "HunyuanDiT2DModelWrapper"
to alphabetical order (place "HunyuanDiT2DModelWrapper" before
"HunyuanDiTPipeline") while keeping the same string names and no other
modifications.

In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py`:
- Around line 507-529: The for-loop in pipeline_hunyuandit.py currently declares
an unused loop variable "i" (for i, t in enumerate(timesteps)) which triggers
Ruff B007; change the loop to use an underscore for the unused index (for _, t
in enumerate(timesteps)) or drop enumerate if not needed, updating the loop that
constructs lat_in, calls self.transformer(...), performs CFG handling and calls
self.scheduler.step so only used variables remain.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9bd57c0d-3159-4d02-abd7-66cbeb72038c

📥 Commits

Reviewing files that changed from the base of the PR and between 81e86a5 and 33b4360.

📒 Files selected for processing (8)
  • docs/source/models/visual-generation.md
  • examples/visual_gen/serve/configs/hunyuandit.yml
  • tensorrt_llm/_torch/visual_gen/models/__init__.py
  • tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py
  • tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py
  • tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py
  • tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py
  • tensorrt_llm/_torch/visual_gen/pipeline_registry.py

Comment on lines +338 to +373
@staticmethod
def _get_image_rotary_emb(
patch_size: int,
vae_scale_factor: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 2D RoPE embeddings for the latent grid.

Follows diffusers' ``get_2d_rotary_pos_embed``.
"""
try:
from diffusers.models.embeddings import get_2d_rotary_pos_embed
except ImportError:
return None # type: ignore[return-value]

grid_height = height // (vae_scale_factor * patch_size)
grid_width = width // (vae_scale_factor * patch_size)
base_size = 512 // (vae_scale_factor * patch_size)
grid_crops_coords = (
(0, 0),
(grid_height, grid_width),
)
freqs_cos, freqs_sin = get_2d_rotary_pos_embed(
embed_dim=88,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
use_real=True,
base_size=base_size,
)
return (
freqs_cos.to(device=device, dtype=dtype),
freqs_sin.to(device=device, dtype=dtype),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Type hint mismatch: function can return None but signature declares Tuple[torch.Tensor, torch.Tensor].

Line 354 returns None when get_2d_rotary_pos_embed cannot be imported, but the function signature at line 339 declares the return type as Tuple[torch.Tensor, torch.Tensor]. Callers (lines 280-284 in transformer, line 490 in this file) check if image_rotary_emb is not None, so the None return is intentional, but the type hint should reflect this.

🔧 Proposed fix
     `@staticmethod`
     def _get_image_rotary_emb(
         patch_size: int,
         vae_scale_factor: int,
         height: int,
         width: int,
         device: torch.device,
         dtype: torch.dtype,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
         """Compute 2D RoPE embeddings for the latent grid.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py`
around lines 338 - 373, The _get_image_rotary_emb function currently returns
None when diffusers' get_2d_rotary_pos_embed is unavailable but its signature is
Tuple[torch.Tensor, torch.Tensor]; change the return annotation to
Optional[Tuple[torch.Tensor, torch.Tensor]] (import Optional from typing) and
remove the type: ignore on the None return so the signature matches callers
(e.g., where image_rotary_emb is checked for None); update any docstring/type
comments accordingly so static checkers understand None is a valid return.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant