diff --git a/references/classification/train.py b/references/classification/train.py index b00a11fcac3..569cf3009e7 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -163,7 +163,7 @@ def load_data(traindir, valdir, args): weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = prototype.transforms.ImageNetEval( + preprocessing = prototype.transforms.ImageClassificationEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) diff --git a/references/detection/train.py b/references/detection/train.py index 765f8144364..3909e6413d0 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -57,7 +57,7 @@ def get_transform(train, args): weights = prototype.models.get_weight(args.weights) return weights.transforms() else: - return prototype.transforms.CocoEval() + return prototype.transforms.ObjectDetectionEval() def get_args_parser(add_help=True): diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index fdb9e4d8d7a..7f2f362c73d 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -137,7 +137,7 @@ def validate(model, args): weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = prototype.transforms.RaftEval() + preprocessing = prototype.transforms.OpticalFlowEval() else: preprocessing = OpticalFlowPresetEval() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 436dd491dca..5dc03945bd7 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -42,7 +42,7 @@ def get_transform(train, args): weights = prototype.models.get_weight(args.weights) return weights.transforms() else: - return prototype.transforms.VocEval(resize_size=520) + return prototype.transforms.SemanticSegmentationEval(resize_size=520) def criterion(inputs, target): diff --git a/references/video_classification/train.py b/references/video_classification/train.py index df8687ff6c2..d36785ddf96 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -157,7 +157,7 @@ def main(args): weights = prototype.models.get_weight(args.weights) transform_test = weights.transforms() else: - transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index 091205d0872..204a68236d3 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.alexnet import AlexNet @@ -16,7 +16,7 @@ class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ "task": "image_classification", "architecture": "AlexNet", diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index ab9d08fbd3a..7d63ee155db 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, List, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.convnext import ConvNeXt, CNBlockConfig @@ -56,7 +56,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=236), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, @@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=230), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, @@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, @@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index 248f0977b3b..4ad9be028e5 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple import torch.nn as nn -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.densenet import DenseNet @@ -78,7 +78,7 @@ def _densenet( class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 7978856, @@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 28681000, @@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 14149480, @@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 20013928, diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 4fbbbc0c1e8..ecdd9bdb423 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Union from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.faster_rcnn import ( @@ -43,7 +43,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ **_COMMON_META, "num_params": 41755286, @@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ **_COMMON_META, "num_params": 19386354, @@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ **_COMMON_META, "num_params": 19386354, diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index faa181b60b0..db3a679a62d 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -1,7 +1,7 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.fcos import ( @@ -27,7 +27,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ "task": "image_object_detection", "architecture": "FCOS", diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index c10d761fa26..e0b4d7061fa 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -1,7 +1,7 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.keypoint_rcnn import ( @@ -37,7 +37,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_LEGACY = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ **_COMMON_META, "num_params": 59137258, @@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ) COCO_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ **_COMMON_META, "num_params": 59137258, diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 3e438dab160..187bf6912b4 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,7 +1,7 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.mask_rcnn import ( @@ -27,7 +27,7 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ "task": "image_object_detection", "architecture": "MaskRCNN", diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index b819150ade0..eadd6c635ca 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,7 +1,7 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.retinanet import ( @@ -28,7 +28,7 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ "task": "image_object_detection", "architecture": "RetinaNet", diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 40568535277..3cab044958d 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -1,7 +1,7 @@ import warnings from typing import Any, Optional -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.ssd import ( @@ -25,7 +25,7 @@ class SSD300_VGG16_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ "task": "image_object_detection", "architecture": "SSD", diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 6bc542df5b3..6de34acb5ae 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional from torch import nn -from torchvision.prototype.transforms import CocoEval +from torchvision.prototype.transforms import ObjectDetectionEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.ssdlite import ( @@ -30,7 +30,7 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=CocoEval, + transforms=ObjectDetectionEval, meta={ "task": "image_object_detection", "architecture": "SSDLite", diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 2619709764f..cb6d2bb2b35 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Sequence, Union from torch import nn -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf @@ -85,7 +85,9 @@ def _efficientnet( class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 5288548, @@ -100,7 +102,9 @@ class EfficientNet_B0_Weights(WeightsEnum): class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 7794184, @@ -111,7 +115,9 @@ class EfficientNet_B1_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", - transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR), + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ), meta={ **_COMMON_META_V1, "num_params": 7794184, @@ -128,7 +134,9 @@ class EfficientNet_B1_Weights(WeightsEnum): class EfficientNet_B2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 9109994, @@ -143,7 +151,9 @@ class EfficientNet_B2_Weights(WeightsEnum): class EfficientNet_B3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 12233232, @@ -158,7 +168,9 @@ class EfficientNet_B3_Weights(WeightsEnum): class EfficientNet_B4_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 19341616, @@ -173,7 +185,9 @@ class EfficientNet_B4_Weights(WeightsEnum): class EfficientNet_B5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 30389784, @@ -188,7 +202,9 @@ class EfficientNet_B5_Weights(WeightsEnum): class EfficientNet_B6_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 43040704, @@ -203,7 +219,9 @@ class EfficientNet_B6_Weights(WeightsEnum): class EfficientNet_B7_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), + transforms=partial( + ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ), meta={ **_COMMON_META_V1, "num_params": 66347960, @@ -219,7 +237,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", transforms=partial( - ImageNetEval, + ImageClassificationEval, crop_size=384, resize_size=384, interpolation=InterpolationMode.BILINEAR, @@ -239,7 +257,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", transforms=partial( - ImageNetEval, + ImageClassificationEval, crop_size=480, resize_size=480, interpolation=InterpolationMode.BILINEAR, @@ -259,7 +277,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", transforms=partial( - ImageNetEval, + ImageClassificationEval, crop_size=480, resize_size=480, interpolation=InterpolationMode.BICUBIC, diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 9a561e51390..70dc0d9db5c 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -2,7 +2,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs @@ -17,7 +17,7 @@ class GoogLeNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index 088062f3010..eec78a26236 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs @@ -16,7 +16,7 @@ class Inception_V3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageNetEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index 90c0a43a5e6..c48e34a7be5 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.mnasnet import MNASNet @@ -38,7 +38,7 @@ class MNASNet0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 2218512, @@ -57,7 +57,7 @@ class MNASNet0_75_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 4383312, diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 7176252111a..71b412898fe 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv2 import MobileNetV2 @@ -28,7 +28,7 @@ class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", @@ -38,7 +38,7 @@ class MobileNet_V2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 1e342b83bcc..aaf9c2c85a4 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional, List -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig @@ -51,7 +51,7 @@ def _mobilenet_v3( class MobileNet_V3_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 5483032, @@ -62,7 +62,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5483032, @@ -77,7 +77,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 2542856, diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index bf8634efd4f..24e87f3d4f9 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -4,7 +4,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.prototype.transforms import RaftEval +from torchvision.prototype.transforms import OpticalFlowEval from torchvision.transforms.functional import InterpolationMode from .._api import WeightsEnum @@ -33,7 +33,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -48,7 +48,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -63,7 +63,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_V1 = Weights( # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -78,7 +78,7 @@ class Raft_Large_Weights(WeightsEnum): # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -91,7 +91,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_K_V1 = Weights( # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -106,7 +106,7 @@ class Raft_Large_Weights(WeightsEnum): # Same as CT_SKHT with extra fine-tuning on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 5257536, @@ -122,7 +122,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth) url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 990162, @@ -136,7 +136,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=RaftEval, + transforms=OpticalFlowEval, meta={ **_COMMON_META, "num_params": 990162, diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 71d5cc3f5e0..cca6ba25060 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -2,7 +2,7 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.googlenet import ( @@ -26,7 +26,7 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index 4b820228de2..2639b7de14f 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.inception import ( @@ -25,7 +25,7 @@ class Inception_V3_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageNetEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index d3fabc44beb..a9789583fe6 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.mobilenetv2 import ( @@ -26,7 +26,7 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV2", diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index e3369d1dc5f..915308d948f 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -2,7 +2,7 @@ from typing import Any, List, Optional, Union import torch -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.mobilenetv3 import ( @@ -59,7 +59,7 @@ def _mobilenet_v3_model( class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV3", diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 61f7dbbeda0..9e2e29db0bf 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, List, Optional, Type, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.resnet import ( @@ -68,7 +68,7 @@ def _resnet( class ResNet18_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -85,7 +85,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -98,7 +98,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -115,7 +115,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -128,7 +128,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 7c550bb2d7b..e21349ff8e0 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, List, Optional, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.shufflenetv2 import ( @@ -67,7 +67,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -82,7 +82,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 44b7678d8e5..d5e2b535532 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -2,7 +2,7 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.regnet import RegNet, BlockParams @@ -77,7 +77,7 @@ def _regnet( class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 4344144, @@ -88,7 +88,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 4344144, @@ -103,7 +103,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 6432512, @@ -114,7 +114,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 6432512, @@ -129,7 +129,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 11202430, @@ -140,7 +140,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 11202430, @@ -155,7 +155,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 19436338, @@ -166,7 +166,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 19436338, @@ -181,7 +181,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 39381472, @@ -192,7 +192,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39381472, @@ -207,7 +207,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 83590140, @@ -218,7 +218,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 83590140, @@ -233,7 +233,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 145046770, @@ -244,7 +244,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 145046770, @@ -264,7 +264,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 5495976, @@ -275,7 +275,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5495976, @@ -290,7 +290,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 7259656, @@ -301,7 +301,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 7259656, @@ -316,7 +316,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 9190136, @@ -327,7 +327,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 9190136, @@ -342,7 +342,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 15296552, @@ -353,7 +353,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 15296552, @@ -368,7 +368,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 39572648, @@ -379,7 +379,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39572648, @@ -394,7 +394,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 54278536, @@ -405,7 +405,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 54278536, @@ -420,7 +420,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 107811560, @@ -431,7 +431,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 107811560, diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index cf18e03fab9..35e30c0e760 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, List, Optional, Type, Union -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.resnet import BasicBlock, Bottleneck, ResNet @@ -63,7 +63,7 @@ def _resnet( class ResNet18_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -80,7 +80,7 @@ class ResNet18_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -97,7 +97,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -110,7 +110,7 @@ class ResNet50_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -127,7 +127,7 @@ class ResNet50_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -140,7 +140,7 @@ class ResNet101_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -157,7 +157,7 @@ class ResNet101_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -170,7 +170,7 @@ class ResNet152_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -187,7 +187,7 @@ class ResNet152_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -200,7 +200,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -217,7 +217,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -230,7 +230,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -247,7 +247,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -260,7 +260,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -277,7 +277,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -290,7 +290,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 43d3ab131b3..7165078161f 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import VocEval +from torchvision.prototype.transforms import SemanticSegmentationEval from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet @@ -36,7 +36,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ **_COMMON_META, "num_params": 42004074, @@ -51,7 +51,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ **_COMMON_META, "num_params": 60996202, @@ -66,7 +66,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ **_COMMON_META, "num_params": 11029328, diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 6b3a672de48..1dfc251844f 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import VocEval +from torchvision.prototype.transforms import SemanticSegmentationEval from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.fcn import FCN, _fcn_resnet @@ -26,7 +26,7 @@ class FCN_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ **_COMMON_META, "num_params": 35322218, @@ -41,7 +41,7 @@ class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ **_COMMON_META, "num_params": 54314346, diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 7359e08102e..2c0fa6f0aff 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import VocEval +from torchvision.prototype.transforms import SemanticSegmentationEval from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 @@ -17,7 +17,7 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(VocEval, resize_size=520), + transforms=partial(SemanticSegmentationEval, resize_size=520), meta={ "task": "image_semantic_segmentation", "architecture": "LRASPP", diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 41fa64c14e7..48047a70c60 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.shufflenetv2 import ShuffleNetV2 @@ -55,7 +55,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -69,7 +69,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 5d26466f92d..7f6a034ed6c 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.squeezenet import SqueezeNet @@ -27,7 +27,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "min_size": (21, 21), @@ -42,7 +42,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "min_size": (17, 17), diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index de0ce518629..233c35418ed 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -1,7 +1,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.vgg import VGG, make_layers, cfgs @@ -55,7 +55,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 132863336, @@ -69,7 +69,7 @@ class VGG11_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 132868840, @@ -83,7 +83,7 @@ class VGG11_BN_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 133047848, @@ -97,7 +97,7 @@ class VGG13_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 133053736, @@ -111,7 +111,7 @@ class VGG13_BN_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 138357544, @@ -125,7 +125,10 @@ class VGG16_Weights(WeightsEnum): IMAGENET1K_FEATURES = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( - ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) + ImageClassificationEval, + crop_size=224, + mean=(0.48235, 0.45882, 0.40784), + std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), ), meta={ **_COMMON_META, @@ -142,7 +145,7 @@ class VGG16_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 138365992, @@ -156,7 +159,7 @@ class VGG16_BN_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 143667240, @@ -170,7 +173,7 @@ class VGG19_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 143678248, diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 90c3d605232..790d254d266 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -2,7 +2,7 @@ from typing import Any, Callable, List, Optional, Sequence, Type, Union from torch import nn -from torchvision.prototype.transforms import Kinect400Eval +from torchvision.prototype.transforms import VideoClassificationEval from torchvision.transforms.functional import InterpolationMode from ....models.video.resnet import ( @@ -65,7 +65,7 @@ def _video_resnet( class R3D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R3D", @@ -80,7 +80,7 @@ class R3D_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "MC3", @@ -95,7 +95,7 @@ class MC3_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R(2+1)D", diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 28446aae30e..468903b6b94 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -5,7 +5,7 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import ImageNetEval +from torchvision.prototype.transforms import ImageClassificationEval from torchvision.transforms.functional import InterpolationMode from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 @@ -38,7 +38,7 @@ class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, @@ -55,7 +55,7 @@ class ViT_B_16_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, @@ -72,7 +72,7 @@ class ViT_B_32_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageNetEval, crop_size=224, resize_size=242), + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, @@ -89,7 +89,7 @@ class ViT_L_16_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageNetEval, crop_size=224), + transforms=partial(ImageClassificationEval, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 4641cc5ab86..16369428e47 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -10,5 +10,11 @@ from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda -from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval +from ._presets import ( + ObjectDetectionEval, + ImageClassificationEval, + SemanticSegmentationEval, + VideoClassificationEval, + OpticalFlowEval, +) from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index d7c4ddb4684..3ab045b3ddb 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -6,10 +6,16 @@ from ...transforms import functional as F, InterpolationMode -__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"] +__all__ = [ + "ObjectDetectionEval", + "ImageClassificationEval", + "VideoClassificationEval", + "SemanticSegmentationEval", + "OpticalFlowEval", +] -class CocoEval(nn.Module): +class ObjectDetectionEval(nn.Module): def forward( self, img: Tensor, target: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: @@ -18,7 +24,7 @@ def forward( return F.convert_image_dtype(img, torch.float), target -class ImageNetEval(nn.Module): +class ImageClassificationEval(nn.Module): def __init__( self, crop_size: int, @@ -44,7 +50,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class Kinect400Eval(nn.Module): +class VideoClassificationEval(nn.Module): def __init__( self, crop_size: Tuple[int, int], @@ -69,7 +75,7 @@ def forward(self, vid: Tensor) -> Tensor: return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) -class VocEval(nn.Module): +class SemanticSegmentationEval(nn.Module): def __init__( self, resize_size: int, @@ -99,7 +105,7 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, return img, target -class RaftEval(nn.Module): +class OpticalFlowEval(nn.Module): def forward( self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: