-
Notifications
You must be signed in to change notification settings - Fork 7.1k
add prototype transforms that use the prototype dispatchers #5418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
c886c82
add prototype transforms that use the prototype dispatchers
pmeier f9e8e9c
simplify
pmeier 62cb2e5
add logger
pmeier 4ba0fd0
remove legacy classes
pmeier e7a1b46
make get_params private
pmeier eb2eafc
remove randbool method
pmeier a464c47
remove AutoAugmentDispatcher
pmeier 2f9c242
add high level kernels for meta conversion
pmeier 5342bd2
remove transforms meta abstraction from auto augment transforms
pmeier 24a5d27
Merge branch 'main' into refactor-transforms
pmeier c767ac0
appease mypy
pmeier 430dd45
add smoke tests for transforms
pmeier 06383b4
remove Query object
pmeier 4c0bcd2
remove extra_repr helper
pmeier 6de3776
fix tests
pmeier 5537f45
appease mypy
pmeier 5fc63a6
Merge branch 'main' into refactor-transforms
pmeier 061d541
revert some changes on the kernel tests
pmeier f65e5d4
fix dispatcher annotations
pmeier ca6a94a
remove float cast for torch.rand
pmeier 72ec278
add helper to query image
pmeier 536735e
fix imports
pmeier e931f03
address auto augment comments
pmeier 4117662
cleanup
pmeier 7ac1c8b
Merge branch 'main' into refactor-transforms
pmeier 7799069
Merge branch 'main' into refactor-transforms
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import itertools | ||
|
||
import PIL.Image | ||
import pytest | ||
import torch | ||
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels | ||
from torchvision.prototype import transforms, features | ||
from torchvision.transforms.functional import to_pil_image | ||
|
||
|
||
def make_vanilla_tensor_images(*args, **kwargs): | ||
for image in make_images(*args, **kwargs): | ||
if image.ndim > 3: | ||
continue | ||
yield image.data | ||
|
||
|
||
def make_pil_images(*args, **kwargs): | ||
for image in make_vanilla_tensor_images(*args, **kwargs): | ||
yield to_pil_image(image) | ||
|
||
|
||
def make_vanilla_tensor_bounding_boxes(*args, **kwargs): | ||
for bounding_box in make_bounding_boxes(*args, **kwargs): | ||
yield bounding_box.data | ||
|
||
|
||
INPUT_CREATIONS_FNS = { | ||
features.Image: make_images, | ||
features.BoundingBox: make_bounding_boxes, | ||
features.OneHotLabel: make_one_hot_labels, | ||
torch.Tensor: make_vanilla_tensor_images, | ||
PIL.Image.Image: make_pil_images, | ||
} | ||
|
||
|
||
def parametrize(transforms_with_inputs): | ||
return pytest.mark.parametrize( | ||
("transform", "input"), | ||
[ | ||
pytest.param( | ||
transform, | ||
input, | ||
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", | ||
) | ||
for transform, inputs in transforms_with_inputs | ||
for idx, input in enumerate(inputs) | ||
], | ||
) | ||
|
||
|
||
def parametrize_from_transforms(*transforms): | ||
transforms_with_inputs = [] | ||
for transform in transforms: | ||
dispatcher = transform._DISPATCHER | ||
if dispatcher is None: | ||
continue | ||
|
||
for type_ in dispatcher._kernels: | ||
try: | ||
inputs = INPUT_CREATIONS_FNS[type_]() | ||
except KeyError: | ||
continue | ||
|
||
transforms_with_inputs.append((transform, inputs)) | ||
|
||
return parametrize(transforms_with_inputs) | ||
|
||
|
||
class TestSmoke: | ||
@parametrize_from_transforms( | ||
transforms.RandomErasing(), | ||
transforms.HorizontalFlip(), | ||
transforms.Resize([16, 16]), | ||
transforms.CenterCrop([16, 16]), | ||
transforms.ConvertImageDtype(), | ||
) | ||
def test_common(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transform, | ||
[ | ||
dict( | ||
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float), | ||
one_hot_label=features.OneHotLabel.new_like( | ||
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float | ||
), | ||
) | ||
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels()) | ||
], | ||
) | ||
for transform in [ | ||
transforms.RandomMixup(alpha=1.0), | ||
transforms.RandomCutmix(alpha=1.0), | ||
] | ||
] | ||
) | ||
def test_mixup_cutmix(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transform, | ||
itertools.chain.from_iterable( | ||
fn(dtypes=[torch.uint8], extra_dims=[(4,)]) | ||
for fn in [ | ||
make_images, | ||
make_vanilla_tensor_images, | ||
make_pil_images, | ||
] | ||
), | ||
) | ||
for transform in ( | ||
transforms.RandAugment(), | ||
transforms.TrivialAugmentWide(), | ||
transforms.AutoAugment(), | ||
) | ||
] | ||
) | ||
def test_auto_augment(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), | ||
itertools.chain.from_iterable( | ||
fn(color_spaces=["rgb"], dtypes=[torch.float32]) | ||
for fn in [ | ||
make_images, | ||
make_vanilla_tensor_images, | ||
] | ||
), | ||
), | ||
] | ||
) | ||
def test_normalize(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transforms.ConvertColorSpace("grayscale"), | ||
itertools.chain( | ||
make_images(), | ||
make_vanilla_tensor_images(color_spaces=["rgb"]), | ||
make_pil_images(color_spaces=["rgb"]), | ||
), | ||
) | ||
] | ||
) | ||
def test_convert_bounding_color_space(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"), | ||
itertools.chain( | ||
make_bounding_boxes(), | ||
make_vanilla_tensor_bounding_boxes(formats=["xywh"]), | ||
), | ||
) | ||
] | ||
) | ||
def test_convert_bounding_box_format(self, transform, input): | ||
transform(input) | ||
|
||
@parametrize( | ||
[ | ||
( | ||
transforms.RandomResizedCrop([16, 16]), | ||
itertools.chain( | ||
make_images(extra_dims=[(4,)]), | ||
make_vanilla_tensor_images(), | ||
make_pil_images(), | ||
), | ||
) | ||
] | ||
) | ||
def test_random_resized_crop(self, transform, input): | ||
transform(input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip | ||
from . import kernels # usort: skip | ||
from . import functional # usort: skip | ||
from .kernels import InterpolationMode # usort: skip | ||
from ._transform import Transform # usort: skip | ||
|
||
from ._augment import RandomErasing, RandomMixup, RandomCutmix | ||
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment | ||
from ._container import Compose, RandomApply, RandomChoice, RandomOrder | ||
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop | ||
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace | ||
from ._misc import Identity, Normalize, ToDtype, Lambda | ||
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval | ||
from ._type_conversion import DecodeImage, LabelToOneHot |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These test are just a quick and dirty implementation to see if the transforms don't error when called with the supported inputs.