@@ -230,6 +230,53 @@ def test_is_colored_output_on(self):
230230 self .assertTrue (color )
231231
232232
233+ class TestDevice (unittest .TestCase ):
234+
235+ def test_from_string_constructor (self ):
236+ device = trtorch .Device ("cuda:0" )
237+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
238+ self .assertEqual (device .gpu_id , 0 )
239+
240+ device = trtorch .Device ("gpu:1" )
241+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
242+ self .assertEqual (device .gpu_id , 1 )
243+
244+ def test_from_string_constructor_dla (self ):
245+ device = trtorch .Device ("dla:0" )
246+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
247+ self .assertEqual (device .gpu_id , 0 )
248+ self .assertEqual (device .dla_core , 0 )
249+
250+ device = trtorch .Device ("dla:1" , allow_gpu_fallback = True )
251+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
252+ self .assertEqual (device .gpu_id , 0 )
253+ self .assertEqual (device .dla_core , 1 )
254+ self .assertEqual (device .allow_gpu_fallback , True )
255+
256+ def test_kwargs_gpu (self ):
257+ device = trtorch .Device (gpu_id = 0 )
258+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
259+ self .assertEqual (device .gpu_id , 0 )
260+
261+ def test_kwargs_dla_and_settings (self ):
262+ device = trtorch .Device (dla_core = 1 , allow_gpu_fallback = False )
263+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
264+ self .assertEqual (device .gpu_id , 0 )
265+ self .assertEqual (device .dla_core , 1 )
266+ self .assertEqual (device .allow_gpu_fallback , False )
267+
268+ device = trtorch .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
269+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
270+ self .assertEqual (device .gpu_id , 1 )
271+ self .assertEqual (device .dla_core , 0 )
272+ self .assertEqual (device .allow_gpu_fallback , True )
273+
274+ def test_from_torch (self ):
275+ device = trtorch .Device ._from_torch_device (torch .device ("cuda:0" ))
276+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
277+ self .assertEqual (device .gpu_id , 0 )
278+
279+
233280def test_suite ():
234281 suite = unittest .TestSuite ()
235282 suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
@@ -242,6 +289,7 @@ def test_suite():
242289 suite .addTest (
243290 TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
244291 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
292+ suite .addTest (unittest .makeSuite (TestDevice ))
245293
246294 return suite
247295
0 commit comments