From 6abe1c383bb60b9d7394d55e6eecd5238e213dd9 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Wed, 10 May 2023 19:03:40 +0000 Subject: [PATCH 1/8] remove unused --- py/torch_tensorrt/fx/fx2trt.py | 46 ++++++++++++ py/torch_tensorrt/fx/lower.py | 3 + py/torch_tensorrt/fx/lower_setting.py | 6 ++ py/torch_tensorrt/fx/passes/pass_utils.py | 87 ++++++++++++----------- 4 files changed, 100 insertions(+), 42 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d0a6bdf0a1..76414c1e2b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -163,6 +163,9 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + optimization_level=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -225,6 +228,18 @@ def run( if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if trt.__version__ >= "8.6": + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -251,6 +266,34 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine + import os + def get_file_name(org): + file_name = org + i = 0 + while os.path.exists(os.path.abspath(file_name)): + i += 1 + file_name = org + str(i) + return file_name + + engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') + if engine_file: + dump_file = get_file_name(engine_file) + print(f'Dumping engine to {dump_file}') + s = engine.serialize() + with open(dump_file, 'wb') as f: + f.write(s) + engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') + if engine_info_file: + inspector = engine.create_engine_inspector() + engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) + if engine_info is None or len(engine_info) == 0: + raise Exception('Engine info is empty') + else: + dump_file = get_file_name(engine_info_file) + print(f'Dumping engine info to {dump_file}') + with open(dump_file, 'w') as f: + f.write(engine_info) + serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -259,6 +302,9 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info( + f"TRT Engine uses: {engine.device_memory_size} Memory" + ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index f96f1db6b9..95198b463b 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -138,6 +138,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + optimization_level=self.lower_setting.optimization_level, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 07e7bf0dac..0bbd51ad12 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + optimization_level: builder optimization level """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index fabc92881d..14d22fb350 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -100,7 +100,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=True): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and @@ -114,48 +114,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From e700dfeef29bceceb7e17c9e2109a53eb24cb6e2 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Wed, 17 May 2023 17:22:36 +0000 Subject: [PATCH 2/8] update unit --- py/torch_tensorrt/fx/fx2trt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 76414c1e2b..026c59510d 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -303,7 +303,7 @@ def get_file_name(org): f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} Memory" + f"TRT Engine uses: {engine.device_memory_size} bytes of Memory" ) return TRTInterpreterResult( From 8a8b8dd9efa576f1972ec3c765c7404c90f97284 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Wed, 10 May 2023 19:03:40 +0000 Subject: [PATCH 3/8] remove unused --- py/torch_tensorrt/fx/fx2trt.py | 46 ++++++++++++ py/torch_tensorrt/fx/lower.py | 3 + py/torch_tensorrt/fx/lower_setting.py | 6 ++ py/torch_tensorrt/fx/passes/pass_utils.py | 87 ++++++++++++----------- 4 files changed, 100 insertions(+), 42 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d0a6bdf0a1..76414c1e2b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -163,6 +163,9 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + optimization_level=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -225,6 +228,18 @@ def run( if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if trt.__version__ >= "8.6": + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -251,6 +266,34 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine + import os + def get_file_name(org): + file_name = org + i = 0 + while os.path.exists(os.path.abspath(file_name)): + i += 1 + file_name = org + str(i) + return file_name + + engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') + if engine_file: + dump_file = get_file_name(engine_file) + print(f'Dumping engine to {dump_file}') + s = engine.serialize() + with open(dump_file, 'wb') as f: + f.write(s) + engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') + if engine_info_file: + inspector = engine.create_engine_inspector() + engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) + if engine_info is None or len(engine_info) == 0: + raise Exception('Engine info is empty') + else: + dump_file = get_file_name(engine_info_file) + print(f'Dumping engine info to {dump_file}') + with open(dump_file, 'w') as f: + f.write(engine_info) + serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -259,6 +302,9 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info( + f"TRT Engine uses: {engine.device_memory_size} Memory" + ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index f96f1db6b9..95198b463b 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -138,6 +138,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + optimization_level=self.lower_setting.optimization_level, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 07e7bf0dac..0bbd51ad12 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,9 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + optimization_level: builder optimization level """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -101,3 +104,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index fabc92881d..14d22fb350 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -100,7 +100,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=True): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and @@ -114,48 +114,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From a5f346d897f7d6c76a91f92c947010ff2da99d15 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Wed, 17 May 2023 17:22:36 +0000 Subject: [PATCH 4/8] update unit --- py/torch_tensorrt/fx/fx2trt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 76414c1e2b..026c59510d 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -303,7 +303,7 @@ def get_file_name(org): f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} Memory" + f"TRT Engine uses: {engine.device_memory_size} bytes of Memory" ) return TRTInterpreterResult( From ccaf0c06c6dc61bbf4f3234ffecbd5824f8b15a6 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Thu, 18 May 2023 17:03:30 -0700 Subject: [PATCH 5/8] remove engine dump --- py/torch_tensorrt/fx/fx2trt.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 026c59510d..4035a036a8 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -266,33 +266,7 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine - import os - def get_file_name(org): - file_name = org - i = 0 - while os.path.exists(os.path.abspath(file_name)): - i += 1 - file_name = org + str(i) - return file_name - - engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') - if engine_file: - dump_file = get_file_name(engine_file) - print(f'Dumping engine to {dump_file}') - s = engine.serialize() - with open(dump_file, 'wb') as f: - f.write(s) - engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') - if engine_info_file: - inspector = engine.create_engine_inspector() - engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) - if engine_info is None or len(engine_info) == 0: - raise Exception('Engine info is empty') - else: - dump_file = get_file_name(engine_info_file) - print(f'Dumping engine info to {dump_file}') - with open(dump_file, 'w') as f: - f.write(engine_info) + serialized_cache = ( bytearray(cache.serialize()) From 1ea8ed8911c771a5180cd9c77cf48aad4ab5681b Mon Sep 17 00:00:00 2001 From: Tin-Yin Lai <132402475+wu6u3tw@users.noreply.github.com> Date: Thu, 18 May 2023 17:17:21 -0700 Subject: [PATCH 6/8] Update fx2trt.py clean spacing --- py/torch_tensorrt/fx/fx2trt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index ec46ec962b..d2974decd7 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -272,8 +272,6 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine - - serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() From cbbeb475e4206c344ed55632b3b6f5dc5ca9fa7f Mon Sep 17 00:00:00 2001 From: tinyinl Date: Mon, 22 May 2023 17:53:44 +0000 Subject: [PATCH 7/8] fix format --- py/torch_tensorrt/fx/passes/pass_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 966ca3a311..4f1d650c6e 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -150,7 +150,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=True, run_alternative_batch_size: int = -1)-> "Decorator": +def validate_inference( + rtol=None, atol=None, suppress_accuracy_check_failure=True, + run_alternative_batch_size: int = -1 +)-> "Decorator": """ Returns a decorator on a PassFunc to sanity check the model outputs difference before/after the transformation is within tolerance. From fc15ac5630946d86dab9b3e02b9543c124ea98a9 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Mon, 22 May 2023 18:01:29 +0000 Subject: [PATCH 8/8] fix format --- py/torch_tensorrt/fx/fx2trt.py | 4 +--- py/torch_tensorrt/fx/passes/pass_utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d2974decd7..ec53fa928d 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -280,9 +280,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} bytes of Memory" - ) + _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 4f1d650c6e..fedb45fdf3 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -151,9 +151,11 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, suppress_accuracy_check_failure=True, - run_alternative_batch_size: int = -1 -)-> "Decorator": + rtol=None, + atol=None, + suppress_accuracy_check_failure=True, + run_alternative_batch_size: int = -1, +) -> "Decorator": """ Returns a decorator on a PassFunc to sanity check the model outputs difference before/after the transformation is within tolerance. @@ -183,7 +185,7 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - if suppress_accuracy_check_failure: + if suppress_accuracy_check_failure: return pass_(module, input, *args, **kwargs) else: res0 = module(*input)