diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 46fadcb905..6863f62317 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -22,8 +22,8 @@ The following code illustrates this approach. model = MyModel().eval().cuda() inputs = torch.randn((1, 3, 224, 224)).cuda() trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule - trt_script_model = torch.jit.trace(trt_gm, inputs) - torch.jit.save(trt_script_model, "trt_model.ts") + trt_traced_model = torchtrt.dynamo.serialize(trt_gm, inputs) + torch.jit.save(trt_traced_model, "trt_model.ts") # Later, you can load it and run inference model = torch.jit.load("trt_model.ts").cuda() @@ -37,21 +37,19 @@ b) ExportedProgram import torch import torch_tensorrt - from torch_tensorrt.dynamo.export import transform, create_exported_program model = MyModel().eval().cuda() inputs = torch.randn((1, 3, 224, 224)).cuda() trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule # Transform and create an exported program - trt_gm = transform(trt_gm, inputs) - trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict()) - torch._export.save(trt_exp_program, "trt_model.ep") + trt_exp_program = torch_tensorrt.dynamo.serialize(trt_gm, inputs, call_spec, ir="exported_program") + torch.export.save(trt_exp_program, "trt_model.ep") # Later, you can load it and run inference - model = torch._export.load("trt_model.ep") + model = torch.export.load("trt_model.ep") model(inputs) -`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. +`torch_tensorrt.dynamo.transform` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram. This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 diff --git a/py/requirements.txt b/py/requirements.txt index aa957c3743..9a6d0eb90d 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,9 +1,8 @@ numpy packaging pybind11==2.6.2 ---extra-index-url https://download.pytorch.org/whl/nightly/cu121 -torch>=2.1.0,<2.2.0 -torchvision>=0.16.0,<0.17.0 +torch==2.1.0 +torchvision==0.16.0 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.6.1 pyyaml diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 63cc2af10a..1e39ea7fb1 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -16,3 +16,4 @@ DYNAMO_CONVERTERS, dynamo_tensorrt_converter, ) + from .export import serialize diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 5394c1382e..7e6598b1ad 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -46,7 +46,7 @@ def compile( - exported_program: ExportedProgram, + exported_program: Union[torch.fx.GraphModule, ExportedProgram], inputs: Any, *, device: Optional[Union[Device, torch.device, str]] = DEVICE, @@ -86,7 +86,15 @@ def compile( inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) - gm = exported_program.module() + if isinstance(exported_program, torch.fx.GraphModule): + gm = exported_program + elif isinstance(exported_program, ExportedProgram): + gm = exported_program.module() + else: + raise AssertionError( + f"Input graph should either be an ExportedProgram or a GraphModule but got type {type(exported_program)}" + ) + logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module diff --git a/py/torch_tensorrt/dynamo/export.py b/py/torch_tensorrt/dynamo/export.py index 9bd1dbddb3..91573f4491 100644 --- a/py/torch_tensorrt/dynamo/export.py +++ b/py/torch_tensorrt/dynamo/export.py @@ -1,6 +1,6 @@ import copy import operator -from typing import Any, Dict, Sequence, Tuple, Union, cast +from typing import Any, Dict, Sequence, Tuple, cast import torch from torch._export.exported_program import CallSpec @@ -10,28 +10,42 @@ from torch_tensorrt.dynamo import partitioning -def transform( - gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] -) -> torch.fx.GraphModule: - # Run shape analysis - _, outputs_map = partitioning.run_shape_analysis(gm, inputs) - - # Inline TensorRT submodules - inline_trt_modules(gm, outputs_map) - - # Inline pytorch submodules - inline_torch_modules(gm) - - # Lift constant buffers and parameters in the graph - # torch.export serialization expects them to be lifted - lift_constant_pass(gm) - - # Clean the graph - gm.delete_all_unused_submodules() - gm.graph.eliminate_dead_code() - gm.graph.lint() - - return gm +def serialize( + gm: torch.fx.GraphModule, + inputs: Sequence[torch.Tensor], + call_spec: CallSpec = None, + ir: str = "torchscript", +) -> ExportedProgram: + if ir == "torchscript": + return torch.jit.trace(gm, inputs) + elif ir == "exported_program": + assert call_spec + # Run shape analysis + _, outputs_map = partitioning.run_shape_analysis(gm, inputs) + + # Inline TensorRT submodules + inline_trt_modules(gm, outputs_map) + + # Inline pytorch submodules + inline_torch_modules(gm) + + # Lift constant buffers and parameters in the graph + # torch.export serialization expects them to be lifted + lift_constant_pass(gm) + + # Clean the graph + gm.delete_all_unused_submodules() + gm.graph.eliminate_dead_code() + gm.graph.lint() + + # Create an exported program with the TRT GraphModule + exp_program = create_trt_exp_program(gm, call_spec) + + return exp_program + else: + raise ValueError( + "Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program" + ) def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule: @@ -115,7 +129,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Copy all nodes in the submodule into gm and # store the output node of this submodule which is now present in gm - submodule_output = gm.graph.graph_copy(submodule.graph, val_map) # Get their references (since we copied) in the parent graph (gm) @@ -174,9 +187,7 @@ def copy_submodule_attributes( def create_trt_exp_program( - gm: torch.fx.GraphModule, - call_spec: CallSpec, - state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], + gm: torch.fx.GraphModule, call_spec: CallSpec ) -> ExportedProgram: """Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines and constructs an Exported Program object with the new IO node names, call_spec and state_dict @@ -208,7 +219,7 @@ def create_trt_exp_program( ) trt_exp_program = ExportedProgram( - gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], [] + gm, gm.graph, trt_graph_signature, call_spec, gm.state_dict(), {}, [], [] ) return trt_exp_program diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 5e0dc7406c..122503bb00 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -6,7 +6,6 @@ import torch_tensorrt as torchtrt import torchvision.models as models from torch._export.serde.serialize import deserialize, serialize -from torch_tensorrt.dynamo.export import create_trt_exp_program, transform from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -45,9 +44,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_gm = transform(trt_gm, [input]) - trt_exp_program = create_trt_exp_program( - trt_gm, exp_program.call_spec, trt_gm.state_dict() + trt_exp_program = torchtrt.dynamo.serialize( + trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" ) serialized_prog = serialize(trt_exp_program) deserialized_prog = deserialize(*serialized_prog) @@ -100,11 +98,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_gm = transform(trt_gm, [input]) - trt_exp_program = create_trt_exp_program( - trt_gm, exp_program.call_spec, trt_gm.state_dict() + trt_exp_program = torchtrt.dynamo.serialize( + trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" ) - serialized_prog = serialize(trt_exp_program) deserialized_prog = deserialize(*serialized_prog) # Check Pyt and TRT exported program outputs @@ -161,11 +157,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_gm = transform(trt_gm, [input]) - trt_exp_program = create_trt_exp_program( - trt_gm, exp_program.call_spec, trt_gm.state_dict() + trt_exp_program = torchtrt.dynamo.serialize( + trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" ) - torch._export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch._export.load("/tmp/trt.ep") @@ -224,11 +218,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_gm = transform(trt_gm, [input]) - trt_exp_program = create_trt_exp_program( - trt_gm, exp_program.call_spec, trt_gm.state_dict() + trt_exp_program = torchtrt.dynamo.serialize( + trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" ) - torch._export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch._export.load("/tmp/trt.ep") @@ -270,9 +262,8 @@ def test_resnet18_save_load(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_gm = transform(trt_gm, [input]) - trt_exp_program = create_trt_exp_program( - trt_gm, exp_program.call_spec, trt_gm.state_dict() + trt_exp_program = torchtrt.dynamo.serialize( + trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" ) torch._export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch._export.load("/tmp/trt.ep") @@ -291,59 +282,3 @@ def test_resnet18_save_load(ir): cos_sim > COSINE_THRESHOLD, msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - - -# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 -# @pytest.mark.unit -# def test_hybrid_conv_fallback(ir): -# """ -# This tests export save and load functionality on a hybrid -# model where a conv (a weighted layer) has been forced to fallback to Pytorch. -# """ - -# class MyModule(torch.nn.Module): -# def __init__(self): -# super().__init__() -# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) -# self.relu = torch.nn.ReLU() - -# def forward(self, x): -# conv = self.conv(x) -# relu = self.relu(conv) -# mul = relu * 0.5 -# return mul - -# model = MyModule().eval().cuda() -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "ir": ir, -# "min_block_size": 1, -# "torch_executed_ops": "torch.ops.aten.convolution.default", -# } - -# trt_exp_program = torchtrt.compile(model, **compile_spec) -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") - -# outputs_pyt = model(input) -# outputs_trt = trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# outputs_trt_deser = deser_trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# )