From 9ed18a7043d52ae71cfc42f7597dd06f60eb1334 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Wed, 25 Sep 2024 15:23:38 +0900 Subject: [PATCH 1/5] chore: adjust log level in cpp trt logger --- core/runtime/TRTEngine.cpp | 7 +++---- core/runtime/TRTEngine.h | 5 +---- core/util/logging/TorchTRTLogger.cpp | 2 +- py/torch_tensorrt/dynamo/utils.py | 13 +++++++++++++ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c2b9e6c35d..27e7d9d31a 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -201,10 +201,9 @@ TRTEngine::TRTEngine( } num_io = std::make_pair(inputs_size, outputs); } - -#ifndef NDEBUG - this->enable_profiling(); -#endif + if (util::logging::get_logger().get_reportable_log_level() == util::logging::LogLevel::kDEBUG) { + this->enable_profiling(); + } LOG_DEBUG(*this); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index ebd5645d59..4314cfe535 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -87,11 +87,8 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); -#ifndef NDEBUG - bool profile_execution = true; -#else bool profile_execution = false; -#endif + std::string device_profile_path; std::string input_profile_path; std::string output_profile_path; diff --git a/core/util/logging/TorchTRTLogger.cpp b/core/util/logging/TorchTRTLogger.cpp index 97d0ef01db..cf7cb73705 100644 --- a/core/util/logging/TorchTRTLogger.cpp +++ b/core/util/logging/TorchTRTLogger.cpp @@ -125,7 +125,7 @@ namespace { TorchTRTLogger& get_global_logger() { #ifndef NDEBUG - static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kDEBUG, true); + static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kINFO, true); #else static TorchTRTLogger global_logger("[Torch-TensorRT] - ", LogLevel::kWARNING, false); #endif diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index ee11e597a1..1ef9f6481c 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -208,6 +208,19 @@ def set_log_level(parent_logger: Any, level: Any) -> None: if parent_logger: parent_logger.setLevel(level) + if level == logging.DEBUG: + log_level = trt.ILogger.Severity.VERBOSE + elif level == logging.INFO: + log_level = trt.ILogger.Severity.INFO + elif level == logging.WARNING: + log_level = trt.ILogger.Severity.WARNING + elif level == logging.ERROR: + log_level = trt.ILogger.Severity.ERROR + elif level == logging.CRITICAL: + log_level = trt.ILogger.Severity.INTERNAL_ERROR + + torch.ops.tensorrt.set_logging_level(int(log_level)) + def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], From b8c3abe08d54f763adaaba051af81f809ebd4d80 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Thu, 26 Sep 2024 10:28:49 +0900 Subject: [PATCH 2/5] chore: invalid log level check --- py/torch_tensorrt/dynamo/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1ef9f6481c..65a74fbe26 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -204,6 +204,7 @@ def set_log_level(parent_logger: Any, level: Any) -> None: Sets the log level to the user provided level. This is used to set debug logging at a global level at entry points of tracing, dynamo and torch_compile compilation. + It also set log level for c++ torch trt logger """ if parent_logger: parent_logger.setLevel(level) @@ -218,6 +219,8 @@ def set_log_level(parent_logger: Any, level: Any) -> None: log_level = trt.ILogger.Severity.ERROR elif level == logging.CRITICAL: log_level = trt.ILogger.Severity.INTERNAL_ERROR + else: + raise AssertionError(f"{level} is valid log level") torch.ops.tensorrt.set_logging_level(int(log_level)) From 48f1f9377211707a272cb0fbb83d5357a379a489 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Thu, 26 Sep 2024 10:52:12 +0900 Subject: [PATCH 3/5] chore: enable profile for develop version --- core/runtime/TRTEngine.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 27e7d9d31a..4f90ce5653 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -201,9 +201,11 @@ TRTEngine::TRTEngine( } num_io = std::make_pair(inputs_size, outputs); } +#ifndef NDEBUG if (util::logging::get_logger().get_reportable_log_level() == util::logging::LogLevel::kDEBUG) { this->enable_profiling(); } +#endif LOG_DEBUG(*this); } From 6964b381986ede52d1d6cacd56265ac834e83938 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 30 Sep 2024 14:28:07 +0900 Subject: [PATCH 4/5] chore: Only set native debugging level in set_log_level() --- core/runtime/TRTEngine.cpp | 5 ++--- core/runtime/TRTEngine.h | 5 ++++- core/util/logging/TorchTRTLogger.cpp | 2 +- py/torch_tensorrt/dynamo/utils.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 4f90ce5653..c2b9e6c35d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -201,10 +201,9 @@ TRTEngine::TRTEngine( } num_io = std::make_pair(inputs_size, outputs); } + #ifndef NDEBUG - if (util::logging::get_logger().get_reportable_log_level() == util::logging::LogLevel::kDEBUG) { - this->enable_profiling(); - } + this->enable_profiling(); #endif LOG_DEBUG(*this); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 4314cfe535..ebd5645d59 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -87,8 +87,11 @@ struct TRTEngine : torch::CustomClassHolder { // c10::List Run(c10::List inputs); void set_profiling_paths(); +#ifndef NDEBUG + bool profile_execution = true; +#else bool profile_execution = false; - +#endif std::string device_profile_path; std::string input_profile_path; std::string output_profile_path; diff --git a/core/util/logging/TorchTRTLogger.cpp b/core/util/logging/TorchTRTLogger.cpp index cf7cb73705..97d0ef01db 100644 --- a/core/util/logging/TorchTRTLogger.cpp +++ b/core/util/logging/TorchTRTLogger.cpp @@ -125,7 +125,7 @@ namespace { TorchTRTLogger& get_global_logger() { #ifndef NDEBUG - static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kINFO, true); + static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kDEBUG, true); #else static TorchTRTLogger global_logger("[Torch-TensorRT] - ", LogLevel::kWARNING, false); #endif diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 65a74fbe26..1f73c08d0e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -220,7 +220,7 @@ def set_log_level(parent_logger: Any, level: Any) -> None: elif level == logging.CRITICAL: log_level = trt.ILogger.Severity.INTERNAL_ERROR else: - raise AssertionError(f"{level} is valid log level") + raise AssertionError(f"{level} is not valid log level") torch.ops.tensorrt.set_logging_level(int(log_level)) From 0790de727011590a90f27bfe9515310348f6493a Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Wed, 2 Oct 2024 12:04:28 +0900 Subject: [PATCH 5/5] chore: set native logging level when runtime is available --- py/torch_tensorrt/dynamo/utils.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1f73c08d0e..a85494239e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -11,6 +11,7 @@ from torch._subclasses.fake_tensor import FakeTensor from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device @@ -204,25 +205,26 @@ def set_log_level(parent_logger: Any, level: Any) -> None: Sets the log level to the user provided level. This is used to set debug logging at a global level at entry points of tracing, dynamo and torch_compile compilation. - It also set log level for c++ torch trt logger + And set log level for c++ torch trt logger if runtime is available. """ if parent_logger: parent_logger.setLevel(level) - if level == logging.DEBUG: - log_level = trt.ILogger.Severity.VERBOSE - elif level == logging.INFO: - log_level = trt.ILogger.Severity.INFO - elif level == logging.WARNING: - log_level = trt.ILogger.Severity.WARNING - elif level == logging.ERROR: - log_level = trt.ILogger.Severity.ERROR - elif level == logging.CRITICAL: - log_level = trt.ILogger.Severity.INTERNAL_ERROR - else: - raise AssertionError(f"{level} is not valid log level") + if ENABLED_FEATURES.torch_tensorrt_runtime: + if level == logging.DEBUG: + log_level = trt.ILogger.Severity.VERBOSE + elif level == logging.INFO: + log_level = trt.ILogger.Severity.INFO + elif level == logging.WARNING: + log_level = trt.ILogger.Severity.WARNING + elif level == logging.ERROR: + log_level = trt.ILogger.Severity.ERROR + elif level == logging.CRITICAL: + log_level = trt.ILogger.Severity.INTERNAL_ERROR + else: + raise AssertionError(f"{level} is not valid log level") - torch.ops.tensorrt.set_logging_level(int(log_level)) + torch.ops.tensorrt.set_logging_level(int(log_level)) def prepare_inputs(