Skip to content

[FX] Changes done internally at Facebook #1288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docsrc/tutorials/getting_started_with_fx_path.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Torch-TensorRT (FX Path) is in ``Beta`` phase and always recommended to work wit

Converting a PyTorch Model to TensorRT Engine
---------------------------------------------
In general, users are welcome to use the ``lower_to_trt()`` to finish the conversion from a model to tensorRT engine. It is a wrapper API that consists of the major steps needed to finish this converison. Please refer to ``lower_example.py`` file in ``examples/fx``.
In general, users are welcome to use the ``compile()`` to finish the conversion from a model to tensorRT engine. It is a wrapper API that consists of the major steps needed to finish this converison. Please refer to ``lower_example.py`` file in ``examples/fx``.

In this section, we will go through an example to illustrate the major steps that FX path uses. Users can refer to ``fx2trt_example.py`` file in ``examples/fx``.

Expand All @@ -60,9 +60,9 @@ symbolically traced variables cannot be used as inputs to control flow
This means the model contains dynamic control flow. Please refer to section “Dynamic Control Flow” in `FX guide <https://pytorch.org/docs/stable/fx.html#dynamic-control-flow>`_.

* **Step 2: Build TensorRT engine**
There are `two different modes <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch>`_ for how TensorRT handles batch dimension, explicit batch dimension and implicit batch dimension. This mode was used by early versions of TensorRT, and is now deprecated but continues to be supported for backwards compatibility. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. User can still choose to use implicit batch mode when they set ``explicit_batch_dimension=False`` in ``lower_to_trt()``. We do not recommend to use it since it will lack of support in future TensorRT versions.
There are `two different modes <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch>`_ for how TensorRT handles batch dimension, explicit batch dimension and implicit batch dimension. This mode was used by early versions of TensorRT, and is now deprecated but continues to be supported for backwards compatibility. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. User can still choose to use implicit batch mode when they set ``explicit_batch_dimension=False`` in ``compile()``. We do not recommend to use it since it will lack of support in future TensorRT versions.

Explicit batch is the default mode and it must be set for dynamic shape. For most of vision task, user can choose to enable ``dynamic_batch`` in ``lower_to_trt()`` if they want to get the similar effects as implicit mode where only batch dimension changes. It has some requirements:
Explicit batch is the default mode and it must be set for dynamic shape. For most of vision task, user can choose to enable ``dynamic_batch`` in ``compile()`` if they want to get the similar effects as implicit mode where only batch dimension changes. It has some requirements:
1. Shapes of inputs, outputs and activations are fixed except batch dimension.
2. Inputs, outputs and activations have batch dimension as the major dimension.
3. All the operators in the model do not modify batch dimension (permute, transpose, split, etc.) or compute over batch dimension (sum, softmax, etc.).
Expand Down
4 changes: 2 additions & 2 deletions examples/fx/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torchvision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.lower import compile
from torch_tensorrt.fx.utils import LowerPrecision


Expand Down Expand Up @@ -183,7 +183,7 @@ def run_configuration_benchmark(
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
elif not conf.jit:
# Run lowering eager mode benchmark
lowered_module = lower_to_trt(
lowered_module = compile(
module,
input,
max_batch_size=conf.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions examples/fx/torchdynamo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torchdynamo
import torchvision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.lower import compile
from torch_tensorrt.fx.utils import LowerPrecision
from torchdynamo.optimizations import backends

Expand Down Expand Up @@ -197,7 +197,7 @@ def run_configuration_benchmark(

if conf.trt:
# Run lowering eager mode benchmark
lowered_module = lower_to_trt(
lowered_module = compile(
module,
input,
max_batch_size=conf.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum

import torch_tensorrt.fx
from torch_tensorrt.fx.lower import lower_to_trt
import torch_tensorrt.fx.lower
from torch_tensorrt.fx.utils import LowerPrecision


Expand Down Expand Up @@ -140,7 +140,7 @@ def compile(
else:
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return lower_to_trt(
return torch_tensorrt.fx.lower.compile(
module,
inputs,
lower_precision=lower_precision,
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .converters import * # noqa: F403 F401
import logging

from .converter_registry import ( # noqa
CONVERTERS,
NO_EXPLICIT_BATCH_DIM_SUPPORT,
Expand All @@ -9,3 +11,5 @@
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa

logging.basicConfig(level=logging.INFO)
40 changes: 22 additions & 18 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Input = Sequence[Any]


def lower_to_trt(
def compile(
module: nn.Module,
input,
max_batch_size: int = 2048,
Expand Down Expand Up @@ -216,28 +216,32 @@ def create(
)
)

@decorate_method(validate_inference(atol=1e-1, rtol=1e-1))
def __call__(
self,
module: nn.Module,
inputs: Input,
additional_inputs: Optional[Input] = None,
) -> nn.Module:
module.eval()

if (
self.lower_pass_manager_builder.lower_setting.lower_precision
== LowerPrecision.FP16
):
module.half()
inputs = tuple(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
lower_setting = self.lower_pass_manager_builder.lower_setting
atol = lower_setting.correctness_atol
rtol = lower_setting.correctness_rtol

@validate_inference(atol=atol, rtol=rtol)
def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
module.eval()
if (
self.lower_pass_manager_builder.lower_setting.lower_precision
== LowerPrecision.FP16
):
module.half()
inputs = tuple(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
)
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
inputs, additional_inputs
)
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
inputs, additional_inputs
)

lower_result = pm(module)
lower_result = pm(module)
return lower_result

return lower_result
return do_lower(module, inputs)
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class LowerSetting(LowerSettingBasic):
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
tactic_sources: tactic sources for TensorRT kernel selection. Default to None,
meaning all possible tactic sources.
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -90,3 +92,5 @@ class LowerSetting(LowerSettingBasic):
opt_profile_replica: int = 1
dynamic_batch: bool = True
tactic_sources: Optional[int] = None
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import unittest

import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/fx/test/passes/test_graph_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination

_LOGGER: logging.Logger = logging.getLogger(__name__)


_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import torch_tensorrt.fx.diagnostics as diag

_LOGGER: logging.Logger = logging.getLogger(__name__)


_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
5 changes: 5 additions & 0 deletions py/torch_tensorrt/fx/tools/trt_profiler_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def profile_trt_module(
layer_info = json.loads(trt_mod.get_layer_info()) # pyre-ignore[29]
shape_map = {}
for layer in layer_info["Layers"]:
# if type is str, it means verbose_profile is off in interpreter.run()
# Theorectically, we can print profiling information without shape information
# but we choose to not print profiling information so we can use verbose_profile to control it
if type(layer) is str:
return
name = layer["Name"]
input_str = ", ".join(
[str(x.get("Dimensions", "[]")) for x in layer.get("Inputs", [])]
Expand Down