Description
This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from F.*
to accept learnable parameters (tensors with requires_grad=True
) and propagating the gradient.
For the motivation and the context, please see #5000
Proposal
Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,
x = torch.rand(1, 3, 32, 32)
alpha = 45
center = [1, 2]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor
The proposal is to be able to learn parameters like alpha
and center
using gradients descent:
x = torch.rand(1, 3, 32, 32)
- alpha = 45
+ alpha = torch.tensor(45.0, requires_grad=True)
- center = [1, 2]
+ center = torch.tensor([1.0, 2.0], requires_grad=True)]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor that requires grad
assert out.requires_grad
# parameters can have grads:
out.mean().backward() # some dummy criterion
assert alpha.grad is not None
assert center.grad is not None
and also keep previous API (no BC breaking changes).
Implementation
In terms of API, it would require updates like:
def rotate(
img: Tensor,
- angle: float,
+ angle: Union[float, int, Tensor],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
- center: Optional[List[int]] = None,
+ center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None,
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
) -> Tensor:
Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding Union[float, Tensor]
does not always work and can break compatibility with the stable version).
In terms of implementation, we have to ensure that:
- methods with updated parameters still support all previous data types
- methods are torch jit scriptable
- methods verify that input image is float tensor (no grad propagation otherwise)
- methods propagate grads for tensor inputs <=> all internal ops for tensor branch are propagating grads
- only floating parameters can accept values as Tensors
- for example, rotate with learnable floating angle
- IMO, we can't make output (integer) size learnable in resize op (please fix me if there is a way)
- certain integer parameters can be promoted to float, e.g. translate in affine
Example with affine and rotate ops : #5110
Transforms to update
- normalize, params: mean and std
- adjust_brightness, params: brightness_factor
- adjust_contrast, params: contrast_factor
- adjust_saturation, params: saturation_factor
- adjust_hue, params: hue_factor
- adjust_gamma, params: gamma, gain
- rotate, params: angle, center, Make F.rotate/F.affine accept learnable params #5110
- affine, params: angle, translate, scale, shear, Make F.rotate/F.affine accept learnable params #5110
- gaussian_blur, params: kernel_size, sigma ?
- posterize, params: bits ?
- solarize, params: threshold
- adjust_sharpness, params: sharpness_factor
Please comment here if I'm missing any op that we could add into the list.