Skip to content

Fix some annotations in transforms v2 for JIT v1 compatibility #7252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 15, 2023

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Feb 15, 2023

TL;DR this reverts changes to the annotation of the fill and padding parameter that we made for v2, but turned out to not be compatible with the JIT behavior of v1.


We have three groups of transforms that take the fill parameter in v1:

  • AA family:
    fill (sequence or number, optional): Pixel fill value for the area outside the transformed
    image. If given a number, the value is used for all bands respectively.
  • Affine family:
    fill (sequence or number): Pixel fill value for the area outside the transformed
    image. Default is ``0``. If given a number, the value is used for all bands respectively.
  • transforms.Pad
    fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
    length 3, it is used to fill R, G, B channels respectively.
    This value is only used when the padding_mode is constant.
    Only number is supported for torch Tensor.
    Only int or tuple value is supported for PIL Image.

In eager mode v1 and v2 behave the same and we enforce that in our consistency tests. However, for JIT they behave differently. The transforms aren't annotating the fill parameter, so we have to look at the functionals:


Let's start with the AA and affine family, since they use the same annotation:

import torch
from torchvision.transforms import functional as F_v1
from torchvision.prototype.transforms import functional as F_v2

name = "rotate"
args = (torch.rand(3, 256, 256),)
kwargs = dict(angle=30)


for version, F in [
    ("v1", F_v1),
    ("v2", F_v2),
]:
    eager = getattr(F, name)
    scripted = torch.jit.script(eager)

    print(version, name)

    for fill in [None, 1, 0.5, [1], [0.5], (1,), (0.5,), [1, 0, 1], [0.1, 0.2, 0.3], (1, 0, 1), (0.1, 0.2, 0.3)]:
        try:
            eager(*args, **kwargs, fill=fill)
            eager_result = "PASS"
        except:
            eager_result = "FAIL"

        try:
            scripted(*args, **kwargs, fill=fill)
            scripted_result = "PASS"
        except:
            scripted_result = "FAIL"

        print(f"{str(fill):>15}: eager {eager_result}, scripted {scripted_result}")

    print("-" * 80)

On main this prints:

v1 rotate
           None: eager PASS, scripted PASS
              1: eager PASS, scripted FAIL
            0.5: eager PASS, scripted FAIL
            [1]: eager PASS, scripted PASS
          [0.5]: eager PASS, scripted PASS
           (1,): eager PASS, scripted PASS
         (0.5,): eager PASS, scripted PASS
      [1, 0, 1]: eager PASS, scripted PASS
[0.1, 0.2, 0.3]: eager PASS, scripted PASS
      (1, 0, 1): eager PASS, scripted PASS
(0.1, 0.2, 0.3): eager PASS, scripted PASS
--------------------------------------------------------------------------------
v2 rotate
           None: eager PASS, scripted PASS
              1: eager PASS, scripted PASS
            0.5: eager PASS, scripted PASS
            [1]: eager PASS, scripted FAIL
          [0.5]: eager PASS, scripted PASS
           (1,): eager PASS, scripted FAIL
         (0.5,): eager PASS, scripted FAIL
      [1, 0, 1]: eager PASS, scripted FAIL
[0.1, 0.2, 0.3]: eager PASS, scripted PASS
      (1, 0, 1): eager PASS, scripted FAIL
(0.1, 0.2, 0.3): eager PASS, scripted FAIL
--------------------------------------------------------------------------------
  • v1 does not work with scalar ints or floats, but passes for everything else
  • v2 does work for Python scalars, but doesn't for tuples or sequences of integers

So how did that happen? In v2 we changed the annotation to

FillTypeJIT = Union[int, float, List[float], None]

Meaning, failures for the list of integers are "expected" (we'll get to why later), but what happened to the tuples? v1 didn't annotate them either?

This is caused by some (undocumented) automagic of JIT. Annotating something with List[int] will automatically handle tuple inputs as well:

@torch.jit.script
def foo(data: List[int]) -> torch.Tensor:
    if isinstance(data, int):
        data = [data]
    return torch.tensor(data)

foo((1, 2, 3))

However, if correct the annotation to Union[int, List[int]], the automagic is no longer applied:

@torch.jit.script
def bar(data: Union[int, List[int]]) -> torch.Tensor:
    if isinstance(data, int):
        data = [data]
    return torch.tensor(data)

bar((1, 2, 3))
RuntimeError: bar() Expected a value of type 'Union[List[int], int]' for argument 'data' but instead found type 'tuple'.

Well, we could just add Tuple[int] to the new annotation, right? Nope. Tuple[int] is not the equivalent to List[int]. That would be Tuple[int, ...], but that is not supported by JIT. And since fill corresponds to the number of channels that will only be known at runtime, we cannot use Tuple[int, int, int] or any other fixed number. Meaning, we need to rely on the automagic for BC and need to revert our annotation changes.


So this is the end of the story? Nope again. As we saw above, F.pad uses different annotations. Let's run our script from above with

name = "pad"
args = (torch.rand(3, 256, 256),)
kwargs = dict(padding=[2])
v1 pad
           None: eager PASS, scripted FAIL
              1: eager PASS, scripted PASS
            0.5: eager PASS, scripted PASS
            [1]: eager FAIL, scripted FAIL
          [0.5]: eager FAIL, scripted FAIL
           (1,): eager FAIL, scripted FAIL
         (0.5,): eager FAIL, scripted FAIL
      [1, 0, 1]: eager FAIL, scripted FAIL
[0.1, 0.2, 0.3]: eager FAIL, scripted FAIL
      (1, 0, 1): eager FAIL, scripted FAIL
(0.1, 0.2, 0.3): eager FAIL, scripted FAIL
--------------------------------------------------------------------------------
v2 pad
           None: eager PASS, scripted PASS
              1: eager PASS, scripted PASS
            0.5: eager PASS, scripted PASS
            [1]: eager PASS, scripted FAIL
          [0.5]: eager PASS, scripted PASS
           (1,): eager PASS, scripted FAIL
         (0.5,): eager PASS, scripted FAIL
      [1, 0, 1]: eager PASS, scripted FAIL
[0.1, 0.2, 0.3]: eager PASS, scripted PASS
      (1, 0, 1): eager PASS, scripted FAIL
(0.1, 0.2, 0.3): eager PASS, scripted FAIL
--------------------------------------------------------------------------------
  • v1 only works with scalars even in eager mode (None does not work while scripting)
  • v2 added support for multi-channel fills in eager mode and even passes on some of them during scripting

Meaning, we can keep our new annotation for F.pad since the v2 variant supports a superset of the values of v1 in eager as in scripted mode.


What's left is the padding argument on F.pad:

You can probably see where this is going. Changing the script from above to

name = "pad"
args = (torch.rand(3, 256, 256),)
kwargs = dict()

and iterating over different padding values gives us:

v1 pad
           1: eager PASS, scripted FAIL
         [2]: eager PASS, scripted PASS
        (2,): eager PASS, scripted PASS
      [3, 4]: eager PASS, scripted PASS
      (3, 4): eager PASS, scripted PASS
[5, 6, 7, 8]: eager PASS, scripted PASS
(5, 6, 7, 8): eager PASS, scripted PASS
--------------------------------------------------------------------------------
v2 pad
           1: eager PASS, scripted PASS
         [2]: eager PASS, scripted PASS
        (2,): eager PASS, scripted FAIL
      [3, 4]: eager PASS, scripted PASS
      (3, 4): eager PASS, scripted FAIL
[5, 6, 7, 8]: eager PASS, scripted PASS
(5, 6, 7, 8): eager PASS, scripted FAIL
--------------------------------------------------------------------------------

Meaning, we need to revert the new annotation and rely on the automagic handling.


If this whole story wasn't so sad, this should probably should have been a blog post rather than a PR description.

cc @vfdev-5 @bjuncek

@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)


def xfail_jit_tuple_instead_of_list(name, *, reason=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually observed that our functionals didn't work for tuples, but missed to check if v1 enforces this. Since we have aligned the behavior now, we can also remove this helper as it is no longer in use.

xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were on the right track 🤦

@@ -450,21 +430,21 @@ def _full_affine_params(**partial_params):
]


def get_fills(*, num_channels, dtype, vector=True):
def get_fills(*, num_channels, dtype):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now make sure that we get all possible fill types.

@@ -12,7 +12,7 @@

D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
FillTypeJIT = Optional[List[float]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this to what we had in v1 ...

@@ -118,7 +118,7 @@ def resized_crop(
def pad(
self,
padding: Union[int, Sequence[int]],
fill: FillTypeJIT = None,
fill: Optional[Union[int, float, List[float]]] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but keep it for F.pad

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Philip.

IIUC, you didn't really add new tests to make sure everything is OK, and instead you removed some of the xfail marks so that the pre-existing tests are actually ran? Are we sure they cover all the cases we want to support?

Regardless, I'll approve to unblock so we can merge ASAP and test these new changes against #7159

@pmeier
Copy link
Collaborator Author

pmeier commented Feb 15, 2023

IIUC, you didn't really add new tests to make sure everything is OK, and instead you removed some of the xfail marks so that the pre-existing tests are actually ran? Are we sure they cover all the cases we want to support?

Yes and no. Yes, I've removed some xfails, but I also expanded the tested parameters. See #7252 (comment). Previously we didn't test for single value lists or tuples in general for fill.

@NicolasHug NicolasHug merged commit f9d1883 into pytorch:main Feb 15, 2023
@pmeier pmeier deleted the jit-fill branch February 15, 2023 10:43
@NicolasHug NicolasHug mentioned this pull request Feb 15, 2023
49 tasks
This was referenced Feb 15, 2023
facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2023
…ty (#7252)

Reviewed By: vmoens

Differential Revision: D44416629

fbshipit-source-id: ab4950cc6c3d313355f29c069838fb96fe9a2dbf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants