Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import logging
import operator
from typing import Dict, Sequence, Tuple, Union
Expand All @@ -23,7 +24,9 @@ def getitem_validator(getitem_node: Node) -> bool:

# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
@dynamo_tensorrt_converter(builtins.getattr)
@dynamo_tensorrt_converter(torch.ops.aten.detach.default)
@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step)
def generic_evaluator(
ctx: ConversionContext,
target: Target,
Expand Down
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_arange_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestArangeConverter(DispatchTestCase):
@parameterized.expand(
[
(0, 5, 1),
(1, 5, 2),
(3, 5, 3),
(5, 0, -1),
(5, 1, -2),
(5, 3, -3),
]
)
def test_arange(self, start, end, step):
class Arange(nn.Module):
def forward(self, x):
return torch.ops.aten.arange.start_step(start, x.shape[0], step)

inputs = [torch.randn(end, 1)]
self.run_test(
Arange(),
inputs,
)


if __name__ == "__main__":
run_tests()