diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 9b2e208b7973..f6fb7c38913f 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -55,8 +55,6 @@ # BufferKind is bound from C++; see cpp/tensorrt_llm/thop/outputTensor.h (torch_ext::BufferKind). from tensorrt_llm.bindings.internal.thop import BufferKind -deep_gemm.set_pdl(get_env_enable_pdl()) - # Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous. @torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", )) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 8d341103cbc9..e09eb9f88c59 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -173,6 +173,20 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], return result +_DEEP_GEMM_PDL_CONFIGURED = False + + +def _configure_deep_gemm_pdl() -> None: + global _DEEP_GEMM_PDL_CONFIGURED + if _DEEP_GEMM_PDL_CONFIGURED: + return + + from tensorrt_llm import deep_gemm + + deep_gemm.set_pdl(os.environ.get("TRTLLM_ENABLE_PDL", "1") == "1") + _DEEP_GEMM_PDL_CONFIGURED = True + + class PyTorchModelEngine(ModelEngine): def __init__( @@ -192,6 +206,8 @@ def __init__( model_weights_memory_tag: Optional[str] = None, model_weights_restore_mode=None, ): + _configure_deep_gemm_pdl() + self.forward_pass_callable = None self.ub_buffers = None (