@@ -362,6 +362,16 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
362
362
363
363
spy .assert_called_once ()
364
364
365
+ @image_sample_inputs
366
+ def test_simple_tensor_output_type (self , info , args_kwargs ):
367
+ (image_datapoint , * other_args ), kwargs = args_kwargs .load ()
368
+ image_simple_tensor = image_datapoint .as_subclass (torch .Tensor )
369
+
370
+ output = info .dispatcher (image_simple_tensor , * other_args , ** kwargs )
371
+
372
+ # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
373
+ assert type (output ) is torch .Tensor
374
+
365
375
@make_info_args_kwargs_parametrization (
366
376
[info for info in DISPATCHER_INFOS if info .pil_kernel_info is not None ],
367
377
args_kwargs_fn = lambda info : info .sample_inputs (datapoints .Image ),
@@ -381,6 +391,22 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on):
381
391
382
392
spy .assert_called_once ()
383
393
394
+ @make_info_args_kwargs_parametrization (
395
+ [info for info in DISPATCHER_INFOS if info .pil_kernel_info is not None ],
396
+ args_kwargs_fn = lambda info : info .sample_inputs (datapoints .Image ),
397
+ )
398
+ def test_pil_output_type (self , info , args_kwargs ):
399
+ (image_datapoint , * other_args ), kwargs = args_kwargs .load ()
400
+
401
+ if image_datapoint .ndim > 3 :
402
+ pytest .skip ("Input is batched" )
403
+
404
+ image_pil = F .to_image_pil (image_datapoint )
405
+
406
+ output = info .dispatcher (image_pil , * other_args , ** kwargs )
407
+
408
+ assert isinstance (output , PIL .Image .Image )
409
+
384
410
@make_info_args_kwargs_parametrization (
385
411
DISPATCHER_INFOS ,
386
412
args_kwargs_fn = lambda info : info .sample_inputs (),
@@ -397,6 +423,17 @@ def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
397
423
398
424
spy .assert_called_once ()
399
425
426
+ @make_info_args_kwargs_parametrization (
427
+ DISPATCHER_INFOS ,
428
+ args_kwargs_fn = lambda info : info .sample_inputs (),
429
+ )
430
+ def test_datapoint_output_type (self , info , args_kwargs ):
431
+ (datapoint , * other_args ), kwargs = args_kwargs .load ()
432
+
433
+ output = info .dispatcher (datapoint , * other_args , ** kwargs )
434
+
435
+ assert isinstance (output , type (datapoint ))
436
+
400
437
@pytest .mark .parametrize (
401
438
("dispatcher_info" , "datapoint_type" , "kernel_info" ),
402
439
[
0 commit comments