Skip to content

[RFC] Make transforms.functional methods differential w.r.t. their parameters #5157

Open
@vfdev-5

Description

@vfdev-5

This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from F.* to accept learnable parameters (tensors with requires_grad=True) and propagating the gradient.

For the motivation and the context, please see #5000

Proposal

Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,

x = torch.rand(1, 3, 32, 32)
alpha = 45
center = [1, 2]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor

The proposal is to be able to learn parameters like alpha and center using gradients descent:

x = torch.rand(1, 3, 32, 32)
- alpha = 45
+ alpha = torch.tensor(45.0, requires_grad=True)
- center = [1, 2]
+ center = torch.tensor([1.0, 2.0], requires_grad=True)]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor that requires grad
assert out.requires_grad

# parameters can have grads:
out.mean().backward()  # some dummy criterion
assert alpha.grad is not None
assert center.grad is not None

and also keep previous API (no BC breaking changes).

Implementation

In terms of API, it would require updates like:

def rotate(
    img: Tensor,
-   angle: float,
+   angle: Union[float, int, Tensor],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
-   center: Optional[List[int]] = None,
+   center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
) -> Tensor:

Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding Union[float, Tensor] does not always work and can break compatibility with the stable version).

In terms of implementation, we have to ensure that:

  • methods with updated parameters still support all previous data types
  • methods are torch jit scriptable
  • methods verify that input image is float tensor (no grad propagation otherwise)
  • methods propagate grads for tensor inputs <=> all internal ops for tensor branch are propagating grads
  • only floating parameters can accept values as Tensors
    • for example, rotate with learnable floating angle
    • IMO, we can't make output (integer) size learnable in resize op (please fix me if there is a way)
    • certain integer parameters can be promoted to float, e.g. translate in affine

Example with affine and rotate ops : #5110

Transforms to update

Please comment here if I'm missing any op that we could add into the list.

cc @vfdev-5 @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions