Skip to content

fix/feat: Add support for 64bit Tensor inputs FX aten [9 / x] #2021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_PYTHON_RUNTIME = None
TRUNCATE_LONG_AND_DOUBLE = False
13 changes: 9 additions & 4 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from functools import partial

from typing import Any, Optional, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability
Comment on lines +8 to +9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were required to avoid circular import errors in Python. Updates to #1983 could potentially fix this issue.

from torch_tensorrt.fx.utils import LowerPrecision

from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
from torch_tensorrt.dynamo.backend._defaults import (
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
Expand All @@ -20,6 +21,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -43,7 +45,7 @@ def compile(
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
calibrator=None,
truncate_long_and_double=False,
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation=False,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
Expand All @@ -62,7 +64,7 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
+ "truncate_long_and_double, torch_executed_ops, pass_through_build_failures}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -103,6 +105,7 @@ def compile(
version_compatible=version_compatible,
optimization_level=optimization_level,
use_python_runtime=use_python_runtime,
truncate_long_and_double=truncate_long_and_double,
**kwargs,
)

Expand Down Expand Up @@ -130,6 +133,7 @@ def create_backend(
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand Down Expand Up @@ -163,5 +167,6 @@ def create_backend(
version_compatible=version_compatible,
optimization_level=optimization_level,
use_python_runtime=use_python_runtime,
truncate_long_and_double=truncate_long_and_double,
**kwargs,
)
13 changes: 12 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
import torch._dynamo as td

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.common import CompilationSettings
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
Expand All @@ -16,6 +16,7 @@
get_submod_inputs,
)
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
from torch_tensorrt.dynamo.common import repair_long_or_double_inputs
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
Expand Down Expand Up @@ -134,6 +135,16 @@ def _compile_module(
partitioned_module, submodule, sample_inputs
)

# Ensure all submodule inputs do not require a gradient
for param in submodule_inputs:
param.requires_grad = False

# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
import io
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
from torch_tensorrt.dynamo.common import (
CompilationSettings,
InputTensorSpec,
TRTInterpreter,
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand Down
113 changes: 112 additions & 1 deletion py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from torch_tensorrt.dynamo import compile
from utils import lower_graph_testing
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT


class TestTRTModuleNextCompilation(TestCase):
Expand Down Expand Up @@ -169,5 +169,116 @@ def forward(self, x, y):
)


class Test64BitInput(TestCase):
def test_float64_input_full_support(self):
class FullySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.mean.dim(
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
)

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)

self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

inputs = [
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
]

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph,
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)

def test_int64_input_partial_support(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.div.Tensor_mode(
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
)

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
unexpected_ops = {torch.ops.aten.add.Tensor}

inputs = [
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
]

(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
fx_graph,
inputs,
unexpected_ops=unexpected_ops,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
testing_partitioning=True,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEquals(
len(partitioned_graphs),
1,
"Without control flow breaks, there should only be a single graph",
)
self.assertEquals(
len(list(partitioned_graphs[0].named_children())),
1,
"Certain operators are set to run in Torch, expected 1 segment",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph,
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from torch_tensorrt.dynamo import compile
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT


class TestLowering(TestCase):
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import logging
from dataclasses import replace, fields

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.common import CompilationSettings, use_python_runtime_parser
from typing import Any, Union, Sequence, Dict
from torch_tensorrt import _Input, Device
from ..common_utils import use_python_runtime_parser
from torch_tensorrt import _Input
from torch_tensorrt._Device import Device


logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from typing import Optional

from ._settings import CompilationSettings
from .input_tensor_spec import InputTensorSpec
from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .truncate_long_and_double import repair_long_or_double_inputs


logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
Expand All @@ -12,6 +12,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -27,3 +28,4 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
Loading