@@ -138,6 +138,22 @@ def make_one_hot_labels(
138
138
yield make_one_hot_label (extra_dims_ )
139
139
140
140
141
+ def make_segmentation_mask (size = None , * , num_categories = 80 , extra_dims = (), dtype = torch .long ):
142
+ size = size or torch .randint (16 , 33 , (2 ,)).tolist ()
143
+ shape = (* extra_dims , 1 , * size )
144
+ data = make_tensor (shape , low = 0 , high = num_categories , dtype = dtype )
145
+ return features .SegmentationMask (data )
146
+
147
+
148
+ def make_segmentation_masks (
149
+ image_sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 )),
150
+ dtypes = (torch .long ,),
151
+ extra_dims = ((), (4 ,), (2 , 3 )),
152
+ ):
153
+ for image_size , dtype , extra_dims_ in itertools .product (image_sizes , dtypes , extra_dims ):
154
+ yield make_segmentation_mask (size = image_size , dtype = dtype , extra_dims = extra_dims_ )
155
+
156
+
141
157
class SampleInput :
142
158
def __init__ (self , * args , ** kwargs ):
143
159
self .args = args
@@ -212,7 +228,7 @@ def resize_bounding_box():
212
228
@register_kernel_info_from_sample_inputs_fn
213
229
def affine_image_tensor ():
214
230
for image , angle , translate , scale , shear in itertools .product (
215
- make_images (extra_dims = ()),
231
+ make_images (extra_dims = ((), ( 4 ,) )),
216
232
[- 87 , 15 , 90 ], # angle
217
233
[5 , - 5 ], # translate
218
234
[0.77 , 1.27 ], # scale
@@ -248,6 +264,24 @@ def affine_bounding_box():
248
264
)
249
265
250
266
267
+ @register_kernel_info_from_sample_inputs_fn
268
+ def affine_segmentation_mask ():
269
+ for image , angle , translate , scale , shear in itertools .product (
270
+ make_segmentation_masks (extra_dims = ((), (4 ,))),
271
+ [- 87 , 15 , 90 ], # angle
272
+ [5 , - 5 ], # translate
273
+ [0.77 , 1.27 ], # scale
274
+ [0 , 12 ], # shear
275
+ ):
276
+ yield SampleInput (
277
+ image ,
278
+ angle = angle ,
279
+ translate = (translate , translate ),
280
+ scale = scale ,
281
+ shear = (shear , shear ),
282
+ )
283
+
284
+
251
285
@register_kernel_info_from_sample_inputs_fn
252
286
def rotate_bounding_box ():
253
287
for bounding_box , angle , expand , center in itertools .product (
@@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
444
478
np .testing .assert_allclose (out_box .cpu ().numpy (), a_out_box )
445
479
446
480
481
+ @pytest .mark .parametrize ("angle" , [- 54 , 56 ])
482
+ @pytest .mark .parametrize ("translate" , [- 7 , 8 ])
483
+ @pytest .mark .parametrize ("scale" , [0.89 , 1.12 ])
484
+ @pytest .mark .parametrize ("shear" , [4 ])
485
+ @pytest .mark .parametrize ("center" , [None , (12 , 14 )])
486
+ def test_correctness_affine_segmentation_mask (angle , translate , scale , shear , center ):
487
+ def _compute_expected_mask (mask , angle_ , translate_ , scale_ , shear_ , center_ ):
488
+ assert mask .ndim == 3 and mask .shape [0 ] == 1
489
+ affine_matrix = _compute_affine_matrix (angle_ , translate_ , scale_ , shear_ , center_ )
490
+ inv_affine_matrix = np .linalg .inv (affine_matrix )
491
+ inv_affine_matrix = inv_affine_matrix [:2 , :]
492
+
493
+ expected_mask = torch .zeros_like (mask .cpu ())
494
+ for out_y in range (expected_mask .shape [1 ]):
495
+ for out_x in range (expected_mask .shape [2 ]):
496
+ output_pt = np .array ([out_x + 0.5 , out_y + 0.5 , 1.0 ])
497
+ input_pt = np .floor (np .dot (inv_affine_matrix , output_pt )).astype (np .int32 )
498
+ in_x , in_y = input_pt [:2 ]
499
+ if 0 <= in_x < mask .shape [2 ] and 0 <= in_y < mask .shape [1 ]:
500
+ expected_mask [0 , out_y , out_x ] = mask [0 , in_y , in_x ]
501
+ return expected_mask .to (mask .device )
502
+
503
+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
504
+ output_mask = F .affine_segmentation_mask (
505
+ mask ,
506
+ angle = angle ,
507
+ translate = (translate , translate ),
508
+ scale = scale ,
509
+ shear = (shear , shear ),
510
+ center = center ,
511
+ )
512
+ if center is None :
513
+ center = [s // 2 for s in mask .shape [- 2 :][::- 1 ]]
514
+
515
+ if mask .ndim < 4 :
516
+ masks = [mask ]
517
+ else :
518
+ masks = [m for m in mask ]
519
+
520
+ expected_masks = []
521
+ for mask in masks :
522
+ expected_mask = _compute_expected_mask (mask , angle , (translate , translate ), scale , (shear , shear ), center )
523
+ expected_masks .append (expected_mask )
524
+ if len (expected_masks ) > 1 :
525
+ expected_masks = torch .stack (expected_masks )
526
+ else :
527
+ expected_masks = expected_masks [0 ]
528
+ torch .testing .assert_close (output_mask , expected_masks )
529
+
530
+
531
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
532
+ def test_correctness_affine_segmentation_mask_on_fixed_input (device ):
533
+ # Check transformation against known expected output and CPU/CUDA devices
534
+
535
+ # Create a fixed input segmentation mask with 2 square masks
536
+ # in top-left, bottom-left corners
537
+ mask = torch .zeros (1 , 32 , 32 , dtype = torch .long , device = device )
538
+ mask [0 , 2 :10 , 2 :10 ] = 1
539
+ mask [0 , 32 - 9 : 32 - 3 , 3 :9 ] = 2
540
+
541
+ # Rotate 90 degrees and scale
542
+ expected_mask = torch .rot90 (mask , k = - 1 , dims = (- 2 , - 1 ))
543
+ expected_mask = torch .nn .functional .interpolate (expected_mask [None , :].float (), size = (64 , 64 ), mode = "nearest" )
544
+ expected_mask = expected_mask [0 , :, 16 : 64 - 16 , 16 : 64 - 16 ].long ()
545
+
546
+ out_mask = F .affine_segmentation_mask (mask , 90 , [0.0 , 0.0 ], 64.0 / 32.0 , [0.0 , 0.0 ])
547
+
548
+ torch .testing .assert_close (out_mask , expected_mask )
549
+
550
+
447
551
@pytest .mark .parametrize ("angle" , range (- 90 , 90 , 56 ))
448
552
@pytest .mark .parametrize ("expand" , [True , False ])
449
553
@pytest .mark .parametrize ("center" , [None , (12 , 14 )])
0 commit comments