Skip to content

Commit 55535ca

Browse files
committed
utility function to detect tegra platform
1 parent 3c9b77f commit 55535ca

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.utils import is_tegra_platform
67

78
from .accumulate_fp32_matmul import accumulate_fp32_matmul
89
from .constant_folding import constant_fold
@@ -21,13 +22,11 @@
2122
repair_input_as_output,
2223
fuse_prims_broadcast,
2324
replace_max_pool_with_indices,
24-
lower_scaled_dot_product_attention,
25-
view_to_reshape,
2625
remove_assert_nodes,
2726
accumulate_fp32_matmul,
2827
]
2928

30-
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
29+
if not is_tegra_platform():
3130
pass_list.append(fuse_distributed_ops)
3231

3332
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
806806
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
807807
)
808808
return output_dtypes
809+
810+
811+
def is_tegra_platform() -> bool:
812+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
813+
return True
814+
return False

0 commit comments

Comments
 (0)