From d648a2ed17bc8ccb54ca4a20ade6b7641b3ab8b8 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Fri, 5 Jun 2026 05:32:37 -0700 Subject: [PATCH] [None][fix] configure DeepGEMM PDL during engine init Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 2 -- tensorrt_llm/_torch/pyexecutor/model_engine.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 9b2e208b7973..f6fb7c38913f 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -55,8 +55,6 @@ # BufferKind is bound from C++; see cpp/tensorrt_llm/thop/outputTensor.h (torch_ext::BufferKind). from tensorrt_llm.bindings.internal.thop import BufferKind -deep_gemm.set_pdl(get_env_enable_pdl()) - # Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous. @torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", )) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 8d341103cbc9..e09eb9f88c59 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -173,6 +173,20 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], return result +_DEEP_GEMM_PDL_CONFIGURED = False + + +def _configure_deep_gemm_pdl() -> None: + global _DEEP_GEMM_PDL_CONFIGURED + if _DEEP_GEMM_PDL_CONFIGURED: + return + + from tensorrt_llm import deep_gemm + + deep_gemm.set_pdl(os.environ.get("TRTLLM_ENABLE_PDL", "1") == "1") + _DEEP_GEMM_PDL_CONFIGURED = True + + class PyTorchModelEngine(ModelEngine): def __init__( @@ -192,6 +206,8 @@ def __init__( model_weights_memory_tag: Optional[str] = None, model_weights_restore_mode=None, ): + _configure_deep_gemm_pdl() + self.forward_pass_callable = None self.ub_buffers = None (