Skip to content

Commit b2ec12f

Browse files
added smoke test for max-autotune (#1349)
Co-authored-by: agunapal <[email protected]>
1 parent 8bcc106 commit b2ec12f

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/smoke_test/smoke_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import platform
88
import importlib
99
import subprocess
10+
import torch._dynamo
11+
import torch.nn as nn
12+
import torch.nn.functional as F
1013

1114
gpu_arch_ver = os.getenv("MATRIX_GPU_ARCH_VERSION")
1215
gpu_arch_type = os.getenv("MATRIX_GPU_ARCH_TYPE")
@@ -33,6 +36,21 @@
3336
},
3437
]
3538

39+
class Net(nn.Module):
40+
def __init__(self):
41+
super(Net, self).__init__()
42+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
43+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
44+
self.fc1 = nn.Linear(9216, 1)
45+
46+
def forward(self, x):
47+
x = self.conv1(x)
48+
x = self.conv2(x)
49+
x = F.max_pool2d(x, 2)
50+
x = torch.flatten(x, 1)
51+
output = self.fc1(x)
52+
return output
53+
3654
def check_version(package: str) -> None:
3755
# only makes sense to check nightly package where dates are known
3856
if channel == "nightly":
@@ -175,6 +193,14 @@ def foo(x: torch.Tensor) -> torch.Tensor:
175193
x_pt2 = torch.compile(foo)(x)
176194
print(torch.allclose(x_eager, x_pt2))
177195

196+
# Reset torch dynamo since we are changing mode
197+
torch._dynamo.reset()
198+
dtype = torch.float32
199+
torch.set_float32_matmul_precision('high')
200+
print(f"Testing smoke_test_compile with mode 'max-autotune' for {dtype}")
201+
x = torch.rand(64, 1, 28, 28, device="cuda").type(torch.float32)
202+
model = Net().to(device="cuda")
203+
x_pt2 = torch.compile(model, mode="max-autotune")(x)
178204

179205
def smoke_test_modules():
180206
for module in MODULES:

0 commit comments

Comments
 (0)