Skip to content

Commit 6b829be

Browse files
committed
Adding transforms for sharpness.
1 parent 9a12648 commit 6b829be

File tree

3 files changed

+62
-12
lines changed

3 files changed

+62
-12
lines changed

test/test_transforms.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,37 @@ def test_adjust_hue(self):
12321232
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
12331233
self.assertTrue(np.allclose(y_np, y_ans))
12341234

1235+
def test_adjust_sharpness(self):
1236+
x_shape = [4, 4, 3]
1237+
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1238+
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
1239+
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1240+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1241+
x_pil = Image.fromarray(x_np, mode='RGB')
1242+
1243+
# test 0
1244+
y_pil = F.adjust_sharpness(x_pil, 1)
1245+
y_np = np.array(y_pil)
1246+
self.assertTrue(np.allclose(y_np, x_np))
1247+
1248+
# test 1
1249+
y_pil = F.adjust_sharpness(x_pil, 0.5)
1250+
y_np = np.array(y_pil)
1251+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
1252+
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
1253+
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1254+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1255+
self.assertTrue(np.allclose(y_np, y_ans))
1256+
1257+
# test 2
1258+
y_pil = F.adjust_sharpness(x_pil, 2)
1259+
y_np = np.array(y_pil)
1260+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1261+
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
1262+
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1263+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1264+
self.assertTrue(np.allclose(y_np, y_ans))
1265+
12351266
def test_adjust_gamma(self):
12361267
x_shape = [2, 2, 3]
12371268
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
@@ -1268,10 +1299,11 @@ def test_adjusts_L_mode(self):
12681299
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
12691300
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
12701301
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
1302+
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
12711303
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
12721304

12731305
def test_color_jitter(self):
1274-
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
1306+
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2)
12751307

12761308
x_shape = [2, 2, 3]
12771309
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]

test/test_transforms_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,14 @@ def test_color_jitter(self):
131131
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
132132
)
133133

134+
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
135+
meth_kwargs = {"sharpness": f}
136+
self._test_class_op(
137+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
138+
)
139+
134140
# All 4 parameters together
135-
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
141+
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2}
136142
self._test_class_op(
137143
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
138144
)

torchvision/transforms/transforms.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ def __repr__(self):
10391039

10401040

10411041
class ColorJitter(torch.nn.Module):
1042-
"""Randomly change the brightness, contrast and saturation of an image.
1042+
"""Randomly change the brightness, contrast, saturation, hue and sharpness of an image.
10431043
10441044
Args:
10451045
brightness (float or tuple of float (min, max)): How much to jitter brightness.
@@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module):
10541054
hue (float or tuple of float (min, max)): How much to jitter hue.
10551055
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
10561056
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1057+
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
1058+
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
1059+
or the given [min, max]. Should be non negative numbers.
10571060
"""
10581061

1059-
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1062+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0):
10601063
super().__init__()
10611064
self.brightness = self._check_input(brightness, 'brightness')
10621065
self.contrast = self._check_input(contrast, 'contrast')
10631066
self.saturation = self._check_input(saturation, 'saturation')
10641067
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
10651068
clip_first_on_zero=False)
1069+
self.sharpness = self._check_input(sharpness, 'sharpness')
10661070

10671071
@torch.jit.unused
10681072
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
@@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10781082
else:
10791083
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
10801084

1081-
# if value is 0 or (1., 1.) for brightness/contrast/saturation
1085+
# if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
10821086
# or (0., 0.) for hue, do nothing
10831087
if value[0] == value[1] == center:
10841088
value = None
@@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10881092
def get_params(brightness: Optional[List[float]],
10891093
contrast: Optional[List[float]],
10901094
saturation: Optional[List[float]],
1091-
hue: Optional[List[float]]
1092-
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1095+
hue: Optional[List[float]],
1096+
sharpness: Optional[List[float]]
1097+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float],
1098+
Optional[float]]:
10931099
"""Get the parameters for the randomized transform to be applied on image.
10941100
10951101
Args:
@@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]],
11011107
uniformly. Pass None to turn off the transformation.
11021108
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
11031109
Pass None to turn off the transformation.
1110+
sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
1111+
uniformly. Pass None to turn off the transformation.
11041112
11051113
Returns:
11061114
tuple: The parameters used to apply the randomized transform
11071115
along with their random order.
11081116
"""
1109-
fn_idx = torch.randperm(4)
1117+
fn_idx = torch.randperm(5)
11101118

11111119
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
11121120
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
11131121
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
11141122
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1123+
sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1]))
11151124

1116-
return fn_idx, b, c, s, h
1125+
return fn_idx, b, c, s, h, sp
11171126

11181127
def forward(self, img):
11191128
"""
@@ -1123,8 +1132,8 @@ def forward(self, img):
11231132
Returns:
11241133
PIL Image or Tensor: Color jittered image.
11251134
"""
1126-
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
1127-
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
1135+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \
1136+
self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness)
11281137

11291138
for fn_id in fn_idx:
11301139
if fn_id == 0 and brightness_factor is not None:
@@ -1135,6 +1144,8 @@ def forward(self, img):
11351144
img = F.adjust_saturation(img, saturation_factor)
11361145
elif fn_id == 3 and hue_factor is not None:
11371146
img = F.adjust_hue(img, hue_factor)
1147+
elif fn_id == 4 and sharpness_factor is not None:
1148+
img = F.adjust_sharpness(img, sharpness_factor)
11381149

11391150
return img
11401151

@@ -1143,7 +1154,8 @@ def __repr__(self):
11431154
format_string += 'brightness={0}'.format(self.brightness)
11441155
format_string += ', contrast={0}'.format(self.contrast)
11451156
format_string += ', saturation={0}'.format(self.saturation)
1146-
format_string += ', hue={0})'.format(self.hue)
1157+
format_string += ', hue={0}'.format(self.hue)
1158+
format_string += ', sharpness={0})'.format(self.sharpness)
11471159
return format_string
11481160

11491161

0 commit comments

Comments
 (0)