From 60b433912436bb85fd2129c2d72c0053ebd08e93 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 21 Oct 2022 16:35:22 -0700 Subject: [PATCH] Changes done internally at Facebook afdc533da031a64e162bb08c8629ff38739e24f8 Wei Wei [fx2trt] disable dispatch trace leaf node test c22f691e6eae1b06ecd301eb6285b32d5dc9717c Mike Iovine [fx2trt] Support dict inputs in acc tracer 8c05a3c57b1f5c63108b979ef8c61411525d0b1f Mike Iovine [fx2trt] Support namedtuple access in acc tracer getattr 1580805d827eb40c941e769b0b99e7c6a3ed6f89 Wei Wei [fx2trt] add reshape unit test baab27b81b1275de92fdaf760a158ce951564d33 Donglin Xia Register avg_pool3d for acc_op in acc_op.py ae4c4e2c3c18d78542140fcc30e1c24f7c647ef3 Wei Wei [aten2trt] init check-in 87ef03338c9a25c5a610a2eb590345e8935f8d75 Wei Wei [aten2trt] add binary ops 2bb168517ace7e638cffc7a241b1cbf528790b92 Mike Iovine [fx2trt] Add acc normalization blocklist 137a3977ffeb03d0387e8a95ff2f32f3d15b3de8 Wei Wei [aten2trt] resnet support fef54c237589a70c007c861e2d59c4052e3de054 Kefei Lu [easy] fx2xxx: fix fuse_parallel_linear which changes getitem slices from tuple to list 4b062ef361cd7797e72c51bb4dc41766aca7b6db Kefei Lu fx2trt: fix bad reshape pattern x.reshape(y.size(0), ...) 49573920892bb2fe75fe011a8cad9887bdc8bd04 Alex Beloi [FX] add tracing for torch.detach 42c54d69c68dc58ac348647acada88b1e5634b40 Fei Kou Fix clamping float32 boundary values e013621dedf5960f81b915cef8d2ce19ca349a7a Kefei Lu trt lower: change preset application logic to in-place instead of immutable update adc9f8ff48c01a0ce70080c930221ac81f048563 Kefei Lu [easy]: fix another instance of [slice(), ...] to (slice(), ...) e9cc5f4b676df502a80a4b85586096e4a3e6a9d6 Charles David Hernandez [docs] fix broken links f06174dbb190df4ea488ca99a81d4884b5ed3aa2 wwei6 [fx2trt] compile --- docs/_sources/tutorials/ptq.rst.txt | 6 +- .../fx/converters/acc_ops_converters.py | 8 +- .../fx/passes/lower_basic_pass.py | 150 +++++++++++++++++- .../fx/passes/lower_pass_manager_builder.py | 2 + py/torch_tensorrt/fx/passes/pass_utils.py | 11 +- .../fx/test/converters/acc_op/test_clamp.py | 1 + .../test/passes/test_fix_reshape_batch_dim.py | 51 ++++++ .../fx/test/tracer/test_acc_tracer.py | 25 +++ .../fx/test/tracer/test_dispatch_tracer.py | 2 +- .../fx/tracer/acc_tracer/acc_ops.py | 6 + 10 files changed, 254 insertions(+), 8 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py diff --git a/docs/_sources/tutorials/ptq.rst.txt b/docs/_sources/tutorials/ptq.rst.txt index b62457109f..75b27c8409 100644 --- a/docs/_sources/tutorials/ptq.rst.txt +++ b/docs/_sources/tutorials/ptq.rst.txt @@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well. From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on -CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq +CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq .. _writing_ptq_python: @@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode. calibrator=calibrator) If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient. -For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py -and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py +For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py +and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py Citations ^^^^^^^^^^^ diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9135ebc98a..f5278b1d07 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -2854,8 +2854,12 @@ def add_clamp(network, input, val, op, name): else: acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions acc_ops_clamp_tensor = ( - val - * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)) + ( + val + * torch.ones( + acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype) + ) + ) .cpu() .numpy() ) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index f7f554e1c6..844fa24238 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -1,10 +1,13 @@ import copy +import logging import operator import warnings -from typing import Any +from typing import Any, Optional import torch import torch.fx +import torch.fx as fx +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch.fx.experimental.const_fold import split_const_subgraphs from ..observer import observable @@ -13,6 +16,8 @@ from ..tracer.acc_tracer.acc_utils import get_attr from .pass_utils import log_before_after, validate_inference +_LOGGER = logging.getLogger(__name__) + # Create an alias for module input type to avoid littering pyre-ignore for Any # throughout the file. Input = Any @@ -460,3 +465,146 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input): gm.graph.lint() gm.recompile() return gm + + +def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule: + """\ + TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256), + since the dynamic shape of the reshape comes from the dynamic shape of + another node (y). The compilation will fail with various memory related + errors, depending on the size of the input tensor. + + This pass fixes the issue by finding this reshape pattern, checking that: + + x.size(0) == y.size(0) + + And then replaces reshape's batch size from y.size(0) to x.size(0). + """ + + def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: + """\ + Try to find the reshape op's batch size as an input node. + + Match below graph structure and return `node_y`: + node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}}) + """ + if ( + maybe_reshape.op != "call_function" + or maybe_reshape.target != acc_ops.reshape + ): + return None + shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None) + if not shape: + return None + batch_size = shape[0] + if isinstance(batch_size, fx.Node): + return batch_size + return None + + def get_reshape_batch_size_inferred_source( + batch_size_node: fx.Node, + ) -> Optional[fx.Node]: + """\ + Given a node representing the batch size used for reshape op, we want + to know if it is coming from below pattern: + + batch_size_node = src.size()[0] + + or in IR graph: + + src -> size(input=_) -> getitem(input=_, idx=0) + ^ ~~~ batch_size_node + + If so, return `src`. Otherwise, return `None`. + """ + if ( + batch_size_node.op != "call_function" + or batch_size_node.target != acc_ops.getitem + or batch_size_node.kwargs["idx"] != 0 + ): + return None + maybe_size: fx.Node = batch_size_node.all_input_nodes[0] + if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size: + return None + return maybe_size.all_input_nodes[0] + + maybe_reshape: fx.Node + for maybe_reshape in mod.graph.nodes: + reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( + maybe_reshape + ) + if not reshape_batch_size: + continue + reshape_batch_size_inferred_source: Optional[ + fx.Node + ] = get_reshape_batch_size_inferred_source(reshape_batch_size) + if not reshape_batch_size_inferred_source: + continue + + reshape_input: fx.Node = maybe_reshape.kwargs["input"] + if reshape_input == reshape_batch_size_inferred_source: + continue + + if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source): + continue + + _LOGGER.info( + f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}" + ) + + # Step 1: create a node to compute batch size, using the tensor which + # is being reshaped: reshape_input.size()[0]. This batch size is now + # derived from reshape_input, the same node as the reshape op's input. + with mod.graph.inserting_before(maybe_reshape): + reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function( + acc_ops.getitem, + kwargs={ + "idx": 0, + "input": maybe_reshape.graph.call_function( + acc_ops.size, + kwargs={ + "input": reshape_input, + }, + ), + }, + ) + + # Step 2: update `maybe_reshape`'s shape argument to be + # (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER) + maybe_reshape.kwargs = { + **maybe_reshape.kwargs, + "acc_out_ty": acc_utils.build_raw_tensor_meta( + shape=( + reshape_batch_size_2, + *(maybe_reshape.kwargs["acc_out_ty"].shape[1:]), + ) + ), + } + + mod.graph.eliminate_dead_code() + mod.recompile() + return mod + + +def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool: + """\ + Check that x.size(0) == y.size(0) + """ + x_size, y_size = _get_shape(x), _get_shape(y) + return ( + x_size + and y_size + # now both are non-empty + and x_size[0] == y_size[0] + ) + + +def _get_shape(node: fx.Node) -> Optional[torch.Size]: + if ( + not getattr(node, "meta", None) + or not node.meta.get("tensor_meta", None) + or not getattr(node.meta["tensor_meta"], "shape", None) + ): + # shape info not available + return None + return node.meta["tensor_meta"].shape diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 877029cd44..c4bb927b85 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -17,6 +17,7 @@ from .graph_opts import common_subexpression_elimination from .lower_basic_pass import ( + fix_reshape_batch_dim, replace_mutable_op, replace_op_with_indices, run_const_fold, @@ -112,6 +113,7 @@ def graph_optimization_pass(self) -> PassManager: passes.append( inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) ) + passes.append(fix_reshape_batch_dim) return PassManager.build_from_passlist(passes) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 9db173f1e1..78e9ec1b22 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,3 +1,4 @@ +import io import logging import tempfile from functools import wraps @@ -233,15 +234,23 @@ def log_before_after(pass_: PassFunc) -> PassFunc: def pass_with_before_after_log( module: fx.GraphModule, input: Input ) -> fx.GraphModule: + before_io = io.StringIO() + after_io = io.StringIO() with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", delete=False, ) as f: - _LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}") print(f"[{pass_}] Before:\n{module.graph}", file=f) + print(module.graph, file=before_io) + module = pass_(module, input) print(f"[{pass_}] After:\n{module.graph}", file=f) + print(module.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + _LOGGER.info( + f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}" + ) return module return pass_with_before_after_log diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index 7c166c1fe0..e59153d5c9 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -12,6 +12,7 @@ class TestClampConverter(AccTestCase): param("min", min=0.5), param("max", max=0.5), param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), ] ) def test_clamp( diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py new file mode 100644 index 0000000000..bd04692ad5 --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -0,0 +1,51 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import logging +from copy import deepcopy + +import torch +import torch.fx as fx +import torch.nn as nn + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer + +_LOGGER = logging.getLogger(__name__) + + +class TestFixReshapeBatchDim(TestCase): + def test_fix_reshape_batch_dim(self): + class Repro(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return y.view(x.size(0), -1, 3) + + mod = Repro() + modt = fx.symbolic_trace(mod) + inp = [ + torch.rand([10, 60]), + torch.rand([10, 60]), + ] + mod(*inp) + mod_acc_traced = acc_tracer.trace(modt, inp) + mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced)) + + expected_graph = r""" +graph(): + %x : [#users=0] = placeholder[target=x] + %y : [#users=2] = placeholder[target=y] + %size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y}) + %getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size}) + %reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) + return reshape +""" + assert ( + str(mod_fixed.graph).strip() == expected_graph.strip() + ), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}" + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index c3779ef933..23b7329669 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2566,6 +2566,31 @@ def forward(self, x: List[torch.Tensor]) -> torch.Tensor: # Make sure we didn't convert to the acc version self.assertEqual(node.target, operator.getitem) + def test_detach(self): + class TestModule(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.detach(x) + + m = TestModule() + sample_inputs = [torch.randn(8)] + traced = acc_tracer.trace(m, sample_inputs) + + placeholder = output = None + for node in traced.graph.nodes: + if node.op == "placeholder": + assert placeholder is None + placeholder = node + elif node.op == "output": + assert output is None + output = node + else: + raise RuntimeError(f"Unexpected Node {node.format_node()}") + + self.assertIsNotNone(placeholder) + self.assertIsNotNone(output) + + self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) + def test_all_acc_ops_registered(self): self.assertEqual( acc_normalizer._acc_ops, diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index ce2832e9a7..5f02051166 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -162,7 +162,7 @@ def f(x): inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] torchdynamo.reset() - dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_aten_compiler_fp16)(mod) + dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) dynamo_aten_output = dynamo_aten_mod(*inputs) torchdynamo.reset() diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8309db3cf3..1bdc7bf704 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -582,7 +582,9 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: @register_acc_op_properties(AccOpProperty.pointwise) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) +@register_acc_op_mapping(op_and_target=("call_function", torch.clip)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) +@register_acc_op_mapping(op_and_target=("call_method", "clip")) @register_acc_op def clamp(*, input, min=None, max=None): return torch.clamp(input=input, min=min, max=max) @@ -818,6 +820,10 @@ def matmul(*, input, other): @register_custom_acc_mapper_fn( op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")] ) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.detach), + arg_replacement_tuples=[("input", "input")], +) def dropout_mapper(node: torch.fx.Node, mod: nn.Module): """ Remove dropout node and directly map its input to output.