-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
We are currently in the process of revamping the transforms
module, but there is still a lot of porting left. The design is now stable enough that the porting process should be manageable for someone not intimately familiar with the new design. Thus, we are actively looking for contributions helping us finishing this.
Here is the list of transformations that still need to be ported to achieve feature parity (minus some special transformations) between the new and old API:
- Port
transforms.Pad
toprototype.transforms
#5521 - Port
transforms.RandomCrop
toprototype.transforms
#5522 - Port
transforms.RandomHorizontalFlip
toprototype.transforms
#5523 - Port
transforms.RandomVerticalFlip
toprototype.transforms
#5524 - Port
transforms.RandomPerspective
toprototype.transforms
#5525 - Port
transforms.FiveCrop
toprototype.transforms
#5526 - Port
transforms.TenCrop
toprototype.transforms
#5527 - Port
transforms.ColorJitter
toprototype.transforms
#5528 - Port
transforms.RandomRotation
toprototype.transforms
#5529 - Port
transforms.RandomAffine
toprototype.transforms
#5530 - Port
transforms.GaussianBlur
toprototype.transforms
#5531 - Port
transforms.RandomInvert
toprototype.transforms
#5532 - Port
transforms.RandomPosterize
toprototype.transforms
#5533 - Port
transforms.RandomSolarize
toprototype.transforms
#5534 - Port
transforms.RandomAdjustSharpness
toprototype.transforms
#5535 - Port
transforms.RandomAutocontrast
toprototype.transforms
#5536 - Port
transforms.RandomEqualize
toprototype.transforms
#5537 - Port
transforms.LinearTransformation
toprototype.transforms
#5538 -
RandomHorizontalFlip
Transforms without dispatcher #5421 -
Resize
Transforms without dispatcher #5421 -
CenterCrop
Transforms without dispatcher #5421 -
RandomResizedCrop
Transforms without dispatcher #5421 -
Normalize
Transforms without dispatcher #5421 -
RandomErasing
Transforms without dispatcher #5421 - AutoAugment transforms
-
RandAugment
Transforms without dispatcher #5421 -
TrivialAugmentWide
Transforms without dispatcher #5421 -
AutoAugment
Transforms without dispatcher #5421 - add prototype AugMix transform #5492
-
There is an issue for each transformation. Please comment on that if you want to take up a task so we can assign it to you.
Here is a recipe on how the porting process looks like:
-
Port the kernels from
transforms.functional_tensor
andtransforms.functional_pil
toprototype.transforms.functional
. In most cases this means just binding them to a new namevision/torchvision/prototype/transforms/functional/_geometry.py
Lines 14 to 15 in 97385df
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip The naming scheme is
{kernel_name}_{feature_type}_{feature_subtype}
. To use the example above, the kernel name ishorizontal_flip
, the feature type isimage
and the subtypes aretensor
andpil
.In general, the new kernels should have the same signature as the old dispatchers from
torchvision.functional
. In most cases this is given by default, but sometimes there is some common pre-processing performed in the dispatchers before the kernels are called. In these cases, the canonical way is to move the common functionality into a private helper function and define the new kernels to call the helper first and afterwards the old kernel. -
Create a new transform in
prototype.transforms
that inherits fromprototype.transforms.Transform
. The constructor can be copy-pasted from the corresponding transform intransforms
. -
Implement the
_transform
method. It receives two arguments:input
andparams
(see below).input
can be any non-container objects, i.e. no lists, tuple, dictionaries, and so on, so the implementation needs to check the input type and dispatch accordingly. The general behavior should be to handle what we can and let the rest pass through (there are exceptions to this, see below). The implementation should look something like thisclass Foo(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.foo_image_tensor(input, ...) return features.Image.new_like(input, output) elif is_simple_tensor(input): return F.foo_image_tensor(input, ...) elif isinstance(input, PIL.Image.Image): return F.foo_image_pil(input, ...) else: return input
-
Some transformations could in theory support other feature types such as bounding boxes, but we currently don't have kernels for them or will never have. In these cases it is crucial to not perform the transformation on the image and let the rest pass through, because it invalidates the correspondence. For example, applying
RandomRotate
only on an image but ignoring a bounding box renders the bounding box invalid. To avoid this, overwrite theforward
method and fail if unsupported types are detected:vision/torchvision/prototype/transforms/_geometry.py
Lines 74 to 78 in 97385df
def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) -
Some random transformations take a
p
parameter that indicates the probability to apply the transformation. Overwrite theforward
method and perform this check there before calling theforward
of the super classvision/torchvision/prototype/transforms/_augment.py
Lines 102 to 105 in 97385df
elif torch.rand(1) >= self.p: return sample return super().forward(sample) -
Some random transformations need to sample parameters at runtime. In the old implementations this is usually done in a
@staticmethod def get_params(): ...
The new architecture is similar but not the same. You can overwrite the
_get_params(self, sample) -> Dict[str, Any]
method:vision/torchvision/prototype/transforms/_geometry.py
Lines 114 to 116 in 97385df
def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) _, height, width = get_image_dimensions(image) The returned dictionary is available through the
params
parameter in the_transform
method. Thequery_image
function used above can be used to find an image in thesample
without worrying about the actual structure. This is useful if the image dimensions are needed to generate the parameters.
The transformations are not simple stuff, so there might be cases were the recipe from above is not sufficient. If you have any kind of questions or hit blockers, feel free to send a PR with what you have and ping me there so I can have a look and help you out.