Skip to content

Commit 85d7b25

Browse files
committed
test: support automatic plugin feature with different dimensions and add flashinfer.rmsnorm support test case
1 parent 8943fb9 commit 85d7b25

File tree

6 files changed

+69
-14
lines changed

6 files changed

+69
-14
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ jobs:
142142
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
143143
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py
144144
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
145+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/flashinfer_plugin.py
145146
popd
146147
147148
tests-py-dynamo-fe:

py/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ pybind11==2.6.2
55
torch>=2.8.0.dev,<2.9.0
66
torchvision>=0.22.0.dev,<0.23.0
77
--extra-index-url https://pypi.ngc.nvidia.com
8-
pyyaml
8+
pyyaml

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from types import FunctionType
34
from typing import Any, Callable, Tuple
@@ -130,16 +131,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
130131
output = torch_op(*fake_args, **kwargs)
131132

132133
# We assume that number of dimensions are the same in torch op
133-
shape_calc_fns = [None] * args[0].ndim
134-
for i in range(args[0].ndim):
135-
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
134+
shape_calc_fns = [None] * output.ndim
135+
136+
for i in range(output.ndim):
137+
input_node_expr = list(
138+
itertools.chain.from_iterable(
139+
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
140+
)
141+
)
142+
136143
shape_calc_fns[i] = lambdify(
137144
tuple(input_node_expr), output.shape[i].node.expr, "math"
138145
)
139146

140147
out_desc = tensor_args[0].like()
141148
for i in range(out_desc.ndim):
142-
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
149+
input_shape_expr = list(
150+
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
151+
)
152+
143153
if output.shape[i].node.expr is None:
144154
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
145155
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

tests/py/dynamo/automatic_plugin/test_automatic_plugin.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,3 @@ def forward(self, lhs, rhs):
8181

8282
if __name__ == "__main__":
8383
run_tests()
84-
85-
# Example Usage
86-
# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
87-
# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
88-
89-
# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B)
90-
91-
# print("C (Addition):", C)
92-
# print("D (Multiplication):", D)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
3+
flashinfer = pytest.importorskip("flashinfer")
4+
import torch
5+
import torch.nn as nn
6+
import torch_tensorrt
7+
from parameterized import parameterized
8+
from torch.testing._internal.common_utils import run_tests
9+
from torch_tensorrt._enums import dtype
10+
11+
from ..conversion.harness import DispatchTestCase
12+
13+
14+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
15+
def flashinfer_rmsnorm(
16+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
17+
) -> torch.Tensor:
18+
return flashinfer.norm.rmsnorm(input, weight)
19+
20+
21+
@torch.library.register_fake("flashinfer::rmsnorm")
22+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
23+
return input
24+
25+
26+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
27+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
28+
)
29+
30+
31+
class TestAutomaticPlugin(DispatchTestCase):
32+
@parameterized.expand(
33+
[
34+
((64, 64), (64,), torch.float16),
35+
((256, 256), (256,), torch.float16),
36+
]
37+
)
38+
def test_rmsnorm_float(self, input_shape, weight_shape, data_type):
39+
class rmsnorm(nn.Module):
40+
def forward(self, input, weight):
41+
return torch.ops.flashinfer.rmsnorm.default(input, weight)
42+
43+
inputs = [
44+
torch.randn(input_shape, device="cuda", dtype=data_type),
45+
torch.randn(weight_shape, device="cuda", dtype=data_type),
46+
]
47+
48+
self.run_test(rmsnorm(), inputs, precision=dtype.f16)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

tests/py/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pytest>=8.2.1
88
pytest-xdist>=3.6.1
99
pyyaml
1010
timm>=1.0.3
11+
flashinfer-python; python_version < "3.13"
1112
transformers==4.49.0
1213
nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13"
1314
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)