Skip to content

Commit 80187fe

Browse files
datumboxvfdev-5
authored andcommitted
[BC-breaking] ColorJitter gets its random params by calling get_params() (pytorch#3001)
* ColorJitter gets its random params by calling get_params(). * Update arguments. * Styles. * Add description for Nones. * Chainging Nones to optional.
1 parent e4a45bb commit 80187fe

File tree

1 file changed

+30
-42
lines changed

1 file changed

+30
-42
lines changed

torchvision/transforms/transforms.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10511051
return value
10521052

10531053
@staticmethod
1054-
@torch.jit.unused
1055-
def get_params(brightness, contrast, saturation, hue):
1056-
"""Get a randomized transform to be applied on image.
1054+
def get_params(brightness: Optional[List[float]],
1055+
contrast: Optional[List[float]],
1056+
saturation: Optional[List[float]],
1057+
hue: Optional[List[float]]
1058+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1059+
"""Get the parameters for the randomized transform to be applied on image.
10571060
1058-
Arguments are same as that of __init__.
1061+
Args:
1062+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
1063+
uniformly. Pass None to turn off the transformation.
1064+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
1065+
uniformly. Pass None to turn off the transformation.
1066+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
1067+
uniformly. Pass None to turn off the transformation.
1068+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
1069+
Pass None to turn off the transformation.
10591070
10601071
Returns:
1061-
Transform which randomly adjusts brightness, contrast and
1062-
saturation in a random order.
1072+
tuple: The parameters used to apply the randomized transform
1073+
along with their random order.
10631074
"""
1064-
transforms = []
1065-
1066-
if brightness is not None:
1067-
brightness_factor = random.uniform(brightness[0], brightness[1])
1068-
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
1069-
1070-
if contrast is not None:
1071-
contrast_factor = random.uniform(contrast[0], contrast[1])
1072-
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
1073-
1074-
if saturation is not None:
1075-
saturation_factor = random.uniform(saturation[0], saturation[1])
1076-
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
1077-
1078-
if hue is not None:
1079-
hue_factor = random.uniform(hue[0], hue[1])
1080-
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
1075+
fn_idx = torch.randperm(4)
10811076

1082-
random.shuffle(transforms)
1083-
transform = Compose(transforms)
1077+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
1078+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
1079+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
1080+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
10841081

1085-
return transform
1082+
return fn_idx, b, c, s, h
10861083

10871084
def forward(self, img):
10881085
"""
@@ -1092,26 +1089,17 @@ def forward(self, img):
10921089
Returns:
10931090
PIL Image or Tensor: Color jittered image.
10941091
"""
1095-
fn_idx = torch.randperm(4)
1092+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
1093+
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
1094+
10961095
for fn_id in fn_idx:
1097-
if fn_id == 0 and self.brightness is not None:
1098-
brightness = self.brightness
1099-
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
1096+
if fn_id == 0 and brightness_factor is not None:
11001097
img = F.adjust_brightness(img, brightness_factor)
1101-
1102-
if fn_id == 1 and self.contrast is not None:
1103-
contrast = self.contrast
1104-
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
1098+
elif fn_id == 1 and contrast_factor is not None:
11051099
img = F.adjust_contrast(img, contrast_factor)
1106-
1107-
if fn_id == 2 and self.saturation is not None:
1108-
saturation = self.saturation
1109-
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
1100+
elif fn_id == 2 and saturation_factor is not None:
11101101
img = F.adjust_saturation(img, saturation_factor)
1111-
1112-
if fn_id == 3 and self.hue is not None:
1113-
hue = self.hue
1114-
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
1102+
elif fn_id == 3 and hue_factor is not None:
11151103
img = F.adjust_hue(img, hue_factor)
11161104

11171105
return img

0 commit comments

Comments
 (0)