Skip to content

Commit f0b7000

Browse files
authored
allow integer parameters in ColorJitter (#7255)
1 parent 1e19d73 commit f0b7000

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchvision/prototype/transforms/_color.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@ def _check_input(
8080
if value is None:
8181
return None
8282

83-
if isinstance(value, float):
83+
if isinstance(value, (int, float)):
8484
if value < 0:
8585
raise ValueError(f"If {name} is a single number, it must be non negative.")
8686
value = [center - value, center + value]
8787
if clip_first_on_zero:
8888
value[0] = max(value[0], 0.0)
89-
elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
90-
raise TypeError(f"{name} should be a single number or a sequence with length 2.")
89+
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
90+
value = [float(v) for v in value]
91+
else:
92+
raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
9193

9294
if not bound[0] <= value[0] <= value[1] <= bound[1]:
9395
raise ValueError(f"{name} values should be between {bound}, but got {value}.")

0 commit comments

Comments
 (0)