2
2
import numbers
3
3
import warnings
4
4
from enum import Enum
5
- from typing import List , Tuple , Any , Optional
5
+ from typing import List , Tuple , Any , Optional , Union
6
6
7
7
import numpy as np
8
8
import torch
@@ -948,12 +948,48 @@ def _get_inverse_affine_matrix(
948
948
return matrix
949
949
950
950
951
+ def _get_inverse_affine_matrix_tensor (
952
+ center : Tensor , angle : Tensor , translate : Tensor , scale : Tensor , shear : Tensor
953
+ ) -> Tensor :
954
+ output = torch .zeros (3 , 3 )
955
+
956
+ rot = angle * torch .pi / 180.0
957
+ shear_rad = shear * torch .pi / 180.0
958
+
959
+ m_center = torch .eye (3 , 3 )
960
+ m_center [:2 , 2 ] = center
961
+
962
+ i_m_center = torch .eye (3 , 3 )
963
+ i_m_center [:2 , 2 ] = - center
964
+
965
+ i_m_translate = torch .eye (3 , 3 )
966
+ i_m_translate [:2 , 2 ] = - translate
967
+
968
+ # RSS without scaling
969
+ sx , sy = shear_rad [0 ], shear_rad [1 ]
970
+ a = torch .cos (rot - sy ) / torch .cos (sy )
971
+ b = torch .cos (rot - sy ) * torch .tan (sx ) / torch .cos (sy ) + torch .sin (rot )
972
+ c = - torch .sin (rot - sy ) / torch .cos (sy )
973
+ d = - torch .sin (rot - sy ) * torch .tan (sx ) / torch .cos (sy ) + torch .cos (rot )
974
+
975
+ output [0 , 0 ] = d
976
+ output [0 , 1 ] = b
977
+ output [1 , 0 ] = c
978
+ output [1 , 1 ] = a
979
+ output = output / scale
980
+ output [2 , 2 ] = 1.0
981
+
982
+ output = torch .chain_matmul (m_center , output , i_m_center , i_m_translate )
983
+ output = output [:2 , :]
984
+ return output
985
+
986
+
951
987
def rotate (
952
988
img : Tensor ,
953
- angle : float ,
989
+ angle : Union [ float , int , Tensor ] ,
954
990
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
955
991
expand : bool = False ,
956
- center : Optional [List [int ]] = None ,
992
+ center : Optional [Union [ List [int ], Tuple [ int , int ], Tensor ]] = None ,
957
993
fill : Optional [List [float ]] = None ,
958
994
resample : Optional [int ] = None ,
959
995
) -> Tensor :
@@ -963,7 +999,7 @@ def rotate(
963
999
964
1000
Args:
965
1001
img (PIL Image or Tensor): image to be rotated.
966
- angle (number): rotation angle value in degrees, counter-clockwise.
1002
+ angle (number or Tensor ): rotation angle value in degrees, counter-clockwise.
967
1003
interpolation (InterpolationMode): Desired interpolation enum defined by
968
1004
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
969
1005
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
@@ -972,7 +1008,7 @@ def rotate(
972
1008
If true, expands the output image to make it large enough to hold the entire rotated image.
973
1009
If false or omitted, make the output image the same size as the input image.
974
1010
Note that the expand flag assumes rotation around the center and no translation.
975
- center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1011
+ center (sequence or Tensor , optional): Optional center of rotation. Origin is the upper left corner.
976
1012
Default is the center of the image.
977
1013
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
978
1014
image. If given a number, the value is used for all bands respectively.
@@ -1001,28 +1037,48 @@ def rotate(
1001
1037
)
1002
1038
interpolation = _interpolation_modes_from_int (interpolation )
1003
1039
1004
- if not isinstance (angle , (int , float )):
1005
- raise TypeError ("Argument angle should be int or float" )
1040
+ if not isinstance (angle , (int , float , Tensor )):
1041
+ raise TypeError ("Argument angle should be int or float or Tensor " )
1006
1042
1007
- if center is not None and not isinstance (center , (list , tuple )):
1008
- raise TypeError ("Argument center should be a sequence" )
1043
+ if center is not None and not isinstance (center , (list , tuple , Tensor )):
1044
+ raise TypeError ("Argument center should be a sequence or a Tensor " )
1009
1045
1010
1046
if not isinstance (interpolation , InterpolationMode ):
1011
1047
raise TypeError ("Argument interpolation should be a InterpolationMode" )
1012
1048
1013
1049
if not isinstance (img , torch .Tensor ):
1050
+ if not isinstance (angle , (int , float )):
1051
+ raise TypeError ("Argument angle should be int or float" )
1052
+
1053
+ if center is not None and not isinstance (center , (list , tuple )):
1054
+ raise TypeError ("Argument center should be a sequence" )
1055
+
1014
1056
pil_interpolation = pil_modes_mapping [interpolation ]
1015
1057
return F_pil .rotate (img , angle = angle , interpolation = pil_interpolation , expand = expand , center = center , fill = fill )
1016
1058
1017
- center_f = [0.0 , 0.0 ]
1059
+ if isinstance (angle , torch .Tensor ) and angle .requires_grad :
1060
+ # assert img.dtype is float
1061
+ pass
1062
+
1063
+ center_t = torch .tensor ([0.0 , 0.0 ])
1018
1064
if center is not None :
1019
- img_size = get_image_size (img )
1065
+ # ct = torch.tensor([float(c) for c in list(center)]) if not isinstance(center, Tensor) else center
1066
+ # THIS DOES NOT PASS JIT as we mix list/tuple of ints but list/tuple of floats are required
1067
+ ct = torch .tensor (center ) if not isinstance (center , Tensor ) else center
1068
+ img_size = torch .tensor (get_image_size (img ))
1020
1069
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1021
- center_f = [ 1.0 * (c - s * 0.5 ) for c , s in zip ( center , img_size )]
1070
+ center_t = 1.0 * (ct - img_size * 0.5 )
1022
1071
1023
1072
# due to current incoherence of rotation angle direction between affine and rotate implementations
1024
1073
# we need to set -angle.
1025
- matrix = _get_inverse_affine_matrix (center_f , - angle , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ])
1074
+ angle_t = torch .tensor (float (angle )) if not isinstance (angle , Tensor ) else angle
1075
+ matrix = _get_inverse_affine_matrix_tensor (
1076
+ center_t ,
1077
+ - angle_t ,
1078
+ torch .tensor ([0.0 , 0.0 ]),
1079
+ torch .tensor (1.0 ),
1080
+ torch .tensor ([0.0 , 0.0 ])
1081
+ )
1026
1082
return F_t .rotate (img , matrix = matrix , interpolation = interpolation .value , expand = expand , fill = fill )
1027
1083
1028
1084
0 commit comments