@@ -46,6 +46,30 @@ def test_compile_script(self):
4646 self .assertTrue (same < 2e-3 )
4747
4848
49+ class TestPTtoTRTtoPT (ModelTestCase ):
50+
51+ def setUp (self ):
52+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
53+ self .ts_model = torch .jit .script (self .model )
54+
55+ def test_pt_to_trt_to_pt (self ):
56+ compile_spec = {
57+ "input_shapes" : [self .input .shape ],
58+ "device" : {
59+ "device_type" : trtorch .DeviceType .GPU ,
60+ "gpu_id" : 0 ,
61+ "dla_core" : 0 ,
62+ "allow_gpu_fallback" : False ,
63+ "disable_tf32" : False
64+ }
65+ }
66+
67+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
68+ trt_mod = trtorch .embed_engine_in_new_module (trt_engine )
69+ same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
70+ self .assertTrue (same < 2e-3 )
71+
72+
4973class TestCheckMethodOpSupport (unittest .TestCase ):
5074
5175 def setUp (self ):
@@ -59,13 +83,13 @@ def test_check_support(self):
5983class TestLoggingAPIs (unittest .TestCase ):
6084
6185 def test_logging_prefix (self ):
62- new_prefix = "TEST "
86+ new_prefix = "Python API Test: "
6387 trtorch .logging .set_logging_prefix (new_prefix )
6488 logging_prefix = trtorch .logging .get_logging_prefix ()
6589 self .assertEqual (new_prefix , logging_prefix )
6690
6791 def test_reportable_log_level (self ):
68- new_level = trtorch .logging .Level .Warning
92+ new_level = trtorch .logging .Level .Error
6993 trtorch .logging .set_reportable_log_level (new_level )
7094 level = trtorch .logging .get_reportable_log_level ()
7195 self .assertEqual (new_level , level )
@@ -78,10 +102,11 @@ def test_is_colored_output_on(self):
78102
79103def test_suite ():
80104 suite = unittest .TestSuite ()
105+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
81106 suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
82107 suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
108+ suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
83109 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
84- suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
85110
86111 return suite
87112
0 commit comments