Skip to content

transforms.Pad returns error with tuple fill for tensors #3227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sagadre opened this issue Jan 7, 2021 · 4 comments
Open

transforms.Pad returns error with tuple fill for tensors #3227

sagadre opened this issue Jan 7, 2021 · 4 comments

Comments

@sagadre
Copy link

sagadre commented Jan 7, 2021

🐛 Bug

Passing a 3-tuple, i.e. Pad(5, fill=(117, 250, 87)) will return an error when when called on a tensor.

To Reproduce

Issue seems to be here:
https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional_tensor.py#L408

tuple fills and Tensors don't seem to be supported; however the docs suggest that that they should be:
https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Pad

seems like support for this feature exists for PIL Images but not Tensors. Here is the check in functional_pil.py.
https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional_pil.py#L130

cc @vfdev-5

@voldemortX
Copy link
Contributor

voldemortX commented Jan 7, 2021

@sagadre You are right, tensors do not support pad with sequence fills as mentioned in the current doc here.

#3224 listed some other things also. We are still updating the documentation at the momemt.

@voldemortX
Copy link
Contributor

voldemortX commented Jan 9, 2021

@datumbox I think the issue here is pad() use torch.nn.functional.pad() which only supports single value fill. A obvious workaround would be to call that function C (number of channels) times, but that may be, how do I put this, not very elegant.
Maybe we could just get rid of torch.nn.functional.pad() like this since we only need fill for constant pad:

out_h = img.shape[-2] + p[2] + p[3]
out_w = img.shape[-1] + p[0] + p[1]
out_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len(fill), 1, 1).expand(-1, -1, out_h, out_w).clone()
out_img[:, :, p[2]: out_h - p[3], p[0]: out_w - p[1]] = img

What do you think?

@datumbox
Copy link
Contributor

datumbox commented Jan 9, 2021

@voldemortX This could be a good workaround for someone who really wants to support the transformation on their codebase, but don't think this is something that we should add in TorchVision because it would be slow. A more elegant and fast approach would be to examine extending the upstream pad.

@voldemortX
Copy link
Contributor

voldemortX commented Jan 9, 2021

@voldemortX This could be a good workaround for someone who really wants to support the transformation on their codebase, but don't think this is something that we should add in TorchVision because it would be slow. A more elegant and fast approach would be to examine extending the upstream pad.

I just quickly checked out the workaround and I agree with you. This should be supported at pytorch instead of torchvision.

EDIT: Maybe I can check this out at pytorch next week. @datumbox
If someone really want to use it at the moment, they can try the workaround for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants