We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a031900 commit 73fd2c4Copy full SHA for 73fd2c4
tests/py/test_api.py
@@ -46,7 +46,13 @@ def test_compile_global(self):
46
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
47
self.assertTrue(same < 2e-2)
48
49
-
+ def test_compile_global_nn_mod(self):
50
+ trt_mod = torchtrt.compile(self.model,
51
+ inputs=[self.input],
52
+ device=torchtrt.Device(gpu_id=0),
53
+ enabled_precisions={torch.float})
54
+ same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
55
+ self.assertTrue(same < 2e-2)
56
57
def test_from_torch_tensor(self):
58
compile_spec = {
0 commit comments