-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Open
Description
Few years ago we introduced non-const fill value handling in _apply_grid_transform
using mask approach:
vision/torchvision/transforms/functional_tensor.py
Lines 550 to 568 in 0d69e35
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice | |
if fill is not None: | |
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) | |
img = torch.cat((img, dummy), dim=1) | |
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) | |
# Fill with required color | |
if fill is not None: | |
mask = img[:, -1:, :, :] # N * 1 * H * W | |
img = img[:, :-1, :, :] # N * C * H * W | |
mask = mask.expand_as(img) | |
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 | |
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) | |
if mode == "nearest": | |
mask = mask < 0.5 | |
img[mask] = fill_img[mask] | |
else: # 'bilinear' | |
img = img * mask + (1.0 - mask) * fill_img |
There are few minor problems with this approach:
- if we pass
fill = [0.0, ]
, we would expect to have a similar result asfill=None
. This is not exactly true for bilinear interpolation mode where we do linear interpolation:
vision/torchvision/transforms/functional_tensor.py
Lines 567 to 568 in 0d69e35
else: # 'bilinear' img = img * mask + (1.0 - mask) * fill_img
Most probably, we would like to skip fill_img
creation for all fill values that has sum(fill) == 0
as grid_sample
pads with zeros.
- if fill is not None:
+ if fill is not None and sum(fill) > 0:
- Linear
fill_img
andimg
interpolation may be replaced by directly applying a mask:
mask = mask < 0.9999
img[mask] = fill_img[mask]
That would match better PIL Image behaviour.
vision/torchvision/transforms/functional_tensor.py
Lines 567 to 568 in 0d69e35
else: # 'bilinear' | |
img = img * mask + (1.0 - mask) * fill_img |
cc @datumbox
Neltherion