@@ -210,18 +210,22 @@ def affine_image_tensor(
210
210
fill : Optional [List [float ]] = None ,
211
211
center : Optional [List [float ]] = None ,
212
212
) -> torch .Tensor :
213
+ num_channels , height , width = img .shape [- 3 :]
214
+ extra_dims = img .shape [:- 3 ]
215
+ img = img .view (- 1 , num_channels , height , width )
216
+
213
217
angle , translate , shear , center = _affine_parse_args (angle , translate , scale , shear , interpolation , center )
214
218
215
219
center_f = [0.0 , 0.0 ]
216
220
if center is not None :
217
- _ , height , width = get_dimensions_image_tensor (img )
218
221
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
219
222
center_f = [1.0 * (c - s * 0.5 ) for c , s in zip (center , [width , height ])]
220
223
221
224
translate_f = [1.0 * t for t in translate ]
222
225
matrix = _get_inverse_affine_matrix (center_f , angle , translate_f , scale , shear )
223
226
224
- return _FT .affine (img , matrix , interpolation = interpolation .value , fill = fill )
227
+ output = _FT .affine (img , matrix , interpolation = interpolation .value , fill = fill )
228
+ return output .view (extra_dims + (num_channels , height , width ))
225
229
226
230
227
231
def affine_image_pil (
@@ -344,15 +348,15 @@ def affine_bounding_box(
344
348
345
349
346
350
def affine_segmentation_mask (
347
- img : torch .Tensor ,
351
+ mask : torch .Tensor ,
348
352
angle : float ,
349
353
translate : List [float ],
350
354
scale : float ,
351
355
shear : List [float ],
352
356
center : Optional [List [float ]] = None ,
353
357
) -> torch .Tensor :
354
358
return affine_image_tensor (
355
- img ,
359
+ mask ,
356
360
angle = angle ,
357
361
translate = translate ,
358
362
scale = scale ,
@@ -423,6 +427,10 @@ def rotate_image_tensor(
423
427
fill : Optional [List [float ]] = None ,
424
428
center : Optional [List [float ]] = None ,
425
429
) -> torch .Tensor :
430
+ num_channels , height , width = img .shape [- 3 :]
431
+ extra_dims = img .shape [:- 3 ]
432
+ img = img .view (- 1 , num_channels , height , width )
433
+
426
434
center_f = [0.0 , 0.0 ]
427
435
if center is not None :
428
436
if expand :
@@ -435,7 +443,8 @@ def rotate_image_tensor(
435
443
# due to current incoherence of rotation angle direction between affine and rotate implementations
436
444
# we need to set -angle.
437
445
matrix = _get_inverse_affine_matrix (center_f , - angle , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ])
438
- return _FT .rotate (img , matrix , interpolation = interpolation .value , expand = expand , fill = fill )
446
+ output = _FT .rotate (img , matrix , interpolation = interpolation .value , expand = expand , fill = fill )
447
+ return output .view (extra_dims + (num_channels , height , width ))
439
448
440
449
441
450
def rotate_image_pil (
@@ -518,15 +527,15 @@ def rotate(
518
527
def pad_image_tensor (
519
528
img : torch .Tensor , padding : Union [int , List [int ]], fill : Union [int , float ] = 0 , padding_mode : str = "constant"
520
529
) -> torch .Tensor :
521
- num_masks , height , width = img .shape [- 3 :]
530
+ num_channels , height , width = img .shape [- 3 :]
522
531
extra_dims = img .shape [:- 3 ]
523
532
524
533
padded_image = _FT .pad (
525
- img = img .view (- 1 , num_masks , height , width ), padding = padding , fill = fill , padding_mode = padding_mode
534
+ img = img .view (- 1 , num_channels , height , width ), padding = padding , fill = fill , padding_mode = padding_mode
526
535
)
527
536
528
537
new_height , new_width = padded_image .shape [- 2 :]
529
- return padded_image .view (extra_dims + (num_masks , new_height , new_width ))
538
+ return padded_image .view (extra_dims + (num_channels , new_height , new_width ))
530
539
531
540
532
541
# TODO: This should be removed once pytorch pad supports non-scalar padding values
0 commit comments