Skip to content

Commit 73fd2c4

Browse files
committed
test(//py): Test nn.module compilation
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a031900 commit 73fd2c4

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/py/test_api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ def test_compile_global(self):
4646
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
4747
self.assertTrue(same < 2e-2)
4848

49-
49+
def test_compile_global_nn_mod(self):
50+
trt_mod = torchtrt.compile(self.model,
51+
inputs=[self.input],
52+
device=torchtrt.Device(gpu_id=0),
53+
enabled_precisions={torch.float})
54+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
55+
self.assertTrue(same < 2e-2)
5056

5157
def test_from_torch_tensor(self):
5258
compile_spec = {

0 commit comments

Comments
 (0)