From 24fa1055b3cfd17d5daf60c05d597722433636eb Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:26:00 -0700 Subject: [PATCH 1/5] fix: Add enabled/disabled sets for decompositions - Add sets to selectively enable or disable decompositions in Torch - Add new runtime argument `enable_experimental_decompositions` to enable all core aten decompositions, or a pre-selected subset thereof - Improve documentation of compilation settings overall --- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 23 ++ py/torch_tensorrt/dynamo/backend/backends.py | 2 +- py/torch_tensorrt/dynamo/compile.py | 10 +- .../dynamo/lowering/_decompositions.py | 232 ++++++++++++++++-- 5 files changed, 245 insertions(+), 23 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 20a6acb7ff..ec67a7a358 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -11,3 +11,4 @@ TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True +ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 4be44cd779..327140df6e 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -14,11 +14,33 @@ USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ) @dataclass class CompilationSettings: + """Compilation settings for Torch-TensorRT Dynamo Paths + + Args: + precision (torch.dtype): Model Layer precision + debug (bool): Whether to print out verbose debugging information + workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) + min_block_size (int): Minimum number of operators per TRT-Engine Block + torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) + max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine + version_compatible (bool): Provide version forward-compatibility for engine plan files + optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 + use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None + truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32 + enable_experimental_decompositions (bool): Whether to enable all core aten decompositions + or only a selected subset of them + """ + precision: torch.dtype = PRECISION debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE @@ -31,3 +53,4 @@ class CompilationSettings: use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 6efbf89e34..2b761970a1 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -48,7 +48,7 @@ def aot_torch_tensorrt_aten_backend( gm, sample_inputs, fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(), + decompositions=get_decompositions(settings.enable_experimental_decompositions), ) diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 86e7dd6688..70bb604b84 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -24,6 +24,7 @@ USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ) from torch_tensorrt.dynamo.backend.backends import _compile_module from torch_tensorrt.dynamo.lowering._fusers import ( @@ -63,6 +64,7 @@ def compile( optimization_level: Optional[int] = OPTIMIZATION_LEVEL, use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.fx.GraphModule: if debug: @@ -72,9 +74,10 @@ def compile( logger.warning( "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}" + "following arguments are supported: " + "{enabled_precisions, debug, workspace_size, min_block_size, " + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner, " + "enable_experimental_decompositions}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -115,6 +118,7 @@ def compile( "use_python_runtime": use_python_runtime, "truncate_long_and_double": truncate_long_and_double, "use_fast_partitioner": use_fast_partitioner, + "enable_experimental_decompositions": enable_experimental_decompositions, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 666d04e779..d2ef325f28 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,12 +1,183 @@ -from typing import Any, Callable, Dict - +from typing import Callable, Dict, Set import torch -from torch._decomp import OpOverload, core_aten_decompositions, register_decomposition - -DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()} +from torch._decomp import ( + register_decomposition, + core_aten_decompositions, + get_decompositions as get_torch_decompositions, +) aten = torch.ops.aten +_core_aten_decompositions: Dict[ + torch._ops.OpOverload, Callable +] = core_aten_decompositions() +enabled_decompositions: Set[torch._ops.OpOverload] = { + aten._adaptive_avg_pool2d_backward, + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.celu, + aten.col2im, + aten.count_nonzero, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.deg2rad, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.dot, + aten.elu, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten._euclidean_dist.default, + aten.expand_as, + aten.eye, + aten.fill, + aten.frac, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu, + aten.gelu_backward, + aten.glu_backward, + aten.grid_sampler_2d, + aten.hardshrink, + aten.hardshrink_backward, + aten.hardsigmoid, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add, + aten.index_add_, + aten.index_copy, + aten.index_copy_, + aten.index_fill, + aten.index_fill_, + aten.index_select, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten.leaky_relu, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.max_pool2d_with_indices_backward, + aten.mish, + aten.mse_loss, + aten.mse_loss_backward, + aten.mv, + aten.mvlgamma, + aten.nansum, + aten.nan_to_num, + aten.narrow, + # TODO: Disable the below operators once freezing is done + aten.native_batch_norm, + aten.native_batch_norm_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten.native_dropout_backward, + aten.native_group_norm, + aten.native_group_norm_backward, + aten.native_layer_norm, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm, + aten.ones, + aten.ones_like, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.renorm, + aten.renorm_, + aten.rot90, + aten.rsub.Scalar, + aten.rsub.Tensor, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward, + aten.sinc, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.softshrink_backward, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.stack, + aten.t, + aten.tanh_backward, + aten.threshold, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.tril.default, + aten.triu.default, + aten.unfold, + aten.unfold_backward, + aten.unfold_copy, + aten.upsample_bilinear2d, + aten.upsample_bilinear2d.vec, + aten.upsample_nearest2d_backward, + aten.xlogy, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, +} +disabled_decompositions: Set[torch._ops.OpOverload] = {} + +ENABLED_TORCH_DECOMPOSITIONS: Dict[ + torch._ops.OpOverload, Callable +] = get_torch_decompositions(enabled_decompositions) +TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {} + def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: """Replace inplace operation with functional equivalent @@ -14,8 +185,8 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 """ - @register_decomposition(aten_op, registry=DECOMPOSITIONS) # type: ignore[misc] - def inplace_op(*args: Any, **kwargs: Any) -> Any: + @register_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) + def inplace_op(*args, **kwargs): out = outplace_op(*args, **kwargs) return args[0].copy_(out) @@ -37,32 +208,34 @@ def inplace_op(*args: Any, **kwargs: Any) -> Any: replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) -@register_decomposition(aten.std, registry=DECOMPOSITIONS) # type: ignore[misc] -def std_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: +@register_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) +def std_replacement(*args, **kwargs) -> torch.Tensor: return torch.sqrt(torch.var(*args, **kwargs)) -@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) # type: ignore[misc] -def rsqrt_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: +@register_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: return torch.reciprocal(torch.sqrt(*args, **kwargs)) -@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) # type: ignore[misc] -def unsafe_view_replacement(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: +@register_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return torch.reshape(x, *args, **kwargs) -@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_decomposition( + torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS +) def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(aten.alias, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS) def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_decomposition(torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS) def addmm_replacement( input_: torch.Tensor, mat1: torch.Tensor, @@ -76,12 +249,33 @@ def addmm_replacement( ) -@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_decomposition( + torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS +) def reciprocal_replacement( input_: torch.Tensor, ) -> torch.Tensor: return torch.div(1, input_) -def get_decompositions() -> Dict[OpOverload, Callable[..., Any]]: - return DECOMPOSITIONS +def get_decompositions( + enable_experimental_decompositions: bool = False, +) -> Dict[torch._ops.OpOverload, Callable]: + if enable_experimental_decompositions: + CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[torch._ops.OpOverload, Callable] = { + decomp: _core_aten_decompositions[decomp] + for decomp in _core_aten_decompositions + if ( + decomp not in TORCH_TRT_DECOMPOSITIONS + and decomp not in disabled_decompositions + ) + } + return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} + else: + duplicate_registrations = set(ENABLED_TORCH_DECOMPOSITIONS.keys()).intersection( + set(TORCH_TRT_DECOMPOSITIONS.keys()) + ) + assert ( + not duplicate_registrations + ), f"Detected duplicate decompositions on: {duplicate_registrations}" + return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS} From a14fb592a51b4f3154845b98ec57f499efb15eba Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 7 Aug 2023 17:38:59 -0700 Subject: [PATCH 2/5] feat: Add registration-time warnings for decomp. - Add decorator-wrapper to perform import-time checks on decompositions and alert the user if any custom decompositions conflict with existing registered or specified operators - Simplify code logic for dictionary merging in `get_decompositions` function - Add safety logic to ensure invariants about the decompositions are not violated --- .../dynamo/lowering/_decompositions.py | 103 +++++++++++++----- 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index d2ef325f28..e5930db498 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,17 +1,19 @@ -from typing import Callable, Dict, Set +from typing import Any, Callable, Dict, Set import torch +import logging from torch._decomp import ( register_decomposition, core_aten_decompositions, get_decompositions as get_torch_decompositions, ) +from torch._ops import OpOverload aten = torch.ops.aten _core_aten_decompositions: Dict[ - torch._ops.OpOverload, Callable + OpOverload, Callable ] = core_aten_decompositions() -enabled_decompositions: Set[torch._ops.OpOverload] = { +torch_enabled_decompositions: Set[OpOverload] = { aten._adaptive_avg_pool2d_backward, aten.addcdiv, aten.addcdiv_, @@ -171,12 +173,66 @@ aten.zeros, aten.zeros_like, } -disabled_decompositions: Set[torch._ops.OpOverload] = {} +torch_disabled_decompositions: Set[OpOverload] = {} + ENABLED_TORCH_DECOMPOSITIONS: Dict[ - torch._ops.OpOverload, Callable -] = get_torch_decompositions(enabled_decompositions) -TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {} + OpOverload, Callable +] = get_torch_decompositions(torch_enabled_decompositions) +TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable] = {} + + +logger = logging.getLogger(__name__) + + +def check_decomp_set_invariants(): + """Validates no overlap between enabled and disabled decomposition sets""" + overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) + + if overlap: + raise AssertionError( + f"Detected {overlap} registered in both torch_enabled_decompositions " + "and torch_disabled_decompositions. Ensure all operator(s) are in " + "at most one of the two sets." + ) + + +check_decomp_set_invariants() + + +def register_torch_trt_decomposition(aten_op, registry=None): + """Checks if the decomposition already exists in one of the sets + Registers the decomposition via the Torch utility + + Alerts the user if the decomposition already exists, before registering + Throws an AssertionError if the user attempts to register a decomposition + which is present in the set of explicitly disabled decompositions + """ + if aten_op in torch_enabled_decompositions: + logger.warning( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in torch_enabled_decompositions. " + "The custom implementation will take precedence." + ) + elif aten_op in torch_disabled_decompositions: + logger.info( + f"Detected custom decomposition for {aten_op}, which is present " + "in torch_disabled_decompositions." + ) + + # Conflicts with _core_aten_decompositions will only occur if + # enable_experimental_decompositions is True in get_decompositions + if aten_op in _core_aten_decompositions: + logger.debug( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in core_aten_decompositions. " + "The custom implementation will take precedence." + ) + + def register(fn: Callable) -> Callable: + return register_decomposition(aten_op=aten_op, registry=registry)(fn) + + return register def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: @@ -185,7 +241,7 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 """ - @register_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) + @register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) def inplace_op(*args, **kwargs): out = outplace_op(*args, **kwargs) return args[0].copy_(out) @@ -208,34 +264,36 @@ def inplace_op(*args, **kwargs): replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) -@register_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) +@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) def std_replacement(*args, **kwargs) -> torch.Tensor: return torch.sqrt(torch.var(*args, **kwargs)) -@register_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) +@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: return torch.reciprocal(torch.sqrt(*args, **kwargs)) -@register_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) +@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: return torch.reshape(x, *args, **kwargs) -@register_decomposition( +@register_torch_trt_decomposition( torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS ) def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS) +@register_torch_trt_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS) def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS) +@register_torch_trt_decomposition( + torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS +) def addmm_replacement( input_: torch.Tensor, mat1: torch.Tensor, @@ -249,7 +307,7 @@ def addmm_replacement( ) -@register_decomposition( +@register_torch_trt_decomposition( torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS ) def reciprocal_replacement( @@ -260,22 +318,13 @@ def reciprocal_replacement( def get_decompositions( enable_experimental_decompositions: bool = False, -) -> Dict[torch._ops.OpOverload, Callable]: +) -> Dict[OpOverload, Callable]: if enable_experimental_decompositions: - CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[torch._ops.OpOverload, Callable] = { + CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable] = { decomp: _core_aten_decompositions[decomp] for decomp in _core_aten_decompositions - if ( - decomp not in TORCH_TRT_DECOMPOSITIONS - and decomp not in disabled_decompositions - ) + if decomp not in torch_disabled_decompositions } return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} else: - duplicate_registrations = set(ENABLED_TORCH_DECOMPOSITIONS.keys()).intersection( - set(TORCH_TRT_DECOMPOSITIONS.keys()) - ) - assert ( - not duplicate_registrations - ), f"Detected duplicate decompositions on: {duplicate_registrations}" return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS} From 78349078b0ff86fdd2b07b101acd5cdd08cb04ac Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 8 Aug 2023 13:47:55 -0700 Subject: [PATCH 3/5] fix: Add additional decompositions for SD --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index e5930db498..ec8290861f 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -172,6 +172,12 @@ aten.zero_, aten.zeros, aten.zeros_like, + # Non-default convenience decompositions + aten.clamp_min, + aten.clamp_max, + aten.linalg_vector_norm, + aten.full, + aten.repeat, } torch_disabled_decompositions: Set[OpOverload] = {} From 25191a9ad7aa86f17ce3b4adb05d886252616875 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 8 Aug 2023 18:41:30 -0700 Subject: [PATCH 4/5] mypy: Reformatting --- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 3 +- py/torch_tensorrt/dynamo/compile.py | 2 +- .../dynamo/lowering/_decompositions.py | 41 ++++++++++--------- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ec67a7a358..45bc198724 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -10,5 +10,6 @@ OPTIMIZATION_LEVEL = None TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False + USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 327140df6e..bbf041efbf 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -4,6 +4,7 @@ import torch from torch_tensorrt.dynamo._defaults import ( DEBUG, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -14,7 +15,6 @@ USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ) @@ -51,6 +51,7 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 70bb604b84..eb051e93e9 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -14,6 +14,7 @@ from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._defaults import ( DEBUG, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -24,7 +25,6 @@ USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ) from torch_tensorrt.dynamo.backend.backends import _compile_module from torch_tensorrt.dynamo.lowering._fusers import ( diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index ec8290861f..efb64247dd 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,17 +1,16 @@ -from typing import Any, Callable, Dict, Set -import torch import logging -from torch._decomp import ( - register_decomposition, - core_aten_decompositions, - get_decompositions as get_torch_decompositions, -) +from typing import Any, Callable, Dict, Optional, Set + +import torch +from torch._decomp import core_aten_decompositions +from torch._decomp import get_decompositions as get_torch_decompositions +from torch._decomp import register_decomposition from torch._ops import OpOverload aten = torch.ops.aten _core_aten_decompositions: Dict[ - OpOverload, Callable + OpOverload, Callable[[Any], Any] ] = core_aten_decompositions() torch_enabled_decompositions: Set[OpOverload] = { aten._adaptive_avg_pool2d_backward, @@ -179,19 +178,19 @@ aten.full, aten.repeat, } -torch_disabled_decompositions: Set[OpOverload] = {} +torch_disabled_decompositions: Set[OpOverload] = set() ENABLED_TORCH_DECOMPOSITIONS: Dict[ - OpOverload, Callable + OpOverload, Callable[[Any], Any] ] = get_torch_decompositions(torch_enabled_decompositions) -TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable] = {} +TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} logger = logging.getLogger(__name__) -def check_decomp_set_invariants(): +def check_decomp_set_invariants() -> None: """Validates no overlap between enabled and disabled decomposition sets""" overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) @@ -206,7 +205,9 @@ def check_decomp_set_invariants(): check_decomp_set_invariants() -def register_torch_trt_decomposition(aten_op, registry=None): +def register_torch_trt_decomposition( + aten_op: OpOverload, registry: Optional[Any] = None +) -> Callable[[Any], Any]: """Checks if the decomposition already exists in one of the sets Registers the decomposition via the Torch utility @@ -235,7 +236,7 @@ def register_torch_trt_decomposition(aten_op, registry=None): "The custom implementation will take precedence." ) - def register(fn: Callable) -> Callable: + def register(fn: Callable[[Any], Any]) -> Any: return register_decomposition(aten_op=aten_op, registry=registry)(fn) return register @@ -248,7 +249,7 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: """ @register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) - def inplace_op(*args, **kwargs): + def inplace_op(*args, **kwargs): # type: ignore out = outplace_op(*args, **kwargs) return args[0].copy_(out) @@ -271,17 +272,17 @@ def inplace_op(*args, **kwargs): @register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) -def std_replacement(*args, **kwargs) -> torch.Tensor: +def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore return torch.sqrt(torch.var(*args, **kwargs)) @register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) -def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore return torch.reciprocal(torch.sqrt(*args, **kwargs)) @register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) -def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # type: ignore return torch.reshape(x, *args, **kwargs) @@ -324,9 +325,9 @@ def reciprocal_replacement( def get_decompositions( enable_experimental_decompositions: bool = False, -) -> Dict[OpOverload, Callable]: +) -> Dict[OpOverload, Callable[[Any], Any]]: if enable_experimental_decompositions: - CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable] = { + CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: _core_aten_decompositions[decomp] for decomp in _core_aten_decompositions if decomp not in torch_disabled_decompositions From 2064f4ffdcb5f0d7591c432956a22e4f80f18466 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 8 Aug 2023 22:30:14 -0700 Subject: [PATCH 5/5] feat: Move decomposition default groups to new file --- py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_settings.py | 1 - .../dynamo/lowering/_decomposition_groups.py | 200 +++++++++++++++++ .../dynamo/lowering/_decompositions.py | 206 +----------------- 4 files changed, 209 insertions(+), 199 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 45bc198724..ec67a7a358 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -10,6 +10,5 @@ OPTIMIZATION_LEVEL = None TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False - USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index bbf041efbf..6f17ad768b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -51,7 +51,6 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py new file mode 100644 index 0000000000..60fef93e08 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -0,0 +1,200 @@ +from typing import Any, Callable, Dict, Set + +import torch +from torch._decomp import core_aten_decompositions +from torch._decomp import get_decompositions as get_torch_decompositions +from torch._ops import OpOverload + +aten = torch.ops.aten + +_core_aten_decompositions: Dict[ + OpOverload, Callable[[Any], Any] +] = core_aten_decompositions() +torch_enabled_decompositions: Set[OpOverload] = { + aten._adaptive_avg_pool2d_backward, + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.celu, + aten.col2im, + aten.count_nonzero, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.deg2rad, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.dot, + aten.elu, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten._euclidean_dist.default, + aten.expand_as, + aten.eye, + aten.fill, + aten.frac, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu, + aten.gelu_backward, + aten.glu_backward, + aten.grid_sampler_2d, + aten.hardshrink, + aten.hardshrink_backward, + aten.hardsigmoid, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add, + aten.index_add_, + aten.index_copy, + aten.index_copy_, + aten.index_fill, + aten.index_fill_, + aten.index_select, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten.leaky_relu, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.max_pool2d_with_indices_backward, + aten.mish, + aten.mse_loss, + aten.mse_loss_backward, + aten.mv, + aten.mvlgamma, + aten.nansum, + aten.nan_to_num, + aten.narrow, + # TODO: Disable the below operators once freezing is done + aten.native_batch_norm, + aten.native_batch_norm_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten.native_dropout_backward, + aten.native_group_norm, + aten.native_group_norm_backward, + aten.native_layer_norm, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm, + aten.ones, + aten.ones_like, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.renorm, + aten.renorm_, + aten.rot90, + aten.rsub.Scalar, + aten.rsub.Tensor, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward, + aten.sinc, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.softshrink_backward, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.stack, + aten.t, + aten.tanh_backward, + aten.threshold, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.tril.default, + aten.triu.default, + aten.unfold, + aten.unfold_backward, + aten.unfold_copy, + aten.upsample_bilinear2d, + aten.upsample_bilinear2d.vec, + aten.upsample_nearest2d_backward, + aten.xlogy, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + # Non-default convenience decompositions + aten.clamp_min, + aten.clamp_max, + aten.linalg_vector_norm, + aten.full, + aten.repeat, +} +torch_disabled_decompositions: Set[OpOverload] = set() + + +ENABLED_TORCH_DECOMPOSITIONS: Dict[ + OpOverload, Callable[[Any], Any] +] = get_torch_decompositions(torch_enabled_decompositions) +TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} + + +def check_decomp_set_invariants() -> None: + """Validates no overlap between enabled and disabled decomposition sets""" + overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) + + if overlap: + raise AssertionError( + f"Detected {overlap} registered in both torch_enabled_decompositions " + "and torch_disabled_decompositions. Ensure all operator(s) are in " + "at most one of the two sets." + ) + + +check_decomp_set_invariants() diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index efb64247dd..57e1954575 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,210 +1,22 @@ import logging -from typing import Any, Callable, Dict, Optional, Set +from typing import Any, Callable, Dict, Optional import torch -from torch._decomp import core_aten_decompositions -from torch._decomp import get_decompositions as get_torch_decompositions from torch._decomp import register_decomposition from torch._ops import OpOverload -aten = torch.ops.aten - -_core_aten_decompositions: Dict[ - OpOverload, Callable[[Any], Any] -] = core_aten_decompositions() -torch_enabled_decompositions: Set[OpOverload] = { - aten._adaptive_avg_pool2d_backward, - aten.addcdiv, - aten.addcdiv_, - aten.addcmul, - aten.addcmul_, - aten.addr, - aten.aminmax, - aten.arange.default, - aten.arange.start, - aten.avg_pool2d_backward, - aten.binary_cross_entropy, - aten.binary_cross_entropy_backward, - aten.binary_cross_entropy_with_logits, - aten.celu, - aten.col2im, - aten.count_nonzero, - aten.cudnn_batch_norm, - aten.cudnn_batch_norm_backward, - aten.deg2rad, - aten.detach, - aten.diag_embed, - aten.diagonal_backward, - aten.dot, - aten.elu, - aten.elu_backward, - aten._embedding_bag, - aten.embedding_dense_backward, - aten._euclidean_dist.default, - aten.expand_as, - aten.eye, - aten.fill, - aten.frac, - aten._fused_moving_avg_obs_fq_helper, - aten.gelu, - aten.gelu_backward, - aten.glu_backward, - aten.grid_sampler_2d, - aten.hardshrink, - aten.hardshrink_backward, - aten.hardsigmoid, - aten.hardsigmoid_backward, - aten.hardswish, - aten.hardswish_, - aten.hardswish_backward, - aten.hardtanh, - aten.hardtanh_, - aten.hardtanh_backward, - aten.heaviside, - aten.huber_loss, - aten.huber_loss_backward, - aten.im2col, - aten.index_add, - aten.index_add_, - aten.index_copy, - aten.index_copy_, - aten.index_fill, - aten.index_fill_, - aten.index_select, - aten.isneginf, - aten.isposinf, - aten.l1_loss, - aten.leaky_relu, - aten.leaky_relu_, - aten.leaky_relu_backward, - aten.lerp, - aten.linspace, - aten.logaddexp, - aten.logaddexp2, - aten.logit, - aten.logit_backward, - aten.log_sigmoid_backward, - aten.log_sigmoid_forward, - aten._log_softmax, - aten._log_softmax_backward_data, - aten.logspace, - aten.logsumexp.default, - aten.masked_fill, - aten.masked_fill_, - aten.max_pool2d_with_indices_backward, - aten.mish, - aten.mse_loss, - aten.mse_loss_backward, - aten.mv, - aten.mvlgamma, - aten.nansum, - aten.nan_to_num, - aten.narrow, - # TODO: Disable the below operators once freezing is done - aten.native_batch_norm, - aten.native_batch_norm_backward, - aten._native_batch_norm_legit, - aten._native_batch_norm_legit_functional, - aten._native_batch_norm_legit_no_training, - aten.native_dropout_backward, - aten.native_group_norm, - aten.native_group_norm_backward, - aten.native_layer_norm, - aten.native_layer_norm_backward, - aten.new_empty, - aten.new_full, - aten.new_ones, - aten.new_zeros, - aten.nll_loss_backward, - aten.nll_loss_forward, - aten.norm, - aten.ones, - aten.ones_like, - aten._prelu_kernel, - aten._prelu_kernel_backward, - aten._reshape_alias, - aten.rad2deg, - aten.renorm, - aten.renorm_, - aten.rot90, - aten.rsub.Scalar, - aten.rsub.Tensor, - aten.select_backward, - aten.select_scatter, - aten.sgn, - aten.sigmoid_backward, - aten.silu, - aten.silu_, - aten.silu_backward, - aten.sinc, - aten.slice_backward, - aten.smooth_l1_loss, - aten.smooth_l1_loss_backward, - aten.soft_margin_loss, - aten.soft_margin_loss_backward, - aten._softmax, - aten._softmax_backward_data, - aten.softplus, - aten.softplus_backward, - aten.softshrink, - aten.softshrink_backward, - aten.special_entr, - aten.special_log_ndtr, - aten.special_xlog1py, - aten.stack, - aten.t, - aten.tanh_backward, - aten.threshold, - aten.threshold_backward, - aten.trace, - aten.transpose.int, - aten.tril.default, - aten.triu.default, - aten.unfold, - aten.unfold_backward, - aten.unfold_copy, - aten.upsample_bilinear2d, - aten.upsample_bilinear2d.vec, - aten.upsample_nearest2d_backward, - aten.xlogy, - aten.zero, - aten.zero_, - aten.zeros, - aten.zeros_like, - # Non-default convenience decompositions - aten.clamp_min, - aten.clamp_max, - aten.linalg_vector_norm, - aten.full, - aten.repeat, -} -torch_disabled_decompositions: Set[OpOverload] = set() - - -ENABLED_TORCH_DECOMPOSITIONS: Dict[ - OpOverload, Callable[[Any], Any] -] = get_torch_decompositions(torch_enabled_decompositions) -TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} - +from ._decomposition_groups import ( + ENABLED_TORCH_DECOMPOSITIONS, + TORCH_TRT_DECOMPOSITIONS, + _core_aten_decompositions, + aten, + torch_disabled_decompositions, + torch_enabled_decompositions, +) logger = logging.getLogger(__name__) -def check_decomp_set_invariants() -> None: - """Validates no overlap between enabled and disabled decomposition sets""" - overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) - - if overlap: - raise AssertionError( - f"Detected {overlap} registered in both torch_enabled_decompositions " - "and torch_disabled_decompositions. Ensure all operator(s) are in " - "at most one of the two sets." - ) - - -check_decomp_set_invariants() - - def register_torch_trt_decomposition( aten_op: OpOverload, registry: Optional[Any] = None ) -> Callable[[Any], Any]: