@@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
1051
1051
return value
1052
1052
1053
1053
@staticmethod
1054
- @torch .jit .unused
1055
- def get_params (brightness , contrast , saturation , hue ):
1056
- """Get a randomized transform to be applied on image.
1054
+ def get_params (brightness : Optional [List [float ]],
1055
+ contrast : Optional [List [float ]],
1056
+ saturation : Optional [List [float ]],
1057
+ hue : Optional [List [float ]]
1058
+ ) -> Tuple [Tensor , Optional [float ], Optional [float ], Optional [float ], Optional [float ]]:
1059
+ """Get the parameters for the randomized transform to be applied on image.
1057
1060
1058
- Arguments are same as that of __init__.
1061
+ Args:
1062
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
1063
+ uniformly. Pass None to turn off the transformation.
1064
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
1065
+ uniformly. Pass None to turn off the transformation.
1066
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
1067
+ uniformly. Pass None to turn off the transformation.
1068
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
1069
+ Pass None to turn off the transformation.
1059
1070
1060
1071
Returns:
1061
- Transform which randomly adjusts brightness, contrast and
1062
- saturation in a random order.
1072
+ tuple: The parameters used to apply the randomized transform
1073
+ along with their random order.
1063
1074
"""
1064
- transforms = []
1065
-
1066
- if brightness is not None :
1067
- brightness_factor = random .uniform (brightness [0 ], brightness [1 ])
1068
- transforms .append (Lambda (lambda img : F .adjust_brightness (img , brightness_factor )))
1069
-
1070
- if contrast is not None :
1071
- contrast_factor = random .uniform (contrast [0 ], contrast [1 ])
1072
- transforms .append (Lambda (lambda img : F .adjust_contrast (img , contrast_factor )))
1073
-
1074
- if saturation is not None :
1075
- saturation_factor = random .uniform (saturation [0 ], saturation [1 ])
1076
- transforms .append (Lambda (lambda img : F .adjust_saturation (img , saturation_factor )))
1077
-
1078
- if hue is not None :
1079
- hue_factor = random .uniform (hue [0 ], hue [1 ])
1080
- transforms .append (Lambda (lambda img : F .adjust_hue (img , hue_factor )))
1075
+ fn_idx = torch .randperm (4 )
1081
1076
1082
- random .shuffle (transforms )
1083
- transform = Compose (transforms )
1077
+ b = None if brightness is None else float (torch .empty (1 ).uniform_ (brightness [0 ], brightness [1 ]))
1078
+ c = None if contrast is None else float (torch .empty (1 ).uniform_ (contrast [0 ], contrast [1 ]))
1079
+ s = None if saturation is None else float (torch .empty (1 ).uniform_ (saturation [0 ], saturation [1 ]))
1080
+ h = None if hue is None else float (torch .empty (1 ).uniform_ (hue [0 ], hue [1 ]))
1084
1081
1085
- return transform
1082
+ return fn_idx , b , c , s , h
1086
1083
1087
1084
def forward (self , img ):
1088
1085
"""
@@ -1092,26 +1089,17 @@ def forward(self, img):
1092
1089
Returns:
1093
1090
PIL Image or Tensor: Color jittered image.
1094
1091
"""
1095
- fn_idx = torch .randperm (4 )
1092
+ fn_idx , brightness_factor , contrast_factor , saturation_factor , hue_factor = \
1093
+ self .get_params (self .brightness , self .contrast , self .saturation , self .hue )
1094
+
1096
1095
for fn_id in fn_idx :
1097
- if fn_id == 0 and self .brightness is not None :
1098
- brightness = self .brightness
1099
- brightness_factor = torch .tensor (1.0 ).uniform_ (brightness [0 ], brightness [1 ]).item ()
1096
+ if fn_id == 0 and brightness_factor is not None :
1100
1097
img = F .adjust_brightness (img , brightness_factor )
1101
-
1102
- if fn_id == 1 and self .contrast is not None :
1103
- contrast = self .contrast
1104
- contrast_factor = torch .tensor (1.0 ).uniform_ (contrast [0 ], contrast [1 ]).item ()
1098
+ elif fn_id == 1 and contrast_factor is not None :
1105
1099
img = F .adjust_contrast (img , contrast_factor )
1106
-
1107
- if fn_id == 2 and self .saturation is not None :
1108
- saturation = self .saturation
1109
- saturation_factor = torch .tensor (1.0 ).uniform_ (saturation [0 ], saturation [1 ]).item ()
1100
+ elif fn_id == 2 and saturation_factor is not None :
1110
1101
img = F .adjust_saturation (img , saturation_factor )
1111
-
1112
- if fn_id == 3 and self .hue is not None :
1113
- hue = self .hue
1114
- hue_factor = torch .tensor (1.0 ).uniform_ (hue [0 ], hue [1 ]).item ()
1102
+ elif fn_id == 3 and hue_factor is not None :
1115
1103
img = F .adjust_hue (img , hue_factor )
1116
1104
1117
1105
return img
0 commit comments