@@ -1039,7 +1039,7 @@ def __repr__(self):
1039
1039
1040
1040
1041
1041
class ColorJitter (torch .nn .Module ):
1042
- """Randomly change the brightness, contrast and saturation of an image.
1042
+ """Randomly change the brightness, contrast, saturation, hue and sharpness of an image.
1043
1043
1044
1044
Args:
1045
1045
brightness (float or tuple of float (min, max)): How much to jitter brightness.
@@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module):
1054
1054
hue (float or tuple of float (min, max)): How much to jitter hue.
1055
1055
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
1056
1056
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1057
+ sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
1058
+ sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
1059
+ or the given [min, max]. Should be non negative numbers.
1057
1060
"""
1058
1061
1059
- def __init__ (self , brightness = 0 , contrast = 0 , saturation = 0 , hue = 0 ):
1062
+ def __init__ (self , brightness = 0 , contrast = 0 , saturation = 0 , hue = 0 , sharpness = 0 ):
1060
1063
super ().__init__ ()
1061
1064
self .brightness = self ._check_input (brightness , 'brightness' )
1062
1065
self .contrast = self ._check_input (contrast , 'contrast' )
1063
1066
self .saturation = self ._check_input (saturation , 'saturation' )
1064
1067
self .hue = self ._check_input (hue , 'hue' , center = 0 , bound = (- 0.5 , 0.5 ),
1065
1068
clip_first_on_zero = False )
1069
+ self .sharpness = self ._check_input (sharpness , 'sharpness' )
1066
1070
1067
1071
@torch .jit .unused
1068
1072
def _check_input (self , value , name , center = 1 , bound = (0 , float ('inf' )), clip_first_on_zero = True ):
@@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
1078
1082
else :
1079
1083
raise TypeError ("{} should be a single number or a list/tuple with lenght 2." .format (name ))
1080
1084
1081
- # if value is 0 or (1., 1.) for brightness/contrast/saturation
1085
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
1082
1086
# or (0., 0.) for hue, do nothing
1083
1087
if value [0 ] == value [1 ] == center :
1084
1088
value = None
@@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
1088
1092
def get_params (brightness : Optional [List [float ]],
1089
1093
contrast : Optional [List [float ]],
1090
1094
saturation : Optional [List [float ]],
1091
- hue : Optional [List [float ]]
1092
- ) -> Tuple [Tensor , Optional [float ], Optional [float ], Optional [float ], Optional [float ]]:
1095
+ hue : Optional [List [float ]],
1096
+ sharpness : Optional [List [float ]]
1097
+ ) -> Tuple [Tensor , Optional [float ], Optional [float ], Optional [float ], Optional [float ],
1098
+ Optional [float ]]:
1093
1099
"""Get the parameters for the randomized transform to be applied on image.
1094
1100
1095
1101
Args:
@@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]],
1101
1107
uniformly. Pass None to turn off the transformation.
1102
1108
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
1103
1109
Pass None to turn off the transformation.
1110
+ sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
1111
+ uniformly. Pass None to turn off the transformation.
1104
1112
1105
1113
Returns:
1106
1114
tuple: The parameters used to apply the randomized transform
1107
1115
along with their random order.
1108
1116
"""
1109
- fn_idx = torch .randperm (4 )
1117
+ fn_idx = torch .randperm (5 )
1110
1118
1111
1119
b = None if brightness is None else float (torch .empty (1 ).uniform_ (brightness [0 ], brightness [1 ]))
1112
1120
c = None if contrast is None else float (torch .empty (1 ).uniform_ (contrast [0 ], contrast [1 ]))
1113
1121
s = None if saturation is None else float (torch .empty (1 ).uniform_ (saturation [0 ], saturation [1 ]))
1114
1122
h = None if hue is None else float (torch .empty (1 ).uniform_ (hue [0 ], hue [1 ]))
1123
+ sp = None if sharpness is None else float (torch .empty (1 ).uniform_ (sharpness [0 ], sharpness [1 ]))
1115
1124
1116
- return fn_idx , b , c , s , h
1125
+ return fn_idx , b , c , s , h , sp
1117
1126
1118
1127
def forward (self , img ):
1119
1128
"""
@@ -1123,8 +1132,8 @@ def forward(self, img):
1123
1132
Returns:
1124
1133
PIL Image or Tensor: Color jittered image.
1125
1134
"""
1126
- fn_idx , brightness_factor , contrast_factor , saturation_factor , hue_factor = \
1127
- self .get_params (self .brightness , self .contrast , self .saturation , self .hue )
1135
+ fn_idx , brightness_factor , contrast_factor , saturation_factor , hue_factor , sharpness_factor = \
1136
+ self .get_params (self .brightness , self .contrast , self .saturation , self .hue , self . sharpness )
1128
1137
1129
1138
for fn_id in fn_idx :
1130
1139
if fn_id == 0 and brightness_factor is not None :
@@ -1135,6 +1144,8 @@ def forward(self, img):
1135
1144
img = F .adjust_saturation (img , saturation_factor )
1136
1145
elif fn_id == 3 and hue_factor is not None :
1137
1146
img = F .adjust_hue (img , hue_factor )
1147
+ elif fn_id == 4 and sharpness_factor is not None :
1148
+ img = F .adjust_sharpness (img , sharpness_factor )
1138
1149
1139
1150
return img
1140
1151
@@ -1143,7 +1154,8 @@ def __repr__(self):
1143
1154
format_string += 'brightness={0}' .format (self .brightness )
1144
1155
format_string += ', contrast={0}' .format (self .contrast )
1145
1156
format_string += ', saturation={0}' .format (self .saturation )
1146
- format_string += ', hue={0})' .format (self .hue )
1157
+ format_string += ', hue={0}' .format (self .hue )
1158
+ format_string += ', sharpness={0})' .format (self .sharpness )
1147
1159
return format_string
1148
1160
1149
1161
0 commit comments