diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 09e313e5bed..8ac0d857753 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -80,14 +80,16 @@ def _check_input( if value is None: return None - if isinstance(value, float): + if isinstance(value, (int, float)): if value < 0: raise ValueError(f"If {name} is a single number, it must be non negative.") value = [center - value, center + value] if clip_first_on_zero: value[0] = max(value[0], 0.0) - elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2): - raise TypeError(f"{name} should be a single number or a sequence with length 2.") + elif isinstance(value, collections.abc.Sequence) and len(value) == 2: + value = [float(v) for v in value] + else: + raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.") if not bound[0] <= value[0] <= value[1] <= bound[1]: raise ValueError(f"{name} values should be between {bound}, but got {value}.")