From 8c47eebf91c2109c6eb1154c2efa1bb7e2c53125 Mon Sep 17 00:00:00 2001 From: yibinl-nvidia <109242046+yibinl-nvidia@users.noreply.github.com> Date: Fri, 5 Jun 2026 02:08:33 +0000 Subject: [PATCH] Persist LTX2 LoRA weight snapshots Signed-off-by: yibinl-nvidia <109242046+yibinl-nvidia@users.noreply.github.com> --- .../models/ltx2/pipeline_ltx2_two_stages.py | 562 +++++++++++++----- .../_torch/visual_gen/test_ltx2_pipeline.py | 257 ++++++++ 2 files changed, 685 insertions(+), 134 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py index 6a8e365d41cf..5b4ef2ebd362 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py @@ -4,7 +4,9 @@ import json import math +import os import time +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union import safetensors.torch @@ -41,6 +43,8 @@ _FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2) # Baseline BF16 peak memory ~75 GiB, saving BF16 weights snopshot total ~108 GiB. _BF16_WEIGHTS_SNAPSHOT_FREE_MEMORY_THRESHOLD_GIB = 115.0 +_LTX2_PERSISTENT_LORA_WEIGHTS_ENV = "TRTLLM_LTX2_PERSISTENT_LORA_WEIGHTS" +_TRUE_ENV_VALUES = {"1", "true", "yes", "on"} # --------------------------------------------------------------------------- @@ -48,24 +52,35 @@ # --------------------------------------------------------------------------- -def _get_free_gpu_memory_gib() -> Optional[float]: - """Return free memory on the current CUDA device, or ``None`` if unavailable.""" +def _get_free_gpu_memory_gib( + device: Optional[Union[torch.device, str, int]] = None, +) -> Optional[float]: + """Return free memory on the requested CUDA device, or ``None`` if unavailable.""" if not torch.cuda.is_available(): return None try: - free_bytes, _ = torch.cuda.mem_get_info() + free_bytes, _ = torch.cuda.mem_get_info(device=device) except (RuntimeError, OSError) as exc: - logger.warning(f"Unable to query CUDA free memory for BF16 weight snapshots: {exc}") + logger.warning( + f"Unable to query CUDA free memory for BF16 weight snapshots on device {device}: {exc}" + ) return None return free_bytes / (1024**3) def _should_save_bf16_weights( + device: Optional[Union[torch.device, str, int]] = None, + preload_free_gib: Optional[float] = None, threshold_gib: float = _BF16_WEIGHTS_SNAPSHOT_FREE_MEMORY_THRESHOLD_GIB, ) -> bool: - free_gib = _get_free_gpu_memory_gib() + free_gib = preload_free_gib + source = "pre-load" + if free_gib is None: + free_gib = _get_free_gpu_memory_gib(device=device) + source = "current" + if free_gib is None: logger.debug("BF16 weight snapshots disabled: CUDA free memory is unavailable") return False @@ -73,12 +88,17 @@ def _should_save_bf16_weights( save_state = free_gib > threshold_gib relation = ">" if save_state else "<=" logger.debug( - f"BF16 weight snapshots {'enabled' if save_state else 'disabled'}: " - f"free GPU memory {free_gib:.2f} GiB {relation} {threshold_gib:.2f} GiB threshold" + f"BF16 weight snapshots {'enabled' if save_state else 'disabled'} " + f"on device {device}: {source} free GPU memory {free_gib:.2f} GiB " + f"{relation} {threshold_gib:.2f} GiB threshold" ) return save_state +def _persistent_lora_weights_enabled() -> bool: + return os.environ.get(_LTX2_PERSISTENT_LORA_WEIGHTS_ENV, "").strip().lower() in _TRUE_ENV_VALUES + + def _load_lora_deltas( lora_path: str, transformer: torch.nn.Module, @@ -367,6 +387,10 @@ def _requantize_fp8_weight( return qw, scale +def _scale_like(scale: torch.Tensor, reference: torch.Tensor) -> torch.Tensor: + return scale.to(device=reference.device, dtype=reference.dtype).reshape(reference.shape) + + def _apply_lora_deltas( module: torch.nn.Module, deltas: Dict[str, torch.Tensor], @@ -402,6 +426,7 @@ def _apply_lora_deltas( applied = 0 snapshot_required = 0 saved_state: Dict[str, Any] = {} + applied_deltas: Dict[str, torch.Tensor] = {} # Build a lookup that maps *clean* parameter names to the actual # Parameter objects. torch.compile wraps each block in an # OptimizedModule, inserting ``._orig_mod.`` into the parameter @@ -419,110 +444,122 @@ def _apply_lora_deltas( clean = raw_name.replace("._orig_mod.", ".") module_dict[clean] = mod - for name, delta in deltas.items(): - param_name = name if name in state else f"{name}.weight" - if param_name not in state: - continue + try: + for name, delta in deltas.items(): + param_name = name if name in state else f"{name}.weight" + if param_name not in state: + continue + + param = state[param_name] + base = param_name.rsplit(".weight", 1)[0] + + # --- same shape --------------------------------------------------- + if param.shape == delta.shape: + if param.dtype in _FP8_DTYPES: + # FP8: dequant -> apply -> requant + scale_key = f"{base}.weight_scale" + if scale_key not in state: + raise RuntimeError( + f"Cannot apply LoRA delta to FP8 param '{param_name}': missing {scale_key}." + ) + ws_param = state[scale_key] + out_f, in_f = delta.shape + is_per_tensor = ws_param.data.numel() == 1 + is_packed = not is_per_tensor and _is_fp8_scale_packed( + ws_param.data, out_f, in_f + ) - param = state[param_name] - base = param_name.rsplit(".weight", 1)[0] + saved_state[param_name] = param.data.clone() + saved_state[scale_key] = ws_param.data.clone() + + bf16 = _dequantize_fp8_weight( + param.data, + ws_param.data, + packed=is_packed, + ) + bf16.add_(delta.to(bf16.device, bf16.dtype), alpha=sign) - # --- same shape --------------------------------------------------- - if param.shape == delta.shape: - if param.dtype in _FP8_DTYPES: - # FP8: dequant → apply → requant + qw, new_scale = _requantize_fp8_weight( + bf16, + repack=is_packed, + per_tensor=is_per_tensor, + ) + new_scale = _scale_like(new_scale, ws_param.data) + param.data.copy_(qw) + ws_param.data.copy_(new_scale) + snapshot_required += 1 + else: + if save_bf16_weights and param.dtype == torch.bfloat16: + saved_state[param_name] = param.data.clone() + snapshot_required += 1 + # BF16: direct in-place addition, then restore by snapshot + # copy when memory allows or by subtracting the delta. + param.data.add_( + delta.to(param.device, param.dtype), + alpha=sign, + ) + applied_deltas[name] = delta + applied += 1 + + # --- packed FP4 (half last dim) ----------------------------------- + elif ( + param.ndim == 2 + and delta.ndim == 2 + and param.shape[0] == delta.shape[0] + and param.shape[1] * 2 == delta.shape[1] + ): scale_key = f"{base}.weight_scale" - if scale_key not in state: + scale2_key = f"{base}.weight_scale_2" + if scale_key not in state or scale2_key not in state: raise RuntimeError( - f"Cannot apply LoRA delta to FP8 param '{param_name}': missing {scale_key}." + f"Cannot apply LoRA delta to quantized param " + f"'{param_name}': missing {scale_key} or {scale2_key}." ) + ws_param = state[scale_key] - out_f, in_f = delta.shape - is_per_tensor = ws_param.data.numel() == 1 - is_packed = not is_per_tensor and _is_fp8_scale_packed(ws_param.data, out_f, in_f) + ws2_param = state[scale2_key] + out_features, in_features = delta.shape saved_state[param_name] = param.data.clone() saved_state[scale_key] = ws_param.data.clone() + saved_state[scale2_key] = ws2_param.data.clone() - bf16 = _dequantize_fp8_weight( + bf16 = _dequantize_fp4_weight( param.data, ws_param.data, - packed=is_packed, + ws2_param.data, + out_features, + in_features, ) bf16.add_(delta.to(bf16.device, bf16.dtype), alpha=sign) - qw, new_scale = _requantize_fp8_weight( - bf16, - repack=is_packed, - per_tensor=is_per_tensor, - ) - param.data.copy_(qw) - ws_param.data.copy_(new_scale) + linear_mod = module_dict.get(base) + if linear_mod is None or not isinstance(linear_mod, Linear): + raise RuntimeError( + f"Packed FP4 LoRA merge: could not find Linear module at '{base}'." + ) + + # Packed FP4: keep the LoRA-merged weight in BF16 for stage 2. + # This replaces packed FP4 storage (out, in//2) with BF16 storage + # (out, in) and swaps the parent Linear to plain F.linear. + param.data = bf16 + saved_state[f"__quant_method__{base}"] = linear_mod.quant_method + linear_mod.quant_method = UnquantizedLinearMethod() snapshot_required += 1 + applied_deltas[name] = delta + applied += 1 else: - if save_bf16_weights and param.dtype == torch.bfloat16: - saved_state[param_name] = param.data.clone() - snapshot_required += 1 - # BF16: direct in-place addition, then restore by snapshot copy - # when memory allows. Other dense weights restore by subtraction. - param.data.add_( - delta.to(param.device, param.dtype), - alpha=sign, - ) - applied += 1 - - # --- packed FP4 (half last dim) ----------------------------------- - elif ( - param.ndim == 2 - and delta.ndim == 2 - and param.shape[0] == delta.shape[0] - and param.shape[1] * 2 == delta.shape[1] - ): - scale_key = f"{base}.weight_scale" - scale2_key = f"{base}.weight_scale_2" - if scale_key not in state or scale2_key not in state: - raise RuntimeError( - f"Cannot apply LoRA delta to quantized param " - f"'{param_name}': missing {scale_key} or {scale2_key}." - ) - - ws_param = state[scale_key] - ws2_param = state[scale2_key] - out_features, in_features = delta.shape - - saved_state[param_name] = param.data.clone() - saved_state[scale_key] = ws_param.data.clone() - saved_state[scale2_key] = ws2_param.data.clone() - - bf16 = _dequantize_fp4_weight( - param.data, - ws_param.data, - ws2_param.data, - out_features, - in_features, - ) - bf16.add_(delta.to(bf16.device, bf16.dtype), alpha=sign) - - linear_mod = module_dict.get(base) - if linear_mod is None or not isinstance(linear_mod, Linear): - raise RuntimeError( - f"Packed FP4 LoRA merge: could not find Linear module at '{base}'." + logger.warning( + f"Shape mismatch for LoRA param '{param_name}': " + f"param={list(param.shape)}, delta={list(delta.shape)}. " + f"Skipping." ) - - # Packed FP4: keep the LoRA-merged weight in BF16 for stage 2. - # This replaces packed FP4 storage (out, in//2) with BF16 storage - # (out, in) and swaps the parent Linear to plain F.linear. - param.data = bf16 - saved_state[f"__quant_method__{base}"] = linear_mod.quant_method - linear_mod.quant_method = UnquantizedLinearMethod() - snapshot_required += 1 - applied += 1 - else: - logger.warning( - f"Shape mismatch for LoRA param '{param_name}': " - f"param={list(param.shape)}, delta={list(delta.shape)}. " - f"Skipping." - ) + except Exception: + if saved_state: + _restore_lora_state(module, saved_state) + if applied_deltas: + _subtract_dense_lora_deltas(module, applied_deltas, saved_state) + raise return applied, saved_state, snapshot_required @@ -614,6 +651,221 @@ def _restore_lora_state( param.data = data +@dataclass +class _PersistentLoRAParamState: + param_name: str + precision: str + weight_param: torch.nn.Parameter + original_weight: torch.Tensor + merged_weight: torch.Tensor + scale_params: Dict[str, torch.nn.Parameter] = field(default_factory=dict) + original_scales: Dict[str, torch.Tensor] = field(default_factory=dict) + merged_scales: Dict[str, torch.Tensor] = field(default_factory=dict) + linear_module: Optional[Linear] = None + original_quant_method: Optional[Any] = None + merged_quant_method: Optional[Any] = None + + +class _PersistentLoRAWeightCache: + """Keep unmerged and merged LoRA-touched weights resident. + + The cache is opt-in and is used only for LTX-2 Stage 2 distilled LoRA. + It removes per-request merge/unmerge math by rebinding parameter storage to + precomputed resident tensors. FP8 and FP4 keep exact original quantized + state. FP4's merged state is BF16 and swaps the parent Linear to + UnquantizedLinearMethod, matching the existing per-request Stage 2 path. + """ + + def __init__( + self, + entries: List[_PersistentLoRAParamState], + ) -> None: + self._entries = entries + self._bound_state = "original" + self.applied_count = len(entries) + + @staticmethod + def _module_state( + module: torch.nn.Module, + ) -> tuple[Dict[str, torch.nn.Parameter], Dict[str, torch.nn.Module]]: + state: Dict[str, torch.nn.Parameter] = {} + for raw_name, param in module.named_parameters(): + clean = raw_name.replace("._orig_mod.", ".") + state[clean] = param + + module_dict: Dict[str, torch.nn.Module] = {} + for raw_name, mod in module.named_modules(): + clean = raw_name.replace("._orig_mod.", ".") + module_dict[clean] = mod + + return state, module_dict + + @classmethod + def build( + cls, + module: torch.nn.Module, + deltas: Dict[str, torch.Tensor], + ) -> "_PersistentLoRAWeightCache": + state, module_dict = cls._module_state(module) + entries: List[_PersistentLoRAParamState] = [] + + for name, delta in deltas.items(): + param_name = name if name in state else f"{name}.weight" + if param_name not in state: + continue + + param = state[param_name] + base = param_name.rsplit(".weight", 1)[0] + + if param.shape == delta.shape: + if param.dtype in _FP8_DTYPES: + scale_key = f"{base}.weight_scale" + if scale_key not in state: + raise RuntimeError( + f"Cannot build persistent LoRA state for FP8 param " + f"'{param_name}': missing {scale_key}." + ) + + ws_param = state[scale_key] + out_f, in_f = delta.shape + is_per_tensor = ws_param.data.numel() == 1 + is_packed = not is_per_tensor and _is_fp8_scale_packed( + ws_param.data, out_f, in_f + ) + + bf16 = _dequantize_fp8_weight( + param.data, + ws_param.data, + packed=is_packed, + ) + bf16.add_(delta.to(bf16.device, bf16.dtype)) + qw, new_scale = _requantize_fp8_weight( + bf16, + repack=is_packed, + per_tensor=is_per_tensor, + ) + new_scale = _scale_like(new_scale, ws_param.data) + + entries.append( + _PersistentLoRAParamState( + param_name=param_name, + precision="fp8", + weight_param=param, + original_weight=param.data, + merged_weight=qw, + scale_params={scale_key: ws_param}, + original_scales={scale_key: ws_param.data}, + merged_scales={scale_key: new_scale}, + ) + ) + else: + merged = param.data.clone() + merged.add_(delta.to(merged.device, merged.dtype)) + precision = "bf16" if param.dtype == torch.bfloat16 else str(param.dtype) + entries.append( + _PersistentLoRAParamState( + param_name=param_name, + precision=precision, + weight_param=param, + original_weight=param.data, + merged_weight=merged, + ) + ) + continue + + if ( + param.ndim == 2 + and delta.ndim == 2 + and param.shape[0] == delta.shape[0] + and param.shape[1] * 2 == delta.shape[1] + ): + scale_key = f"{base}.weight_scale" + scale2_key = f"{base}.weight_scale_2" + if scale_key not in state or scale2_key not in state: + raise RuntimeError( + f"Cannot build persistent LoRA state for packed FP4 param " + f"'{param_name}': missing {scale_key} or {scale2_key}." + ) + + ws_param = state[scale_key] + ws2_param = state[scale2_key] + out_features, in_features = delta.shape + + bf16 = _dequantize_fp4_weight( + param.data, + ws_param.data, + ws2_param.data, + out_features, + in_features, + ) + bf16.add_(delta.to(bf16.device, bf16.dtype)) + + linear_mod = module_dict.get(base) + if linear_mod is None or not isinstance(linear_mod, Linear): + raise RuntimeError( + f"Packed FP4 persistent LoRA state: could not find " + f"Linear module at '{base}'." + ) + + entries.append( + _PersistentLoRAParamState( + param_name=param_name, + precision="fp4", + weight_param=param, + original_weight=param.data, + merged_weight=bf16, + scale_params={ + scale_key: ws_param, + scale2_key: ws2_param, + }, + original_scales={ + scale_key: ws_param.data, + scale2_key: ws2_param.data, + }, + merged_scales={ + scale_key: ws_param.data, + scale2_key: ws2_param.data, + }, + linear_module=linear_mod, + original_quant_method=linear_mod.quant_method, + merged_quant_method=UnquantizedLinearMethod(), + ) + ) + continue + + logger.warning( + f"Shape mismatch for persistent LoRA param '{param_name}': " + f"param={list(param.shape)}, delta={list(delta.shape)}. " + f"Skipping." + ) + + return cls(entries) + + def bind_original(self) -> None: + for entry in self._entries: + entry.weight_param.data = entry.original_weight + for scale_name, scale_param in entry.scale_params.items(): + scale_param.data = entry.original_scales[scale_name] + if entry.linear_module is not None: + entry.linear_module.quant_method = entry.original_quant_method + self._bound_state = "original" + + def bind_merged(self) -> None: + for entry in self._entries: + entry.weight_param.data = entry.merged_weight + for scale_name, scale_param in entry.scale_params.items(): + scale_param.data = entry.merged_scales[scale_name] + if entry.linear_module is not None: + entry.linear_module.quant_method = entry.merged_quant_method + self._bound_state = "merged" + + def precision_counts(self) -> Dict[str, int]: + counts: Dict[str, int] = {} + for entry in self._entries: + counts[entry.precision] = counts.get(entry.precision, 0) + 1 + return counts + + # --------------------------------------------------------------------------- # Pipeline # --------------------------------------------------------------------------- @@ -649,6 +901,9 @@ def load_standard_components( skip_components: Optional[list] = None, **kwargs, ) -> None: + # The BF16 snapshot threshold is a whole-pipeline capacity gate, so + # record it before loading model/runtime components that consume HBM. + self._bf16_snapshot_preload_free_gib = _get_free_gpu_memory_gib(device=device) super().load_standard_components( checkpoint_dir, device, @@ -697,6 +952,7 @@ def load_standard_components( # --- Distilled LoRA (pre-compute deltas) --- self._distilled_lora_deltas: Dict[str, torch.Tensor] = {} + self._distilled_lora_weight_cache: Optional[_PersistentLoRAWeightCache] = None if distilled_lora_path: logger.info(f"Loading distilled LoRA from {distilled_lora_path}...") self._distilled_lora_deltas = _load_lora_deltas( @@ -707,6 +963,26 @@ def load_standard_components( logger.info( f"Distilled LoRA ready: {len(self._distilled_lora_deltas)} parameter deltas" ) + if _persistent_lora_weights_enabled(): + try: + self._distilled_lora_weight_cache = _PersistentLoRAWeightCache.build( + self.transformer, + self._distilled_lora_deltas, + ) + except torch.cuda.OutOfMemoryError as exc: + logger.warning( + "Persistent LTX-2 LoRA weights disabled after CUDA OOM " + f"during cache build: {exc}" + ) + torch.cuda.empty_cache() + self._distilled_lora_weight_cache = None + else: + self._distilled_lora_weight_cache.bind_original() + logger.info( + "Persistent LTX-2 LoRA weights ready: " + f"{self._distilled_lora_weight_cache.applied_count} params, " + f"precision_counts={self._distilled_lora_weight_cache.precision_counts()}" + ) # ------------------------------------------------------------------ # Inference entry point @@ -845,26 +1121,40 @@ def forward( # ================================================================ # Stage 2: refinement denoising with distilled LoRA # ================================================================ - # For FP4 models (static-packed or dynamic), stage 2 always runs in - # BF16: the quant_method is swapped to UnquantizedLinearMethod inside - # _apply_lora_deltas and restored afterwards. FP8 handling is also - # unchanged and always restores from saved quantized state. BF16 weights - # save snapshots only when enough free GPU memory is available; - # otherwise they fall back to on-the-fly LoRA subtraction. - save_bf16_weights = _should_save_bf16_weights() - n, saved_lora_state, snapshot_required = _apply_lora_deltas( - self.transformer, - self._distilled_lora_deltas, - sign=1.0, - save_bf16_weights=save_bf16_weights, - ) - logger.info(f"Merged distilled LoRA ({n} params) for stage 2 (BF16 weights)") - - # Disable Ulysses for Stage 2: only rank 0 is active, so - # cross-rank collectives in the attention backend would hang. - self.transformer.set_ulysses_enabled(False) + # The default path merges LoRA per request and restores afterwards. + # When TRTLLM_LTX2_PERSISTENT_LORA_WEIGHTS is enabled, the resident + # cache already owns original and merged tensors. Stage 2 only rebinds + # pointers and FP4 quant_method state, so no per-request clone, merge, or + # unmerge math is needed. + lora_cache = self._distilled_lora_weight_cache + using_persistent_lora = lora_cache is not None + saved_lora_state: Dict[str, Any] = {} + snapshot_required = 0 + n = 0 stage2_start = time.time() try: + if using_persistent_lora: + lora_cache.bind_merged() + n = lora_cache.applied_count + logger.info(f"Bound persistent distilled LoRA ({n} params) for stage 2") + else: + transformer_device = next(self.transformer.parameters()).device + preload_free_gib = getattr(self, "_bf16_snapshot_preload_free_gib", None) + save_bf16_weights = _should_save_bf16_weights( + device=transformer_device, + preload_free_gib=preload_free_gib, + ) + n, saved_lora_state, snapshot_required = _apply_lora_deltas( + self.transformer, + self._distilled_lora_deltas, + sign=1.0, + save_bf16_weights=save_bf16_weights, + ) + logger.info(f"Merged distilled LoRA ({n} params) for stage 2 (BF16 weights)") + + # Disable Ulysses for Stage 2: only rank 0 is active, so + # cross-rank collectives in the attention backend would hang. + self.transformer.set_ulysses_enabled(False) video_latents, audio_latents = self._refinement_denoise( video_latents=video_latents, audio_latents=audio_latents, @@ -882,31 +1172,35 @@ def forward( stage2_denoise_time = time.time() - stage2_start logger.info(f"Stage 2 denoising time: {stage2_denoise_time:.2f}s (BF16 weights)") self.transformer.set_ulysses_enabled(True) - if snapshot_required and not saved_lora_state: - raise RuntimeError( - "LoRA state was not saved; cannot safely restore stage 2 weights." - ) - - snapshot_restored = 0 - if snapshot_required: - # Restore every LoRA-touched parameter from its snapshot. Packed - # FP4 also restores the original quant_method. - _restore_lora_state(self.transformer, saved_lora_state) - snapshot_restored = _count_saved_lora_weight_tensors(saved_lora_state) + if using_persistent_lora: + lora_cache.bind_original() + logger.info("Re-bound persistent distilled LoRA original weights after stage 2") + else: + if snapshot_required and not saved_lora_state: + raise RuntimeError( + "LoRA state was not saved; cannot safely restore stage 2 weights." + ) - # BF16 weights that were not snapshotted, plus any other dense - # floating-point weights, are restored by subtracting LoRA deltas. - dense_restored = _subtract_dense_lora_deltas( - self.transformer, - self._distilled_lora_deltas, - saved_lora_state, - ) - restored = snapshot_restored + dense_restored - if restored != n: - raise RuntimeError( - f"Restored {restored} LoRA-touched weights after stage 2, but {n} were applied." + snapshot_restored = 0 + if snapshot_required: + # Restore every LoRA-touched parameter from its snapshot. Packed + # FP4 also restores the original quant_method. + _restore_lora_state(self.transformer, saved_lora_state) + snapshot_restored = _count_saved_lora_weight_tensors(saved_lora_state) + + # BF16 weights that were not snapshotted are restored by + # subtracting LoRA deltas. FP8 and FP4 are exact snapshot restores. + dense_restored = _subtract_dense_lora_deltas( + self.transformer, + self._distilled_lora_deltas, + saved_lora_state, ) - logger.info("Un-merged distilled LoRA after stage 2") + restored = snapshot_restored + dense_restored + if restored != n: + raise RuntimeError( + f"Restored {restored} LoRA-touched weights after stage 2, but {n} were applied." + ) + logger.info("Un-merged distilled LoRA after stage 2") # ================================================================ # Decode diff --git a/tests/unittest/_torch/visual_gen/test_ltx2_pipeline.py b/tests/unittest/_torch/visual_gen/test_ltx2_pipeline.py index a5109b3cb01e..756c8b283e24 100644 --- a/tests/unittest/_torch/visual_gen/test_ltx2_pipeline.py +++ b/tests/unittest/_torch/visual_gen/test_ltx2_pipeline.py @@ -21,6 +21,7 @@ from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm._torch.visual_gen.config import DiffusionModelConfig +from tensorrt_llm._torch.visual_gen.models.ltx2 import pipeline_ltx2_two_stages as ltx2_two_stages from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2_FORCE_ONE_STAGE_ENV from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineComponent, PipelineLoader from tensorrt_llm.visual_gen.args import AttentionConfig, CacheDiTConfig, VisualGenArgs @@ -1104,6 +1105,262 @@ def ltx2_two_stage_assets_exist(): return True +class TestLTX2TwoStageLoRAHelpers: + """Test LTX-2 two-stage distilled LoRA helpers without model loading.""" + + def test_bf16_snapshot_gate_uses_preload_memory(self, monkeypatch): + """Pre-load memory is the whole-pipeline budget and wins over current memory.""" + queried_devices = [] + + def fake_get_free_gpu_memory_gib(device=None): + queried_devices.append(device) + return 1.0 + + monkeypatch.setattr( + ltx2_two_stages, + "_get_free_gpu_memory_gib", + fake_get_free_gpu_memory_gib, + ) + + assert ltx2_two_stages._should_save_bf16_weights( + device="cuda:1", + preload_free_gib=120.0, + threshold_gib=115.0, + ) + assert not ltx2_two_stages._should_save_bf16_weights( + device="cuda:1", + preload_free_gib=110.0, + threshold_gib=115.0, + ) + assert queried_devices == [] + + def test_bf16_snapshot_gate_fallback_passes_device(self, monkeypatch): + """Fallback free-memory query must target the transformer's CUDA device.""" + queried_devices = [] + + def fake_get_free_gpu_memory_gib(device=None): + queried_devices.append(device) + return 116.0 + + monkeypatch.setattr( + ltx2_two_stages, + "_get_free_gpu_memory_gib", + fake_get_free_gpu_memory_gib, + ) + + assert ltx2_two_stages._should_save_bf16_weights( + device="cuda:1", + preload_free_gib=None, + threshold_gib=115.0, + ) + assert queried_devices == ["cuda:1"] + + def test_persistent_bf16_cache_reuses_weight_storage(self): + """Persistent BF16 cache swaps between the same original and merged tensors.""" + + class TinyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter( + torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.bfloat16, + ) + ) + + module = TinyModule() + original = module.weight.detach().clone() + delta = torch.full_like(module.weight, 0.5) + + cache = ltx2_two_stages._PersistentLoRAWeightCache.build( + module, + {"weight": delta}, + ) + assert cache.applied_count == 1 + assert cache.precision_counts() == {"bf16": 1} + + original_ptr = module.weight.data_ptr() + cache.bind_merged() + merged_ptr = module.weight.data_ptr() + assert merged_ptr != original_ptr + assert torch.allclose(module.weight, original + delta) + + cache.bind_original() + assert module.weight.data_ptr() == original_ptr + assert torch.equal(module.weight, original) + + cache.bind_merged() + assert module.weight.data_ptr() == merged_ptr + assert torch.allclose(module.weight, original + delta) + + cache.bind_original() + assert module.weight.data_ptr() == original_ptr + assert torch.equal(module.weight, original) + + def test_persistent_fp8_cache_reuses_weight_and_scale_storage(self): + """Persistent FP8 cache swaps both quantized weight and scale storage.""" + + class TinyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Module() + bf16_weight = torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.bfloat16, + ) + weight_scale = torch.tensor(0.25, dtype=torch.float32) + self.proj.weight = torch.nn.Parameter( + (bf16_weight.float() / weight_scale).to(torch.float8_e4m3fn), + requires_grad=False, + ) + self.proj.weight_scale = torch.nn.Parameter( + weight_scale, + requires_grad=False, + ) + + module = TinyModule() + original_weight = module.proj.weight.detach().clone() + original_scale = module.proj.weight_scale.detach().clone() + delta = torch.full((2, 2), 0.5, dtype=torch.bfloat16) + + cache = ltx2_two_stages._PersistentLoRAWeightCache.build( + module, + {"proj.weight": delta}, + ) + assert cache.applied_count == 1 + assert cache.precision_counts() == {"fp8": 1} + + original_weight_ptr = module.proj.weight.data_ptr() + original_scale_ptr = module.proj.weight_scale.data_ptr() + cache.bind_merged() + merged_weight_ptr = module.proj.weight.data_ptr() + merged_scale_ptr = module.proj.weight_scale.data_ptr() + assert merged_weight_ptr != original_weight_ptr + assert merged_scale_ptr != original_scale_ptr + assert module.proj.weight.dtype == torch.float8_e4m3fn + + cache.bind_original() + assert module.proj.weight.data_ptr() == original_weight_ptr + assert module.proj.weight_scale.data_ptr() == original_scale_ptr + assert torch.equal(module.proj.weight, original_weight) + assert torch.equal(module.proj.weight_scale, original_scale) + + cache.bind_merged() + assert module.proj.weight.data_ptr() == merged_weight_ptr + assert module.proj.weight_scale.data_ptr() == merged_scale_ptr + + cache.bind_original() + assert module.proj.weight.data_ptr() == original_weight_ptr + assert module.proj.weight_scale.data_ptr() == original_scale_ptr + + def test_persistent_fp4_cache_swaps_quant_method_and_weight_storage(self, monkeypatch): + """Persistent FP4 cache binds merged BF16 weight and restores packed state.""" + + class TinyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = Linear( + 2, + 2, + bias=False, + dtype=torch.bfloat16, + skip_create_weights_in_init=True, + reduce_output=False, + ) + self.linear.weight = torch.nn.Parameter( + torch.zeros((2, 1), dtype=torch.uint8), + requires_grad=False, + ) + self.linear.weight_scale = torch.nn.Parameter( + torch.ones((128, 4), dtype=torch.uint8), + requires_grad=False, + ) + self.linear.weight_scale_2 = torch.nn.Parameter( + torch.ones((), dtype=torch.float32), + requires_grad=False, + ) + self.linear.quant_method = object() + + module = TinyModule() + original_quant_method = module.linear.quant_method + original_weight_ptr = module.linear.weight.data_ptr() + original_weight = module.linear.weight.detach().clone() + delta = torch.full((2, 2), 0.5, dtype=torch.bfloat16) + + def fake_dequantize_fp4_weight( + packed_weight, + interleaved_scale, + weight_scale_2, + out_features, + in_features, + ): + assert packed_weight.data_ptr() == original_weight_ptr + assert out_features == 2 + assert in_features == 2 + return torch.ones((2, 2), dtype=torch.bfloat16) + + monkeypatch.setattr( + ltx2_two_stages, + "_dequantize_fp4_weight", + fake_dequantize_fp4_weight, + ) + + cache = ltx2_two_stages._PersistentLoRAWeightCache.build( + module, + {"linear.weight": delta}, + ) + assert cache.applied_count == 1 + assert cache.precision_counts() == {"fp4": 1} + + cache.bind_merged() + merged_weight_ptr = module.linear.weight.data_ptr() + assert merged_weight_ptr != original_weight_ptr + assert module.linear.weight.dtype == torch.bfloat16 + assert torch.equal(module.linear.weight, torch.ones_like(delta) + delta) + assert isinstance(module.linear.quant_method, ltx2_two_stages.UnquantizedLinearMethod) + + cache.bind_original() + assert module.linear.weight.data_ptr() == original_weight_ptr + assert torch.equal(module.linear.weight, original_weight) + assert module.linear.quant_method is original_quant_method + + cache.bind_merged() + assert module.linear.weight.data_ptr() == merged_weight_ptr + cache.bind_original() + assert module.linear.weight.data_ptr() == original_weight_ptr + + def test_apply_lora_deltas_rolls_back_dense_weights_on_failure(self): + """A later merge failure must not leave earlier dense weights LoRA-merged.""" + + class TinyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.good = torch.nn.Parameter( + torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.bfloat16, + ) + ) + self.bad = torch.nn.Parameter(torch.zeros((2, 1), dtype=torch.bfloat16)) + + module = TinyModule() + original_good = module.good.detach().clone() + deltas = { + "good": torch.full_like(module.good, 0.5), + "bad": torch.ones((2, 2), dtype=torch.bfloat16), + } + + with pytest.raises(RuntimeError, match="missing bad.weight_scale"): + ltx2_two_stages._apply_lora_deltas( + module, + deltas, + sign=1.0, + save_bf16_weights=False, + ) + + assert torch.equal(module.good, original_good) + + class TestLTX2TwoStagePipelineLoading: """Test two-stage pipeline loading via PipelineLoader."""