Skip to content

[RFC] Make transforms.functional methods differential w.r.t. their parameters #5157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
12 tasks
vfdev-5 opened this issue Jan 4, 2022 · 1 comment
Open
12 tasks

Comments

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 4, 2022

This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from F.* to accept learnable parameters (tensors with requires_grad=True) and propagating the gradient.

For the motivation and the context, please see #5000

Proposal

Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,

x = torch.rand(1, 3, 32, 32)
alpha = 45
center = [1, 2]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor

The proposal is to be able to learn parameters like alpha and center using gradients descent:

x = torch.rand(1, 3, 32, 32)
- alpha = 45
+ alpha = torch.tensor(45.0, requires_grad=True)
- center = [1, 2]
+ center = torch.tensor([1.0, 2.0], requires_grad=True)]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor that requires grad
assert out.requires_grad

# parameters can have grads:
out.mean().backward()  # some dummy criterion
assert alpha.grad is not None
assert center.grad is not None

and also keep previous API (no BC breaking changes).

Implementation

In terms of API, it would require updates like:

def rotate(
    img: Tensor,
-   angle: float,
+   angle: Union[float, int, Tensor],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
-   center: Optional[List[int]] = None,
+   center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
) -> Tensor:

Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding Union[float, Tensor] does not always work and can break compatibility with the stable version).

In terms of implementation, we have to ensure that:

  • methods with updated parameters still support all previous data types
  • methods are torch jit scriptable
  • methods verify that input image is float tensor (no grad propagation otherwise)
  • methods propagate grads for tensor inputs <=> all internal ops for tensor branch are propagating grads
  • only floating parameters can accept values as Tensors
    • for example, rotate with learnable floating angle
    • IMO, we can't make output (integer) size learnable in resize op (please fix me if there is a way)
    • certain integer parameters can be promoted to float, e.g. translate in affine

Example with affine and rotate ops : #5110

Transforms to update

Please comment here if I'm missing any op that we could add into the list.

cc @vfdev-5 @datumbox

@ain-soph
Copy link
Contributor

ain-soph commented Jan 4, 2022

I think gaussian_blur kernel_size, posterize and solarize are not differentiable in mathematics. Maybe we can just ignore them?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants