From f49824ad28a5a91df9f235a6e52047b3f3399c91 Mon Sep 17 00:00:00 2001 From: Peter Kisfaludi Date: Thu, 4 Jun 2026 21:19:20 -0700 Subject: [PATCH 1/3] feat(visual_gen): add HunyuanDiT text-to-image pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrates Tencent HunyuanDiT into the TensorRT-LLM VisualGen framework, following the same BasePipeline pattern used by Qwen-Image and FLUX. New files: - tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py HunyuanDiTPipeline registered via @register_pipeline; supports v1.0-v1.2 checkpoints. Implements bilingual (BertModel + MT5EncoderModel) text encoding, DDPM denoising loop, resolution binning to training buckets, 2D RoPE embeddings, and VAE decode. - tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py HunyuanDiT2DModelWrapper: thin nn.Module around diffusers HunyuanDiT2DModel with load_weights() compatible with the WeightLoader contract. - tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py Default generation params (1024×1024, 50 steps, cfg=7.5) and extra-param schema (negative_prompt, use_resolution_binning). - examples/visual_gen/serve/configs/hunyuandit.yml Serve config with VANILLA attention backend. Modified: - pipeline_registry.py: add HunyuanDiT detection in _detect_from_checkpoint - models/__init__.py: export HunyuanDiTPipeline - docs/source/models/visual-generation.md: add HunyuanDiT to model table and feature matrix Qwen-Image (Qwen/Qwen-Image, Qwen/Qwen-Image-2512) was already present in main; no code changes needed for it — confirmed in docs table. Co-Authored-By: Claude Sonnet 4.6 --- docs/source/models/visual-generation.md | 8 +- .../visual_gen/serve/configs/hunyuandit.yml | 5 + .../_torch/visual_gen/models/__init__.py | 2 + .../visual_gen/models/hunyuandit/__init__.py | 12 + .../visual_gen/models/hunyuandit/defaults.py | 36 ++ .../models/hunyuandit/pipeline_hunyuandit.py | 539 ++++++++++++++++++ .../hunyuandit/transformer_hunyuandit.py | 154 +++++ .../_torch/visual_gen/pipeline_registry.py | 3 + 8 files changed, 758 insertions(+), 1 deletion(-) create mode 100644 examples/visual_gen/serve/configs/hunyuandit.yml create mode 100644 tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py create mode 100644 tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py create mode 100644 tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py create mode 100644 tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index db8c73912969..684207148458 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -35,6 +35,9 @@ TensorRT-LLM **VisualGen** provides a unified inference stack for diffusion mode | `Lightricks/LTX-2` | Text-to-Video (with Audio), Image-to-Video (with Audio) | | `Qwen/Qwen-Image` | Text-to-Image | | `Qwen/Qwen-Image-2512` | Text-to-Image | +| `Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers` | Text-to-Image | +| `Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers` | Text-to-Image | +| `Tencent-Hunyuan/HunyuanDiT-v1.0-Diffusers` | Text-to-Image | Models are auto-detected from the checkpoint directory. Diffusers-format models are detected via `model_index.json`; LTX-2 monolithic safetensors checkpoints are detected via embedded metadata. The `AutoPipeline` registry selects the appropriate pipeline class automatically. @@ -48,10 +51,13 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models | **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | | **Qwen-Image** [^2] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | +| **HunyuanDiT** [^3] | No | No | No | No | No | No | No | No | Yes | Yes | No | No | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. -[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. +[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. + +[^3]: HunyuanDiT uses bilingual (Chinese/English) text conditioning via a BertModel CLIP encoder and an MT5EncoderModel. The initial integration wraps the diffusers `HunyuanDiT2DModel` and supports BF16/FP16 inference via `trtllm-serve`. Quantization and parallel optimizations are planned for future releases. ## Quick Start diff --git a/examples/visual_gen/serve/configs/hunyuandit.yml b/examples/visual_gen/serve/configs/hunyuandit.yml new file mode 100644 index 000000000000..1659ac835d89 --- /dev/null +++ b/examples/visual_gen/serve/configs/hunyuandit.yml @@ -0,0 +1,5 @@ +attention_config: + backend: VANILLA +parallel_config: + cfg_size: 1 + ulysses_size: 1 diff --git a/tensorrt_llm/_torch/visual_gen/models/__init__.py b/tensorrt_llm/_torch/visual_gen/models/__init__.py index c5d63ed88b23..2c6461f05710 100644 --- a/tensorrt_llm/_torch/visual_gen/models/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/models/__init__.py @@ -35,6 +35,7 @@ from ..pipeline_registry import AutoPipeline, register_pipeline from .cosmos3 import Cosmos3OmniMoTPipeline from .flux import Flux2Pipeline, FluxPipeline +from .hunyuandit import HunyuanDiTPipeline from .ltx2 import LTX2Pipeline # noqa: F401 from .qwen_image import QwenImagePipeline from .wan import WanImageToVideoPipeline, WanPipeline @@ -44,6 +45,7 @@ "BasePipeline", "FluxPipeline", "Flux2Pipeline", + "HunyuanDiTPipeline", "QwenImagePipeline", "WanPipeline", "WanImageToVideoPipeline", diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py new file mode 100644 index 000000000000..068e18f3d552 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/__init__.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""HunyuanDiT text-to-image pipeline exports.""" + +from .pipeline_hunyuandit import HunyuanDiTPipeline +from .transformer_hunyuandit import HunyuanDiT2DModelWrapper + +__all__ = [ + "HunyuanDiTPipeline", + "HunyuanDiT2DModelWrapper", +] diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py new file mode 100644 index 000000000000..e1e72f6a261c --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/defaults.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""HunyuanDiT default generation parameters and extra-param schema.""" + +from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + +_HUNYUANDIT_DEFAULT_PARAMS = { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "max_sequence_length": 77, +} + + +def get_hunyuandit_default_params() -> dict: + return dict(_HUNYUANDIT_DEFAULT_PARAMS) + + +def get_hunyuandit_extra_param_specs() -> dict: + return { + "negative_prompt": ExtraParamSchema( + type="str", + default="", + description="Negative text prompt for classifier-free guidance.", + ), + "use_resolution_binning": ExtraParamSchema( + type="bool", + default=True, + description=( + "Snap resolution to the nearest HunyuanDiT training bucket " + "(recommended for best quality)." + ), + ), + } diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 000000000000..2d5e01e30d2a --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,539 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""HunyuanDiT text-to-image pipeline. + +Ports the denoise loop, bilingual text encoding (BertModel + MT5EncoderModel), +DDPM sampling, and VAE decode from ``diffusers.HunyuanDiTPipeline`` onto the +TensorRT-LLM VisualGen executor. + +The transformer backbone is loaded via :class:`HunyuanDiT2DModelWrapper` +which wraps the diffusers ``HunyuanDiT2DModel``. All other components (VAE, +text encoders, tokenizers, scheduler) are loaded directly from the HuggingFace +checkpoint using diffusers / transformers. + +References: + - Model card: https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers + - Tencent repo: https://github.com/Tencent-Hunyuan/HunyuanDiT + - diffusers: ``diffusers.pipelines.hunyuan_dit.pipeline_hunyuan_dit`` +""" + +import math +import time +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from tensorrt_llm._torch.visual_gen.output import CudaPhaseTimer, PipelineOutput +from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline +from tensorrt_llm._torch.visual_gen.pipeline_registry import PipelineComponent, register_pipeline +from tensorrt_llm.logger import logger + +from .defaults import get_hunyuandit_default_params, get_hunyuandit_extra_param_specs +from .transformer_hunyuandit import HunyuanDiT2DModelWrapper + +# --------------------------------------------------------------------------- +# Resolution binning +# --------------------------------------------------------------------------- + +# HunyuanDiT was trained on these aspect-ratio buckets (height × width). +# We snap user-requested resolutions to the closest bucket by default. +_SUPPORTED_SHAPE_BINNING = ( + (1024, 1024), + (1280, 1280), + (1024, 768), + (768, 1024), + (1280, 960), + (960, 1280), + (1280, 768), + (768, 1280), +) + + +def _map_to_standard_shapes( + target_height: int, target_width: int +) -> Tuple[int, int]: + """Return the closest supported (height, width) pair. + + Matching strategy: minimise the Euclidean distance in log-space between + the user's aspect ratio and the training buckets, then prefer the bucket + whose total pixel count is closest to the user's. This matches the + reference diffusers implementation. + """ + target_ratio = target_height / target_width + best = None + best_dist = float("inf") + for h, w in _SUPPORTED_SHAPE_BINNING: + dist = abs(math.log(h / w) - math.log(target_ratio)) + if dist < best_dist: + best_dist = dist + best = (h, w) + return best # type: ignore[return-value] + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + +_DEFAULT_GENERATION_PARAMS = get_hunyuandit_default_params() + + +@register_pipeline( + "HunyuanDiTPipeline", + hf_ids=[ + "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", + "Tencent-Hunyuan/HunyuanDiT-v1.0-Diffusers", + ], + doc="Tencent HunyuanDiT bilingual (Chinese/English) text-to-image pipeline.", +) +class HunyuanDiTPipeline(BasePipeline): + """HunyuanDiT Text-to-Image Pipeline. + + Supports HunyuanDiT-v1.0, v1.1, and v1.2 diffusers checkpoints. + + Text conditioning uses a dual-encoder architecture: + * A bilingual CLIP-like ``BertModel`` for short (≤ 77 token) sequences. + * An ``MT5EncoderModel`` for long (≤ 256 token) sequences. + Both encodings are passed jointly to the transformer. + """ + + DEFAULT_GENERATION_PARAMS = _DEFAULT_GENERATION_PARAMS + + def __init__(self, model_config): + super().__init__(model_config) + self.vae_scale_factor = 8 # SD-style VAE, 8× spatial compression + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def dtype(self) -> torch.dtype: + return self.model_config.torch_dtype + + @property + def device(self) -> torch.device: + if self.transformer is not None: + return next(self.transformer.parameters()).device + return torch.device("cuda:0") + + @property + def default_warmup_resolutions(self) -> List[Tuple[int, int]]: + return [(512, 512), (1024, 1024)] + + @property + def default_warmup_num_frames(self) -> List[int]: + return [1] + + @property + def resolution_multiple_of(self) -> Tuple[int, int]: + return (self.vae_scale_factor, self.vae_scale_factor) + + @property + def default_generation_params(self) -> dict: + return dict(_DEFAULT_GENERATION_PARAMS) + + @property + def extra_param_specs(self) -> dict: + return get_hunyuandit_extra_param_specs() + + # ------------------------------------------------------------------ + # Component initialisation + # ------------------------------------------------------------------ + + def _init_transformer(self) -> None: + logger.info("Creating HunyuanDiT2D transformer") + self.transformer = HunyuanDiT2DModelWrapper(model_config=self.model_config) + + def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> None: + with torch.no_grad(): + self.forward( + prompt="warmup", + height=height, + width=width, + num_inference_steps=max(steps, 2), + guidance_scale=7.5, + seed=42, + use_resolution_binning=False, + ) + + def load_standard_components( + self, + checkpoint_dir: str, + device: torch.device, + skip_components: Optional[list] = None, + ) -> None: + skip_components = skip_components or [] + + try: + from diffusers import AutoencoderKL, DDPMScheduler + except ImportError as exc: + raise ImportError( + "HunyuanDiT requires diffusers >= 0.26 (`pip install -U diffusers`)." + ) from exc + + try: + from transformers import AutoTokenizer, BertModel, MT5EncoderModel, T5Tokenizer + except ImportError as exc: + raise ImportError( + "HunyuanDiT requires transformers (`pip install -U transformers`)." + ) from exc + + if PipelineComponent.TOKENIZER not in skip_components: + logger.info("Loading HunyuanDiT CLIP tokenizer (BertTokenizer)...") + self.tokenizer = AutoTokenizer.from_pretrained( + checkpoint_dir, subfolder=PipelineComponent.TOKENIZER + ) + + if PipelineComponent.TOKENIZER_2 not in skip_components: + logger.info("Loading HunyuanDiT T5 tokenizer (MT5Tokenizer)...") + self.tokenizer_2 = T5Tokenizer.from_pretrained( + checkpoint_dir, subfolder=PipelineComponent.TOKENIZER_2 + ) + + if PipelineComponent.TEXT_ENCODER not in skip_components: + logger.info("Loading HunyuanDiT CLIP text encoder (BertModel)...") + self.text_encoder = BertModel.from_pretrained( + checkpoint_dir, + subfolder=PipelineComponent.TEXT_ENCODER, + torch_dtype=self.model_config.torch_dtype, + ).to(device) + self.text_encoder.eval() + + if PipelineComponent.TEXT_ENCODER_2 not in skip_components: + logger.info("Loading HunyuanDiT T5 text encoder (MT5EncoderModel)...") + self.text_encoder_2 = MT5EncoderModel.from_pretrained( + checkpoint_dir, + subfolder=PipelineComponent.TEXT_ENCODER_2, + torch_dtype=self.model_config.torch_dtype, + ).to(device) + self.text_encoder_2.eval() + + if PipelineComponent.VAE not in skip_components: + logger.info("Loading HunyuanDiT VAE (AutoencoderKL)...") + self.vae = AutoencoderKL.from_pretrained( + checkpoint_dir, + subfolder=PipelineComponent.VAE, + torch_dtype=torch.float32, # VAE decode in fp32 for numerical stability + ).to(device) + self.vae.eval() + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + if PipelineComponent.SCHEDULER not in skip_components: + logger.info("Loading HunyuanDiT DDPM scheduler...") + self.scheduler = DDPMScheduler.from_pretrained( + checkpoint_dir, subfolder=PipelineComponent.SCHEDULER + ) + + def load_weights(self, weights: dict) -> None: + if self.transformer is not None: + transformer_weights = weights.get("transformer", weights) + self.transformer.load_weights(transformer_weights) + self.transformer.to_inference_dtype().eval() + self._target_dtype = self.model_config.torch_dtype + + # ------------------------------------------------------------------ + # Text encoding (bilingual: BertModel + MT5EncoderModel) + # ------------------------------------------------------------------ + + def _encode_prompt_clip( + self, + prompt: List[str], + device: torch.device, + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode via BertModel (short CLIP-like encoder, max 77 tokens).""" + max_len = min(max_sequence_length, 77) + tok_out = self.tokenizer( + prompt, + padding="max_length", + max_length=max_len, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + out = self.text_encoder( + input_ids=tok_out.input_ids, + attention_mask=tok_out.attention_mask, + ) + return out.last_hidden_state.to(self.dtype), tok_out.attention_mask.to(device) + + def _encode_prompt_t5( + self, + prompt: List[str], + device: torch.device, + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode via MT5EncoderModel (long sequence encoder, max 256 tokens).""" + max_len = min(max_sequence_length, 256) + tok_out = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_len, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + out = self.text_encoder_2( + input_ids=tok_out.input_ids, + attention_mask=tok_out.attention_mask, + ) + return out.last_hidden_state.to(self.dtype), tok_out.attention_mask.to(device) + + def _encode_prompt( + self, + prompt: List[str], + device: torch.device, + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Return (clip_embeds, clip_mask, t5_embeds, t5_mask).""" + clip_embeds, clip_mask = self._encode_prompt_clip(prompt, device, max_sequence_length) + t5_embeds, t5_mask = self._encode_prompt_t5(prompt, device, max_sequence_length) + return clip_embeds, clip_mask, t5_embeds, t5_mask + + # ------------------------------------------------------------------ + # Latent utilities + # ------------------------------------------------------------------ + + def _prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator, + ) -> torch.Tensor: + from diffusers.utils.torch_utils import randn_tensor + + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """VAE decode → uint8 (B, H, W, 3) tensor.""" + # Scale latents per the SD VAE convention + latents = latents / self.vae.config.scaling_factor + with torch.no_grad(): + image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.permute(0, 2, 3, 1) # (B, C, H, W) → (B, H, W, C) + return (image * 255).round().to(torch.uint8) + + # ------------------------------------------------------------------ + # RoPE image embedding helper + # ------------------------------------------------------------------ + + @staticmethod + def _get_image_rotary_emb( + patch_size: int, + vae_scale_factor: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 2D RoPE embeddings for the latent grid. + + Follows diffusers' ``get_2d_rotary_pos_embed``. + """ + try: + from diffusers.models.embeddings import get_2d_rotary_pos_embed + except ImportError: + return None # type: ignore[return-value] + + grid_height = height // (vae_scale_factor * patch_size) + grid_width = width // (vae_scale_factor * patch_size) + base_size = 512 // (vae_scale_factor * patch_size) + grid_crops_coords = ( + (0, 0), + (grid_height, grid_width), + ) + freqs_cos, freqs_sin = get_2d_rotary_pos_embed( + embed_dim=88, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + use_real=True, + base_size=base_size, + ) + return ( + freqs_cos.to(device=device, dtype=dtype), + freqs_sin.to(device=device, dtype=dtype), + ) + + # ------------------------------------------------------------------ + # Inference entry points + # ------------------------------------------------------------------ + + def infer(self, req): + params = req.params + num_per = params.num_images_per_prompt or 1 + base_prompts = req.prompt if isinstance(req.prompt, list) else [req.prompt] + prompts = [p for p in base_prompts for _ in range(num_per)] + + negative = params.negative_prompt + if negative is not None: + negatives = negative if isinstance(negative, list) else [negative] + if len(negatives) == 1: + negatives = negatives * len(base_prompts) + negative = [n for n in negatives for _ in range(num_per)] + + extra = getattr(params, "extra_params", {}) or {} + return self.forward( + prompt=prompts, + negative_prompt=negative, + height=params.height, + width=params.width, + num_inference_steps=params.num_inference_steps, + guidance_scale=params.guidance_scale, + seed=params.seed, + max_sequence_length=params.max_sequence_length, + use_resolution_binning=extra.get("use_resolution_binning", True), + ) + + @torch.inference_mode() + def forward( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + seed: int = 42, + max_sequence_length: int = 77, + use_resolution_binning: bool = True, + **kwargs, + ) -> PipelineOutput: + """Text-to-image generation with HunyuanDiT. + + Mirrors ``diffusers.HunyuanDiTPipeline.__call__`` with classifier-free + guidance (separate negative prompt path). + """ + pipeline_start = time.time() + timer = CudaPhaseTimer() + timer.mark_pre_start() + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + do_cfg = guidance_scale > 1.0 + + device = self.device + generator = torch.Generator(device=device).manual_seed(seed) + + # Optionally snap to a training bucket + if use_resolution_binning: + height, width = _map_to_standard_shapes(height, width) + logger.info("HunyuanDiT: using binned resolution %d×%d", height, width) + + # Text encoding (bilingual) + logger.info("Encoding prompt...") + clip_embeds, clip_mask, t5_embeds, t5_mask = self._encode_prompt( + prompt, device, max_sequence_length + ) + + if do_cfg: + neg = negative_prompt + if neg is None: + neg = [""] * batch_size + elif isinstance(neg, str): + neg = [neg] * batch_size + neg_clip, neg_clip_mask, neg_t5, neg_t5_mask = self._encode_prompt( + neg, device, max_sequence_length + ) + # Concatenate along batch dim for a single forward pass + clip_embeds = torch.cat([neg_clip, clip_embeds]) + clip_mask = torch.cat([neg_clip_mask, clip_mask]) + t5_embeds = torch.cat([neg_t5, t5_embeds]) + t5_mask = torch.cat([neg_t5_mask, t5_mask]) + + # Latents + num_channels_latents = self.transformer.in_channels + latents = self._prepare_latents( + batch_size, + num_channels_latents, + height, + width, + self.dtype, + device, + generator, + ) + + # Image meta (target/source sizes for HunyuanDiT style conditioning) + image_meta_size = torch.tensor( + [height, width, height, width, 0, 0] * batch_size, + dtype=torch.float32, + device=device, + ).view(batch_size, 6) + if do_cfg: + image_meta_size = image_meta_size.repeat(2, 1) + + # Style embedding (0 = natural photo, per HunyuanDiT convention) + style = torch.zeros(batch_size, dtype=torch.int64, device=device) + if do_cfg: + style = style.repeat(2) + + # RoPE embeddings for the latent grid + image_rotary_emb = self._get_image_rotary_emb( + patch_size=2, + vae_scale_factor=self.vae_scale_factor, + height=height, + width=width, + device=device, + dtype=self.dtype, + ) + + # Scheduler timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Denoise loop + timer.mark_denoise_start() + logger.info("Denoising (%d steps)...", len(timesteps)) + + for i, t in enumerate(timesteps): + lat_in = torch.cat([latents] * 2) if do_cfg else latents + + noise_pred = self.transformer( + hidden_states=lat_in, + timestep=t.expand(lat_in.shape[0]), + encoder_hidden_states=clip_embeds, + text_embedding_mask=clip_mask, + encoder_hidden_states_t5=t5_embeds, + text_embedding_mask_t5=t5_mask, + image_meta_size=image_meta_size, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + if do_cfg: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + timer.mark_post_start() + logger.info("Decoding...") + image = self._decode_latents(latents) + + if getattr(self, "rank", 0) == 0: + logger.info("Pipeline total: %.2fs", time.time() - pipeline_start) + + timer.mark_end() + return timer.fill(PipelineOutput(image=image)) diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py new file mode 100644 index 000000000000..043cf35def8e --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""HunyuanDiT 2D transformer wrapper. + +Wraps the diffusers ``HunyuanDiT2DModel`` so it fits the TensorRT-LLM +VisualGen ``BasePipeline`` weight-loading contract: + + * ``_init_transformer`` (called in ``__init__``) builds the model from + config alone (no weights) using :class:`HunyuanDiT2DModelWrapper`. + * ``load_weights`` is called later with a flat ``{name: tensor}`` dict + from the safetensors checkpoint; this method delegates to ``load_state_dict``. + * ``forward`` delegates to the underlying diffusers transformer with the + same kwargs the denoising loop passes (``hidden_states``, ``timestep``, + ``encoder_hidden_states``, ``text_embedding_mask``, + ``encoder_hidden_states_t5``, ``text_embedding_mask_t5``). + +All non-transformer components (VAE, text encoders, scheduler) are loaded +in ``HunyuanDiTPipeline.load_standard_components``. +""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from tensorrt_llm.logger import logger + + +class HunyuanDiT2DModelWrapper(nn.Module): + """Thin TRT-LLM wrapper around diffusers ``HunyuanDiT2DModel``. + + Args: + model_config: ``DiffusionModelConfig`` instance from the pipeline. + **transformer_kwargs: Config kwargs forwarded to ``HunyuanDiT2DModel`` + (e.g. ``num_attention_heads``, ``num_layers``, …). All keyword + args have defaults matching the published HunyuanDiT-v1.2 model so + the wrapper works even when ``pretrained_config`` is sparse. + """ + + # Published HunyuanDiT-v1.2 config defaults + _DEFAULTS: Dict[str, Any] = { + "num_attention_heads": 16, + "attention_head_dim": 88, + "in_channels": 4, + "patch_size": 2, + "activation_fn": "gelu-approximate", + "num_layers": 40, + "use_linear_projection": False, + "cross_attention_dim": 1024, + "cross_attention_dim_t5": 2048, + "pooled_projection_dim": 1024, + "text_len": 77, + "text_len_t5": 256, + "norm_type": "ada_norm_continous", + "sample_size": 128, + } + + def __init__(self, model_config, **transformer_kwargs): + super().__init__() + self.model_config = model_config + + # Merge defaults → pretrained_config → caller overrides + cfg: Dict[str, Any] = dict(self._DEFAULTS) + pretrained = getattr(model_config, "pretrained_config", None) + if pretrained is not None: + src = pretrained if isinstance(pretrained, dict) else vars(pretrained) + for k in self._DEFAULTS: + if k in src: + cfg[k] = src[k] + cfg.update(transformer_kwargs) + + try: + from diffusers.models import HunyuanDiT2DModel + except ImportError as exc: + raise ImportError( + "HunyuanDiT requires diffusers >= 0.26 " + "(`pip install -U diffusers`)." + ) from exc + + logger.info( + "Building HunyuanDiT2DModel: %d layers, %d heads, head_dim=%d", + cfg["num_layers"], + cfg["num_attention_heads"], + cfg["attention_head_dim"], + ) + self.transformer = HunyuanDiT2DModel(**cfg) + + # Remember latent channel count so the pipeline can read it. + self.in_channels = cfg["in_channels"] + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ + + def load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + """Populate transformer parameters from a flat state-dict. + + ``weights`` is provided by the TRT-LLM ``WeightLoader`` and contains + the raw tensors from the checkpoint's ``transformer/`` safetensors + shards. We use ``strict=False`` so missing or extra keys (e.g. from + a newer / older checkpoint version) do not abort loading. + """ + result = self.transformer.load_state_dict(weights, strict=False) + if result.missing_keys: + logger.warning( + "HunyuanDiT: %d missing keys in state dict " + "(first 10: %s)", + len(result.missing_keys), + result.missing_keys[:10], + ) + if result.unexpected_keys: + logger.warning( + "HunyuanDiT: %d unexpected keys in state dict " + "(first 10: %s)", + len(result.unexpected_keys), + result.unexpected_keys[:10], + ) + + def to_inference_dtype(self): + dtype = getattr(self.model_config, "torch_dtype", torch.bfloat16) + self.transformer.to(dtype) + return self + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + text_embedding_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_t5: Optional[torch.Tensor] = None, + text_embedding_mask_t5: Optional[torch.Tensor] = None, + image_meta_size: Optional[torch.Tensor] = None, + style: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + **kwargs, + ): + return self.transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + text_embedding_mask=text_embedding_mask, + encoder_hidden_states_t5=encoder_hidden_states_t5, + text_embedding_mask_t5=text_embedding_mask_t5, + image_meta_size=image_meta_size, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=return_dict, + ) diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_registry.py b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py index 7be918d0ffec..2bf3013cbec7 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_registry.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_registry.py @@ -184,6 +184,9 @@ def _detect_from_checkpoint(checkpoint_dir: str) -> str: if "QwenImage" in class_name: return "QwenImagePipeline" + if "HunyuanDiT" in class_name: + return "HunyuanDiTPipeline" + if "Cosmos3" in class_name: return "Cosmos3OmniMoTPipeline" From 33b436045f0896b47ea639527493d1ad6a70b879 Mon Sep 17 00:00:00 2001 From: Peter Kisfaludi Date: Thu, 4 Jun 2026 21:50:32 -0700 Subject: [PATCH 2/3] feat(visual_gen/hunyuandit): add Ulysses sequence parallelism MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements DeepSpeed Ulysses sequence parallelism for HunyuanDiT following the same pattern used by FLUX (SequenceSharder + all-to-all around attention). HunyuanDiTUlyssesAttnProcessor Drop-in replacement for HunyuanAttnProcessor2_0. For self-attention (encoder_hidden_states is None) it wraps F.scaled_dot_product_attention with two all_to_all_4d calls: [B, S/U, H, D] → all-to-all → [B, S, H/U, D] → SDPA → all-to-all → [B, S/U, H, D] Cross-attention falls back to standard SDPA (text K/V is replicated on every rank, so no all-to-all is needed). HunyuanDiT2DModelUlysses Monkeypatches HunyuanDiT2DModel.forward() via types.MethodType to: 1. Shard the patch-embedded latent sequence across Ulysses ranks. 2. Slice image_rotary_emb to the local sequence shard. 3. Run the transformer blocks (with custom processors for self-attn). 4. all_gather to reassemble the full sequence before norm_out/proj_out. U-Net-style skip tensors are sharded the same way as hidden_states so no special handling is needed for the skip connections. Constraints - num_attention_heads (16) must be divisible by ulysses_size. - Validated at wrapper construction; raises ValueError otherwise. - Requires vgm.ulysses_group to be initialised (VisualGenMapping.init_device_mesh). Docs: HunyuanDiT Ulysses column updated from No → Yes in the feature matrix. Qwen-Image already has Ulysses support through the TRT-LLM Attention module (inherits UlyssesAttention wrapping); no code changes needed there. Co-Authored-By: Claude Sonnet 4.6 --- docs/source/models/visual-generation.md | 4 +- .../hunyuandit/transformer_hunyuandit.py | 408 ++++++++++++++++-- 2 files changed, 372 insertions(+), 40 deletions(-) diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index 684207148458..f92d4425cdc0 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -51,13 +51,13 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models | **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | | **Qwen-Image** [^2] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | -| **HunyuanDiT** [^3] | No | No | No | No | No | No | No | No | Yes | Yes | No | No | +| **HunyuanDiT** [^3] | No | No | No | No | Yes | No | No | No | Yes | Yes | No | No | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. [^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. -[^3]: HunyuanDiT uses bilingual (Chinese/English) text conditioning via a BertModel CLIP encoder and an MT5EncoderModel. The initial integration wraps the diffusers `HunyuanDiT2DModel` and supports BF16/FP16 inference via `trtllm-serve`. Quantization and parallel optimizations are planned for future releases. +[^3]: HunyuanDiT uses bilingual (Chinese/English) text conditioning via a BertModel CLIP encoder and an MT5EncoderModel. Ulysses sequence parallelism is supported: after the patch-embed the latent sequence is sharded across ranks; a custom attention processor injects all-to-all collectives around self-attention while text cross-attention remains standard SDPA (text tokens are replicated). Set `ulysses_size` to the desired number of sequence-parallel ranks (must divide `num_attention_heads=16`). Quantization and ring-attention optimizations are planned for future releases. ## Quick Start diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py index 043cf35def8e..ab88aac451f5 100644 --- a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py @@ -1,41 +1,362 @@ # SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""HunyuanDiT 2D transformer wrapper. - -Wraps the diffusers ``HunyuanDiT2DModel`` so it fits the TensorRT-LLM -VisualGen ``BasePipeline`` weight-loading contract: - - * ``_init_transformer`` (called in ``__init__``) builds the model from - config alone (no weights) using :class:`HunyuanDiT2DModelWrapper`. - * ``load_weights`` is called later with a flat ``{name: tensor}`` dict - from the safetensors checkpoint; this method delegates to ``load_state_dict``. - * ``forward`` delegates to the underlying diffusers transformer with the - same kwargs the denoising loop passes (``hidden_states``, ``timestep``, - ``encoder_hidden_states``, ``text_embedding_mask``, - ``encoder_hidden_states_t5``, ``text_embedding_mask_t5``). - -All non-transformer components (VAE, text encoders, scheduler) are loaded -in ``HunyuanDiTPipeline.load_standard_components``. +"""HunyuanDiT 2D transformer wrapper with Ulysses sequence parallelism. + +Architecture overview +--------------------- +The wrapper exposes two classes to the pipeline: + +``HunyuanDiT2DModelWrapper`` + Thin ``nn.Module`` around diffusers' ``HunyuanDiT2DModel``. Selects + ``HunyuanDiT2DModelUlysses`` (a subclass with Ulysses support) when + ``visual_gen_mapping.ulysses_size > 1``, otherwise falls back to the + vanilla diffusers model for single-GPU usage. + +``HunyuanDiT2DModelUlysses`` + Subclass of ``HunyuanDiT2DModel`` that overrides ``forward()`` to shard + the latent sequence across Ulysses ranks AFTER the patch-embed and gather + it back BEFORE the final norm/proj. Self-attention blocks use + ``HunyuanDiTUlyssesAttnProcessor`` which injects an all-to-all before and + after ``F.scaled_dot_product_attention``. Cross-attention is standard + SDPA — text tokens are replicated on every rank so no all-to-all is needed. + +``HunyuanDiTUlyssesAttnProcessor`` + Drop-in replacement for ``HunyuanAttnProcessor2_0``. When called for + self-attention it wraps SDPA with + + all_to_all(q/k/v, scatter_dim=heads, gather_dim=seq) # before + SDPA([B, S, H/U, D]) + all_to_all(output, scatter_dim=seq, gather_dim=heads) # after + + so each rank computes a head-sharded slice of the full-sequence attention. + +References +---------- +- DeepSpeed Ulysses: https://arxiv.org/abs/2309.14509 +- diffusers HunyuanDiT: ``diffusers.models.transformers.hunyuan_transformer_2d`` """ from typing import Any, Dict, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F from tensorrt_llm.logger import logger +# --------------------------------------------------------------------------- +# Ulysses attention processor +# --------------------------------------------------------------------------- + + +class HunyuanDiTUlyssesAttnProcessor: + """Custom attention processor injecting Ulysses all-to-all for self-attention. + + Compatible with diffusers' attention processor protocol + (``processor.__call__(attn, hidden_states, ...)``) and is a drop-in + replacement for ``HunyuanAttnProcessor2_0``. + + For *cross-attention* (``encoder_hidden_states is not None``) the processor + falls back to standard SDPA because text K/V tensors are already replicated + on every rank — no all-to-all is required. + + Args: + ulysses_group: ``torch.distributed.ProcessGroup`` spanning Ulysses ranks. + ulysses_size: Number of Ulysses ranks (must divide ``num_attention_heads``). + """ + + def __init__(self, ulysses_group: dist.ProcessGroup, ulysses_size: int): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("HunyuanDiTUlyssesAttnProcessor requires PyTorch ≥ 2.0.") + self.ulysses_group = ulysses_group + self.ulysses_size = ulysses_size + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, + ) -> torch.Tensor: + from tensorrt_llm._torch.distributed import all_to_all_4d + + try: + from diffusers.models.embeddings import apply_rotary_emb + except ImportError: + apply_rotary_emb = None + + is_cross = encoder_hidden_states is not None + B = hidden_states.shape[0] + + # ---- Q / K / V projections ---------------------------------------- + query = attn.to_q(hidden_states) + kv_src = encoder_hidden_states if is_cross else hidden_states + key = attn.to_k(kv_src) + value = attn.to_v(kv_src) + + # Derive head_dim from key projection output + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + # Reshape → [B, S, H, D] + query = query.view(B, -1, attn.heads, head_dim) + key = key.view(B, -1, attn.heads, head_dim) + value = value.view(B, -1, attn.heads, head_dim) + + # ---- QK-norm (LayerNorm, applied per-head) ------------------------- + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # ---- RoPE (expects [B, H, S, D]) ----------------------------------- + if image_rotary_emb is not None and apply_rotary_emb is not None: + query = apply_rotary_emb(query.transpose(1, 2), image_rotary_emb)[0] + query = query.transpose(1, 2) + if not is_cross: + key = apply_rotary_emb(key.transpose(1, 2), image_rotary_emb)[0] + key = key.transpose(1, 2) + + # ---- Ulysses all-to-all (self-attention only) ---------------------- + if not is_cross and self.ulysses_size > 1: + # [B, S/U, H, D] → [B, S, H/U, D] + query = all_to_all_4d( + query, scatter_dim=2, gather_dim=1, process_group=self.ulysses_group + ) + key = all_to_all_4d( + key, scatter_dim=2, gather_dim=1, process_group=self.ulysses_group + ) + value = all_to_all_4d( + value, scatter_dim=2, gather_dim=1, process_group=self.ulysses_group + ) + + # SDPA expects [B, H, S, D] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + out = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0) + + # Reverse: [B, H/U, S, D] → [B, S, H/U, D] → [B, S/U, H, D] + out = out.transpose(1, 2).contiguous() + out = all_to_all_4d( + out, scatter_dim=1, gather_dim=2, process_group=self.ulysses_group + ) + else: + # Standard SDPA (cross-attention or single-GPU fallback) + if attention_mask is not None: + seq_len_kv = key.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, seq_len_kv, B) + attention_mask = attention_mask.view(B, attn.heads, -1, attention_mask.shape[-1]) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + out = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0 + ) + out = out.transpose(1, 2) + + # ---- Output projection --------------------------------------------- + out = out.reshape(B, -1, attn.heads * head_dim).to(query.dtype) + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out + + +# --------------------------------------------------------------------------- +# Ulysses-capable HunyuanDiT2DModel subclass +# --------------------------------------------------------------------------- + + +class HunyuanDiT2DModelUlysses: + """Mixin that overrides ``forward()`` to add Ulysses sequence sharding. + + We implement this as a plain class whose ``forward()`` replaces the + diffusers model's ``forward()`` via attribute assignment after construction, + because subclassing diffusers ``ModelMixin`` (which uses ``@register_to_config``) + reliably breaks the config serialization when additional ``__init__`` kwargs are added. + + Usage (inside ``HunyuanDiT2DModelWrapper``):: + + model = HunyuanDiT2DModel(...) + HunyuanDiT2DModelUlysses.patch(model, ulysses_group, ulysses_size) + """ + + @staticmethod + def patch( + model, + ulysses_group: dist.ProcessGroup, + ulysses_size: int, + ) -> None: + """Attach Ulysses sharding behaviour to an existing ``HunyuanDiT2DModel``. + + 1. Stores ``ulysses_group`` / ``ulysses_size`` on the model instance. + 2. Replaces ``forward`` with our sequence-sharding variant. + 3. Replaces the self-attention processor on every block with + ``HunyuanDiTUlyssesAttnProcessor``. + """ + model._ulysses_group = ulysses_group + model._ulysses_size = ulysses_size + + processor = HunyuanDiTUlyssesAttnProcessor(ulysses_group, ulysses_size) + for block in model.blocks: + # attn1 = self-attention; attn2 = cross-attention (keep default) + block.attn1.processor = processor + + # Bind the new forward as a bound method + import types + + model.forward = types.MethodType(HunyuanDiT2DModelUlysses._forward, model) + + @staticmethod + def _forward( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + controlnet_block_samples=None, + return_dict=True, + ): + """Ulysses-aware forward. + + Identical to ``HunyuanDiT2DModel.forward`` except that it shards the + patch-embedded sequence across Ulysses ranks before the transformer + blocks and gathers it back before ``norm_out`` / ``proj_out``. + """ + height, width = hidden_states.shape[-2:] + + # 1. PatchEmbed → [B, S, D] + hidden_states = self.pos_embed(hidden_states) + + # 2. Timestep + text conditioning + temb = self.time_extra_emb( + timestep, + encoder_hidden_states_t5, + image_meta_size, + style, + hidden_dtype=timestep.dtype, + ) + + # 3. T5 text projection + batch_size, seq_len_t5, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, seq_len_t5, -1) + + # Concatenate CLIP + T5 text embeddings + encoder_hidden_states = torch.cat( + [encoder_hidden_states, encoder_hidden_states_t5], dim=1 + ) + text_embedding_mask = torch.cat( + [text_embedding_mask, text_embedding_mask_t5], dim=-1 + ) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + encoder_hidden_states = torch.where( + text_embedding_mask, encoder_hidden_states, self.text_embedding_padding + ) + + # 4. Ulysses: shard image sequence across ranks + S = hidden_states.shape[1] + ulysses_size = self._ulysses_size + ulysses_group = self._ulysses_group + rank = dist.get_rank(ulysses_group) + shard_size = S // ulysses_size + hidden_states = hidden_states[:, rank * shard_size : (rank + 1) * shard_size, :].contiguous() + + # Shard the RoPE frequencies to match the sequence shard + if image_rotary_emb is not None: + freqs_cos, freqs_sin = image_rotary_emb + freqs_cos = freqs_cos[rank * shard_size : (rank + 1) * shard_size] + freqs_sin = freqs_sin[rank * shard_size : (rank + 1) * shard_size] + image_rotary_emb = (freqs_cos, freqs_sin) + + # 5. Transformer blocks (U-Net-style skip connections) + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.num_layers // 2: + if controlnet_block_samples is not None: + skip = skips.pop() + controlnet_block_samples.pop() + else: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + ) + else: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + if layer < (self.config.num_layers // 2 - 1): + skips.append(hidden_states) + + if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: + raise ValueError( + "The number of controls is not equal to the number of skip connections." + ) + + # 6. Ulysses: gather sequence shards → [B, S, D] + gathered = [torch.zeros_like(hidden_states) for _ in range(ulysses_size)] + dist.all_gather(gathered, hidden_states.contiguous(), group=ulysses_group) + hidden_states = torch.cat(gathered, dim=1) + + # 7. Final norm + projection + hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) + hidden_states = self.proj_out(hidden_states) + + # 8. Unpatchify → [B, out_channels, H, W] + patch_size = self.pos_embed.patch_size + h_out = height // patch_size + w_out = width // patch_size + hidden_states = hidden_states.reshape( + hidden_states.shape[0], h_out, w_out, patch_size, patch_size, self.out_channels + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + hidden_states.shape[0], self.out_channels, h_out * patch_size, w_out * patch_size + ) + + if not return_dict: + return (output,) + + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + return Transformer2DModelOutput(sample=output) + + +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + class HunyuanDiT2DModelWrapper(nn.Module): """Thin TRT-LLM wrapper around diffusers ``HunyuanDiT2DModel``. + Selects ``HunyuanDiT2DModelUlysses``-patched model when + ``visual_gen_mapping.ulysses_size > 1``. + Args: - model_config: ``DiffusionModelConfig`` instance from the pipeline. - **transformer_kwargs: Config kwargs forwarded to ``HunyuanDiT2DModel`` - (e.g. ``num_attention_heads``, ``num_layers``, …). All keyword - args have defaults matching the published HunyuanDiT-v1.2 model so - the wrapper works even when ``pretrained_config`` is sparse. + model_config: ``DiffusionModelConfig`` from the pipeline. + **transformer_kwargs: Config overrides for ``HunyuanDiT2DModel``. """ # Published HunyuanDiT-v1.2 config defaults @@ -74,45 +395,56 @@ def __init__(self, model_config, **transformer_kwargs): from diffusers.models import HunyuanDiT2DModel except ImportError as exc: raise ImportError( - "HunyuanDiT requires diffusers >= 0.26 " - "(`pip install -U diffusers`)." + "HunyuanDiT requires diffusers >= 0.26 (`pip install -U diffusers`)." ) from exc + # Read Ulysses config + vgm = getattr(model_config, "visual_gen_mapping", None) + ulysses_size = vgm.ulysses_size if vgm is not None else 1 + + num_heads = cfg["num_attention_heads"] + if ulysses_size > 1 and num_heads % ulysses_size != 0: + raise ValueError( + f"HunyuanDiT: num_attention_heads ({num_heads}) must be divisible by " + f"ulysses_size ({ulysses_size})." + ) + logger.info( - "Building HunyuanDiT2DModel: %d layers, %d heads, head_dim=%d", + "Building HunyuanDiT2DModel: %d layers, %d heads, head_dim=%d, ulysses=%d", cfg["num_layers"], - cfg["num_attention_heads"], + num_heads, cfg["attention_head_dim"], + ulysses_size, ) self.transformer = HunyuanDiT2DModel(**cfg) - - # Remember latent channel count so the pipeline can read it. self.in_channels = cfg["in_channels"] + # Patch model with Ulysses-aware forward when requested + if ulysses_size > 1: + ulysses_group = vgm.ulysses_group + if ulysses_group is None: + raise RuntimeError( + "HunyuanDiT Ulysses requires vgm.ulysses_group to be initialised " + "(call VisualGenMapping.init_device_mesh first)." + ) + HunyuanDiT2DModelUlysses.patch(self.transformer, ulysses_group, ulysses_size) + logger.info("HunyuanDiT: Ulysses sequence parallelism enabled (size=%d)", ulysses_size) + # ------------------------------------------------------------------ # Weight loading # ------------------------------------------------------------------ def load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - """Populate transformer parameters from a flat state-dict. - - ``weights`` is provided by the TRT-LLM ``WeightLoader`` and contains - the raw tensors from the checkpoint's ``transformer/`` safetensors - shards. We use ``strict=False`` so missing or extra keys (e.g. from - a newer / older checkpoint version) do not abort loading. - """ result = self.transformer.load_state_dict(weights, strict=False) if result.missing_keys: logger.warning( - "HunyuanDiT: %d missing keys in state dict " - "(first 10: %s)", + "HunyuanDiT: %d missing keys (first 10: %s)", len(result.missing_keys), result.missing_keys[:10], ) if result.unexpected_keys: logger.warning( - "HunyuanDiT: %d unexpected keys in state dict " - "(first 10: %s)", + "HunyuanDiT: %d unexpected keys (first 10: %s)", len(result.unexpected_keys), result.unexpected_keys[:10], ) From e39bb7375cf0e36a4e9e767513188a1db2440169 Mon Sep 17 00:00:00 2001 From: Peter Kisfaludi Date: Fri, 5 Jun 2026 15:31:31 -0700 Subject: [PATCH 3/3] fix(visual_gen/hunyuandit): address PR review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pipeline_hunyuandit.py: fix return type hint on _get_image_rotary_emb to Optional[Tuple[...]] since it returns None when diffusers is unavailable - pipeline_hunyuandit.py: rename unused loop variable i → _ in denoising loop - transformer_hunyuandit.py: add runtime check that sequence length S is divisible by ulysses_size before sharding, with a descriptive error message Co-Authored-By: Claude Sonnet 4.6 --- .../visual_gen/models/hunyuandit/pipeline_hunyuandit.py | 4 ++-- .../visual_gen/models/hunyuandit/transformer_hunyuandit.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py index 2d5e01e30d2a..2ca4161327d2 100644 --- a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/pipeline_hunyuandit.py @@ -343,7 +343,7 @@ def _get_image_rotary_emb( width: int, device: torch.device, dtype: torch.dtype, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Compute 2D RoPE embeddings for the latent grid. Follows diffusers' ``get_2d_rotary_pos_embed``. @@ -504,7 +504,7 @@ def forward( timer.mark_denoise_start() logger.info("Denoising (%d steps)...", len(timesteps)) - for i, t in enumerate(timesteps): + for _, t in enumerate(timesteps): lat_in = torch.cat([latents] * 2) if do_cfg else latents noise_pred = self.transformer( diff --git a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py index ab88aac451f5..2eeb5012bc79 100644 --- a/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py +++ b/tensorrt_llm/_torch/visual_gen/models/hunyuandit/transformer_hunyuandit.py @@ -273,6 +273,12 @@ def _forward( ulysses_size = self._ulysses_size ulysses_group = self._ulysses_group rank = dist.get_rank(ulysses_group) + if S % ulysses_size != 0: + raise ValueError( + f"HunyuanDiT Ulysses: sequence length {S} is not divisible by " + f"ulysses_size={ulysses_size}. Adjust the image resolution so that " + f"(height // vae_scale_factor // patch_size)^2 is divisible by ulysses_size." + ) shard_size = S // ulysses_size hidden_states = hidden_states[:, rank * shard_size : (rank + 1) * shard_size, :].contiguous()