@@ -44,14 +44,19 @@ def forward(self, x):
4444 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
4545 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
4646 torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
47- # TODO: Enable this serialization issues are fixed
48- # deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
47+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
4948 # Check Pyt and TRT exported program outputs
5049 cos_sim = cosine_similarity (model (input ), trt_module (input )[0 ])
5150 assertions .assertTrue (
5251 cos_sim > COSINE_THRESHOLD ,
5352 msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
5453 )
54+ # Check Pyt and deserialized TRT exported program outputs
55+ cos_sim = cosine_similarity (model (input ), deser_trt_module (input )[0 ])
56+ assertions .assertTrue (
57+ cos_sim > COSINE_THRESHOLD ,
58+ msg = f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
59+ )
5560 # TODO: Enable this serialization issues are fixed
5661 # # Check Pyt and deserialized TRT exported program outputs
5762 # cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
@@ -95,9 +100,8 @@ def forward(self, x):
95100
96101 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
97102 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
98- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
99- # TODO: Enable this serialization issues are fixed
100- # deser_trt_module = torchtrt.load("./trt.ep").module()
103+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
104+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
101105 # Check Pyt and TRT exported program outputs
102106 outputs_pyt = model (input )
103107 outputs_trt = trt_module (input )
@@ -108,15 +112,14 @@ def forward(self, x):
108112 msg = f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
109113 )
110114
111- # TODO: Enable this serialization issues are fixed
112- # # Check Pyt and deserialized TRT exported program outputs
113- # outputs_trt_deser = deser_trt_module(input)
114- # for idx in range(len(outputs_pyt)):
115- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
116- # assertions.assertTrue(
117- # cos_sim > COSINE_THRESHOLD,
118- # msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
119- # )
115+ # Check Pyt and deserialized TRT exported program outputs
116+ outputs_trt_deser = deser_trt_module (input )
117+ for idx in range (len (outputs_pyt )):
118+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
119+ assertions .assertTrue (
120+ cos_sim > COSINE_THRESHOLD ,
121+ msg = f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
122+ )
120123
121124
122125@pytest .mark .unit
@@ -152,9 +155,8 @@ def forward(self, x):
152155
153156 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
154157 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
155- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
156- # TODO: Enable this serialization issues are fixed
157- # deser_trt_module = torchtrt.load("./trt.ep").module()
158+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
159+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
158160 # Check Pyt and TRT exported program outputs
159161 outputs_pyt = model (input )
160162 outputs_trt = trt_module (input )
@@ -165,15 +167,14 @@ def forward(self, x):
165167 msg = f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
166168 )
167169
168- # TODO: Enable this serialization issues are fixed
169- # # Check Pyt and deserialized TRT exported program outputs
170- # outputs_trt_deser = deser_trt_module(input)
171- # for idx in range(len(outputs_pyt)):
172- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
173- # assertions.assertTrue(
174- # cos_sim > COSINE_THRESHOLD,
175- # msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
176- # )
170+ # Check Pyt and deserialized TRT exported program outputs
171+ outputs_trt_deser = deser_trt_module (input )
172+ for idx in range (len (outputs_pyt )):
173+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
174+ assertions .assertTrue (
175+ cos_sim > COSINE_THRESHOLD ,
176+ msg = f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
177+ )
177178
178179
179180@pytest .mark .unit
@@ -212,9 +213,8 @@ def forward(self, x):
212213
213214 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
214215 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
215- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
216- # TODO: Enable this serialization issues are fixed
217- # deser_trt_module = torchtrt.load("./trt.ep").module()
216+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
217+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
218218 outputs_pyt = model (input )
219219 outputs_trt = trt_module (input )
220220 for idx in range (len (outputs_pyt )):
@@ -224,14 +224,13 @@ def forward(self, x):
224224 msg = f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
225225 )
226226
227- # TODO: Enable this serialization issues are fixed
228- # outputs_trt_deser = deser_trt_module(input)
229- # for idx in range(len(outputs_pyt)):
230- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
231- # assertions.assertTrue(
232- # cos_sim > COSINE_THRESHOLD,
233- # msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
234- # )
227+ outputs_trt_deser = deser_trt_module (input )
228+ for idx in range (len (outputs_pyt )):
229+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
230+ assertions .assertTrue (
231+ cos_sim > COSINE_THRESHOLD ,
232+ msg = f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
233+ )
235234
236235
237236@pytest .mark .unit
@@ -254,9 +253,8 @@ def test_resnet18(ir):
254253
255254 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
256255 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
257- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
258- # TODO: Enable this serialization issues are fixed
259- # deser_trt_module = torchtrt.load("./trt.ep").module()
256+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
257+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
260258 outputs_pyt = model (input )
261259 outputs_trt = trt_module (input )
262260 cos_sim = cosine_similarity (outputs_pyt , outputs_trt [0 ])
@@ -265,13 +263,13 @@ def test_resnet18(ir):
265263 msg = f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
266264 )
267265
268- # TODO: Enable this serialization issues are fixed
269- # outputs_trt_deser = deser_trt_module(input)
270- # cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
271- # assertions.assertTrue(
272- # cos_sim > COSINE_THRESHOLD,
273- # msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
274- # )
266+ outputs_trt_deser = deser_trt_module ( input )
267+
268+ cos_sim = cosine_similarity (outputs_pyt , outputs_trt_deser [0 ])
269+ assertions .assertTrue (
270+ cos_sim > COSINE_THRESHOLD ,
271+ msg = f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
272+ )
275273
276274
277275@pytest .mark .unit
@@ -310,9 +308,8 @@ def forward(self, x):
310308 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
311309 trt_module = torchtrt .dynamo .compile (exp_program , ** compile_spec )
312310
313- torchtrt .save (trt_module , "./trt.ep" , inputs = [input ])
314- # TODO: Enable this serialization issues are fixed
315- # deser_trt_module = torchtrt.load("./trt.ep").module()
311+ torchtrt .save (trt_module , "/tmp/trt.ep" , inputs = [input ])
312+ deser_trt_module = torchtrt .load ("/tmp/trt.ep" ).module ()
316313 outputs_pyt = model (input )
317314 outputs_trt = trt_module (input )
318315
@@ -323,14 +320,13 @@ def forward(self, x):
323320 msg = f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
324321 )
325322
326- # TODO: Enable this serialization issues are fixed
327- # outputs_trt_deser = deser_trt_module(input)
328- # for idx in range(len(outputs_pyt)):
329- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
330- # assertions.assertTrue(
331- # cos_sim > COSINE_THRESHOLD,
332- # msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
333- # )
323+ outputs_trt_deser = deser_trt_module (input )
324+ for idx in range (len (outputs_pyt )):
325+ cos_sim = cosine_similarity (outputs_pyt [idx ], outputs_trt_deser [idx ])
326+ assertions .assertTrue (
327+ cos_sim > COSINE_THRESHOLD ,
328+ msg = f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
329+ )
334330
335331
336332@pytest .mark .unit
@@ -361,9 +357,9 @@ def forward(self, x):
361357 )
362358 outputs_trt = trt_gm (input )
363359 # Save it as torchscript representation
364- torchtrt .save (trt_gm , ". /trt.ts" , output_format = "torchscript" , inputs = [input ])
360+ torchtrt .save (trt_gm , "/tmp /trt.ts" , output_format = "torchscript" , inputs = [input ])
365361
366- trt_ts_module = torchtrt .load (". /trt.ts" )
362+ trt_ts_module = torchtrt .load ("/tmp /trt.ts" )
367363 outputs_trt_deser = trt_ts_module (input )
368364
369365 cos_sim = cosine_similarity (outputs_trt , outputs_trt_deser )
0 commit comments