File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -475,14 +475,13 @@ def forward(self, x):
475475 optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
476476 torch_model_results = fx_graph (* inputs ).detach ().cpu ()
477477
478- max_diff = float (
479- torch .max (torch .abs (optimized_model_results - torch_model_results ))
480- )
481- self .assertAlmostEqual (
482- max_diff ,
483- 0 ,
484- DECIMALS_OF_AGREEMENT ,
485- f"Select_scatter TRT outputs don't match with the original model." ,
478+ optimized_model_results_shape = optimized_model_results .size ()
479+ torch_model_results_shape = torch_model_results .size ()
480+
481+ self .assertEquals (
482+ optimized_model_results_shape ,
483+ torch_model_results_shape ,
484+ f"The optimized model results shape and torch model results shape should be equal in empty_like" ,
486485 )
487486
488487
You can’t perform that action at this time.
0 commit comments