Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/visual_gen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python models/wan_t2v.py

# With engine config (quant, parallelism, etc.)
python models/wan_t2v.py --visual_gen_args configs/wan2.2-t2v-fp4-1gpu.yaml
python models/ltx2.py --visual_gen_args configs/ltx2.yaml
```

Install deps from the repo root: `pip install -r requirements-dev.txt`.
Expand Down
33 changes: 33 additions & 0 deletions examples/visual_gen/configs/ltx2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 1-GPU LTX-2 text-to-video with audio.
# Shared by offline examples (--visual_gen_args) and trtllm-serve.
#
# LTX-2 stores the Gemma text encoder separately from the diffusion checkpoint.
# Set text_encoder_path to the local Gemma3 checkpoint directory before running
# against a local LTX-2 checkpoint.
pipeline_config:
text_encoder_path: google/gemma-3-12b-it

attention_config:
backend: VANILLA

parallel_config:
cfg_size: 1
ulysses_size: 1

cuda_graph_config:
enable: false
95 changes: 95 additions & 0 deletions examples/visual_gen/models/ltx2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LTX-2 Text-to-Video generation with audio.

Usage:
python ltx2.py
python ltx2.py --visual_gen_args ../configs/ltx2.yaml
"""

import argparse

from tensorrt_llm import VisualGen, VisualGenArgs


def main():
Comment thread
coderabbitai[bot] marked this conversation as resolved.
parser = argparse.ArgumentParser(description="LTX-2 Text-to-Video example")
parser.add_argument(
"--model",
type=str,
default="Lightricks/LTX-2",
help="Model path or HuggingFace Hub ID",
)
parser.add_argument(
"--visual_gen_args",
"--extra_visual_gen_options",
dest="visual_gen_args",
type=str,
default=None,
help="Path to YAML config (same as trtllm-serve --visual_gen_args)",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help=(
"Gemma3 text encoder path. Overrides pipeline_config.text_encoder_path "
"from --visual_gen_args when set."
),
)
parser.add_argument(
"--output_path",
type=str,
default="ltx2_t2v_output.mp4",
help="Path to save the output video",
)
args = parser.parse_args()

# LTX-2 requires pipeline_config.text_encoder_path for the Gemma3 text
# encoder. The YAML path is preferred for production configs; the default
# below keeps this script runnable as a minimal offline example.
extra_args = VisualGenArgs.from_yaml(args.visual_gen_args) if args.visual_gen_args else VisualGenArgs()
text_encoder_path = args.text_encoder_path
if text_encoder_path is None and not args.visual_gen_args:
text_encoder_path = "google/gemma-3-12b-it"
if text_encoder_path is not None:
extra_args.pipeline_config = {
**extra_args.pipeline_config,
"text_encoder_path": text_encoder_path,
}
visual_gen = VisualGen(model=args.model, args=extra_args)

# --- Model-specific: T2V request construction ---
# Start from LTX-2 defaults and override the main request shape explicitly.
params = visual_gen.default_params
params.height = 512
params.width = 768
params.num_frames = 121
params.frame_rate = 24.0
params.num_inference_steps = 40
params.guidance_scale = 4.0

output = visual_gen.generate(
inputs="A cinematic shot of a cat walking through a field of flowers",
params=params,
)

output.save(args.output_path)
print(f"Saved: {args.output_path}")


if __name__ == "__main__":
main()
Loading