@@ -45,6 +45,27 @@ def test_compile_script(self):
4545 same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
4646 self .assertTrue (same < 2e-3 )
4747
48+ class TestPTtoTRTtoPT (ModelTestCase ):
49+ def setUp (self ):
50+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
51+ self .ts_model = torch .jit .script (self .model )
52+
53+ def test_pt_to_trt_to_pt (self ):
54+ compile_spec = {
55+ "input_shapes" : [self .input .shape ],
56+ "device" : {
57+ "device_type" : trtorch .DeviceType .GPU ,
58+ "gpu_id" : 0 ,
59+ "dla_core" : 0 ,
60+ "allow_gpu_fallback" : False ,
61+ "disable_tf32" : False
62+ }
63+ }
64+
65+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
66+ trt_mod = trtorch .embed_engine_in_new_module (trt_engine )
67+ same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
68+ self .assertTrue (same < 2e-3 )
4869
4970class TestCheckMethodOpSupport (unittest .TestCase ):
5071
@@ -59,13 +80,13 @@ def test_check_support(self):
5980class TestLoggingAPIs (unittest .TestCase ):
6081
6182 def test_logging_prefix (self ):
62- new_prefix = "TEST "
83+ new_prefix = "Python API Test: "
6384 trtorch .logging .set_logging_prefix (new_prefix )
6485 logging_prefix = trtorch .logging .get_logging_prefix ()
6586 self .assertEqual (new_prefix , logging_prefix )
6687
6788 def test_reportable_log_level (self ):
68- new_level = trtorch .logging .Level .Warning
89+ new_level = trtorch .logging .Level .Error
6990 trtorch .logging .set_reportable_log_level (new_level )
7091 level = trtorch .logging .get_reportable_log_level ()
7192 self .assertEqual (new_level , level )
@@ -78,10 +99,11 @@ def test_is_colored_output_on(self):
7899
79100def test_suite ():
80101 suite = unittest .TestSuite ()
102+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
81103 suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
82104 suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
105+ suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
83106 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
84- suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
85107
86108 return suite
87109
0 commit comments