From a696473d7000c7bb3c55a1a22e061c698625dc2e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 13:17:22 +0000 Subject: [PATCH 01/18] Making _preset.py classes --- torchvision/models/alexnet.py | 2 +- torchvision/models/convnext.py | 2 +- torchvision/models/densenet.py | 2 +- torchvision/models/detection/faster_rcnn.py | 2 +- torchvision/models/detection/fcos.py | 2 +- torchvision/models/detection/keypoint_rcnn.py | 2 +- torchvision/models/detection/mask_rcnn.py | 2 +- torchvision/models/detection/retinanet.py | 2 +- torchvision/models/detection/ssd.py | 2 +- torchvision/models/detection/ssdlite.py | 2 +- torchvision/models/efficientnet.py | 2 +- torchvision/models/googlenet.py | 2 +- torchvision/models/inception.py | 2 +- torchvision/models/mnasnet.py | 2 +- torchvision/models/mobilenetv2.py | 2 +- torchvision/models/mobilenetv3.py | 2 +- torchvision/models/optical_flow/raft.py | 2 +- torchvision/models/quantization/googlenet.py | 2 +- torchvision/models/quantization/inception.py | 2 +- torchvision/models/quantization/mobilenetv2.py | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- torchvision/models/quantization/resnet.py | 2 +- torchvision/models/quantization/shufflenetv2.py | 2 +- torchvision/models/regnet.py | 2 +- torchvision/models/resnet.py | 2 +- torchvision/models/segmentation/deeplabv3.py | 2 +- torchvision/models/segmentation/fcn.py | 2 +- torchvision/models/segmentation/lraspp.py | 2 +- torchvision/models/shufflenetv2.py | 2 +- torchvision/models/squeezenet.py | 2 +- torchvision/models/vgg.py | 2 +- torchvision/models/video/resnet.py | 2 +- torchvision/models/vision_transformer.py | 2 +- torchvision/transforms/__init__.py | 7 ------- torchvision/transforms/_presets.py | 4 ++++ 35 files changed, 37 insertions(+), 40 deletions(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 4df533000f9..fb7f8117cad 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 8d25e77eaa1..dd7ca8a5c37 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,7 +7,7 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index b0de4529902..52e4f4e9529 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index b5a1df4502c..a0e9e8dfcaf 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 7948cd76ab2..1fcbd370289 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 522545293a0..b6a56aa95bf 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 47b32984991..1d9ff9a7f67 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 048d68b52dc..427bc415ed9 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -10,7 +10,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index b907b3fccf8..403b37b0c7c 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops import boxes as box_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index dad28cfed13..006d2b10eb9 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops.misc import Conv2dNormActivation -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet from .._api import WeightsEnum, Weights diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 9665c169bbf..691e525e1d9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,7 @@ from torchvision.ops import StochasticDepth from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index e09e6788097..bd07bf6f9ed 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 24d084b62d2..ef998df1b7d 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 287911edbec..11a69af982f 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 1e19db1a314..cf2753bc9ac 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -7,7 +7,7 @@ from torch import nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3a98456416d..b68ef12fa2f 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index a506224d4b3..f95411ced21 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,7 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ...transforms import OpticalFlowEval, InterpolationMode +from ...transforms._presets import OpticalFlowEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import Weights, WeightsEnum from .._utils import handle_legacy_interface diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index befc2299c06..61cdc899b8e 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.nn import functional as F -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 697d99d4027..8780c76f264 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 40f5cb544fd..be36d0566c2 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights from ...ops.misc import Conv2dNormActivation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4b79b7f26ae..be3a9739a9a 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 666b1b23163..67e11e7e720 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -13,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index c5bfe698636..3eda3c813ec 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -6,7 +6,7 @@ from torch import Tensor from torchvision.models import shufflenetv2 -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 1015c21b858..4d90d2d551e 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 159749df006..2d24e2b1b8e 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 6e8bf0c398b..a4b00c5e792 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 5a3ca1f654f..b0836cd07b0 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationEval, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index d1fe15a350d..dfa15804cf8 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index b38c0ac2974..c26eaa1d536 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index d495b3148e5..28e26cd21b6 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 27325c9016c..3398c2d6e9b 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a6b779d10f1..df215fce0a4 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor -from ...transforms import VideoClassificationEval, InterpolationMode +from ...transforms._presets import VideoClassificationEval, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 801e7adc981..9e1b701b144 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 94ec34ebe98..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,9 +1,2 @@ from .transforms import * from .autoaugment import * -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 1776d876ccb..1a1590c276f 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -1,3 +1,7 @@ +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" from typing import Dict, Optional, Tuple import torch From 8e56a373fbab8ce3fb51e6e393de18195cd3e960 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 13:39:16 +0000 Subject: [PATCH 02/18] Remove support of targets on presets. --- gallery/plot_optical_flow.py | 2 +- gallery/plot_repurposing_annotations.py | 7 ++-- gallery/plot_visualization_utils.py | 8 ++--- torchvision/transforms/_presets.py | 43 ++++++------------------- 4 files changed, 17 insertions(+), 43 deletions(-) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 770610fb971..5149ebc541b 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs): def preprocess(img1_batch, img2_batch): img1_batch = F.resize(img1_batch, size=[520, 960]) img2_batch = F.resize(img2_batch, size=[520, 960]) - return transforms(img1_batch, img2_batch)[:2] + return transforms(img1_batch, img2_batch) img1_batch, img2_batch = preprocess(img1_batch, img2_batch) diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index a826a2523f2..0a795b9162b 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -146,11 +146,8 @@ def show(imgs): print(img.size()) tranforms = weights.transforms() -img, _ = tranforms(img) -target = {} -target["boxes"] = boxes -target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) -detection_outputs = model(img.unsqueeze(0), [target]) +img = tranforms(img) +detection_outputs = model(img.unsqueeze(0)) #################################### diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 27fd97681c0..7f92d54ebdd 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -81,7 +81,7 @@ def show(imgs): weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -131,7 +131,7 @@ def show(imgs): model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch, _ = transforms(batch) +normalized_batch = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -272,7 +272,7 @@ def show(imgs): weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -397,7 +397,7 @@ def show(imgs): weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -person_float, _ = transforms(person_int) +person_float = transforms(person_int) model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 1a1590c276f..e086ee12985 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -20,12 +20,10 @@ class ObjectDetectionEval(nn.Module): - def forward( - self, img: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, torch.float), target + return F.convert_image_dtype(img, torch.float) class ImageClassificationEval(nn.Module): @@ -95,28 +93,22 @@ def __init__( self._interpolation = interpolation self._interpolation_target = interpolation_target - def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + def forward(self, img: Tensor) -> Tensor: if isinstance(self._size, list): img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - if target: - if isinstance(self._size, list): - target = F.resize(target, self._size, interpolation=self._interpolation_target) - if not isinstance(target, Tensor): - target = F.pil_to_tensor(target) - target = target.squeeze(0).to(torch.int64) - return img, target + return img class OpticalFlowEval(nn.Module): - def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - - img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) + def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) img1 = F.convert_image_dtype(img1, torch.float32) img2 = F.convert_image_dtype(img2, torch.float32) @@ -128,19 +120,4 @@ def forward( img1 = img1.contiguous() img2 = img2.contiguous() - return img1, img2, flow, valid_flow_mask - - def _pil_or_numpy_to_tensor( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - if not isinstance(img1, Tensor): - img1 = F.pil_to_tensor(img1) - if not isinstance(img2, Tensor): - img2 = F.pil_to_tensor(img2) - - if flow is not None and not isinstance(flow, Tensor): - flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): - valid_flow_mask = torch.from_numpy(valid_flow_mask) - - return img1, img2, flow, valid_flow_mask + return img1, img2 From 65d32d31dbc12a7fbc353e57d7fe43d19989f1e5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 16:55:46 +0000 Subject: [PATCH 03/18] Rewriting the video preset --- torchvision/transforms/_presets.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index e086ee12985..e43c69fff98 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -69,12 +69,24 @@ def __init__( self._interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: - vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) + need_squeeze = False + if vid.ndim < 5: + vid = vid.unsqueeze(dim=0) + need_squeeze = True + + vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) + N, T, C, H, W = vid.shape + vid = vid.view(-1, C, H, W) vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self._mean, std=self._std) - return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) + vid = vid.view(N, T, C, H, W) + vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) + + if need_squeeze: + vid = vid.squeeze(dim=0) + return vid class SemanticSegmentationEval(nn.Module): From e7a31a2f0ad73b94cbe1d8fe8f8b9f003f650846 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 16:56:14 +0000 Subject: [PATCH 04/18] Adding tests to check that the bundled transforms are JIT scriptable --- test/test_extended_models.py | 66 ++++++++++++++++++++++++++++++++++-- test/test_models.py | 5 +-- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 4bfe03d1ea0..170ba47233a 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -1,5 +1,6 @@ import importlib import os +import torch import pytest import test_models as TM @@ -7,8 +8,9 @@ from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface -run_if_test_with_prototype = pytest.mark.skipif( - os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1", + +run_if_test_with_extended = pytest.mark.skipif( + os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", ) @@ -76,7 +78,7 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) -@run_if_test_with_prototype +@run_if_test_with_extended def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"] defaults = { @@ -123,6 +125,64 @@ def test_schema_meta_validation(model_fn): assert not bad_names +@pytest.mark.parametrize( + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), +) +@run_if_test_with_extended +def test_transforms_jit(model_fn): + model_name = model_fn.__name__ + weights_enum = _get_model_weights(model_fn) + if len(weights_enum) == 0: + pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") + + defaults = { + "models": { + "input_shape": (1, 3, 224, 224), + }, + "detection": { + "input_shape": (3, 300, 300), + }, + "quantization": { + "input_shape": (1, 3, 224, 224), + "quantize": True, + }, + "segmentation": { + "input_shape": (1, 3, 520, 520), + }, + "video": { + "input_shape": (1, 4, 112, 112, 3), + }, + "optical_flow": { + "input_shape": (1, 3, 128, 128), + }, + } + module_name = model_fn.__module__.split(".")[-2] + + kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + x = torch.rand(input_shape) + if module_name == "optical_flow": + args = (x, x) + else: + args = (x, ) + + problematic_weights = [] + for w in weights_enum: + transforms = w.transforms() + try: + TM._check_jit_scriptable(transforms, args) + except Exception: + problematic_weights.append(w) + + assert not problematic_weights + + # With this filter, every unexpected warning will be turned into an error @pytest.mark.filterwarnings("error") class TestHandleLegacyInterface: diff --git a/test/test_models.py b/test/test_models.py index 5bef9e24d9f..137c325be1a 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,8 +133,9 @@ def get_export_import_copy(m): if eager_out is None: with torch.no_grad(), freeze_rng_state(): - if unwrapper: - eager_out = nn_module(*args) + eager_out = nn_module(*args) + if unwrapper: + eager_out = unwrapper(eager_out) with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) From a7b454f129fbc2280f55dbf9880e7734b11e276b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 17:02:03 +0000 Subject: [PATCH 05/18] Rename all presets from *Eval to *Inference --- torchvision/models/alexnet.py | 4 +- torchvision/models/convnext.py | 10 ++-- torchvision/models/densenet.py | 10 ++-- torchvision/models/detection/faster_rcnn.py | 8 +-- torchvision/models/detection/fcos.py | 4 +- torchvision/models/detection/keypoint_rcnn.py | 6 +- torchvision/models/detection/mask_rcnn.py | 4 +- torchvision/models/detection/retinanet.py | 4 +- torchvision/models/detection/ssd.py | 4 +- torchvision/models/detection/ssdlite.py | 4 +- torchvision/models/efficientnet.py | 26 ++++----- torchvision/models/googlenet.py | 4 +- torchvision/models/inception.py | 4 +- torchvision/models/mnasnet.py | 6 +- torchvision/models/mobilenetv2.py | 6 +- torchvision/models/mobilenetv3.py | 8 +-- torchvision/models/optical_flow/raft.py | 18 +++--- torchvision/models/quantization/googlenet.py | 4 +- torchvision/models/quantization/inception.py | 4 +- .../models/quantization/mobilenetv2.py | 4 +- .../models/quantization/mobilenetv3.py | 4 +- torchvision/models/quantization/resnet.py | 12 ++-- .../models/quantization/shufflenetv2.py | 6 +- torchvision/models/regnet.py | 58 +++++++++---------- torchvision/models/resnet.py | 34 +++++------ torchvision/models/segmentation/deeplabv3.py | 8 +-- torchvision/models/segmentation/fcn.py | 6 +- torchvision/models/segmentation/lraspp.py | 4 +- torchvision/models/shufflenetv2.py | 6 +- torchvision/models/squeezenet.py | 6 +- torchvision/models/vgg.py | 20 +++---- torchvision/models/video/resnet.py | 8 +-- torchvision/models/vision_transformer.py | 10 ++-- torchvision/transforms/_presets.py | 20 +++---- 34 files changed, 172 insertions(+), 172 deletions(-) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index fb7f8117cad..46887b09e5f 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ "task": "image_classification", "architecture": "AlexNet", diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index dd7ca8a5c37..7f3d880f66e 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,7 +7,7 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -218,7 +218,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, @@ -232,7 +232,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, @@ -246,7 +246,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, @@ -260,7 +260,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 52e4f4e9529..681044aefc8 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp from torch import Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -280,7 +280,7 @@ def _densenet( class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 7978856, @@ -294,7 +294,7 @@ class DenseNet121_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 28681000, @@ -308,7 +308,7 @@ class DenseNet161_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 14149480, @@ -322,7 +322,7 @@ class DenseNet169_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 20013928, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index a0e9e8dfcaf..49c79cc3ae9 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -336,7 +336,7 @@ def forward(self, x): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ **_COMMON_META, "num_params": 41755286, @@ -350,7 +350,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ **_COMMON_META, "num_params": 19386354, @@ -364,7 +364,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ **_COMMON_META, "num_params": 19386354, diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 1fcbd370289..b6f55020cff 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -646,7 +646,7 @@ def forward( class FCOS_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ "task": "image_object_detection", "architecture": "FCOS", diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index b6a56aa95bf..fbd82099b2e 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -318,7 +318,7 @@ def forward(self, x): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_LEGACY = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ **_COMMON_META, "num_params": 59137258, @@ -329,7 +329,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ) COCO_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ **_COMMON_META, "num_params": 59137258, diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 1d9ff9a7f67..ce03c48a2ad 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -308,7 +308,7 @@ def __init__(self, in_channels, dim_reduced, num_classes): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ "task": "image_object_detection", "architecture": "MaskRCNN", diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 427bc415ed9..d6c4e843909 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -10,7 +10,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -588,7 +588,7 @@ def forward(self, images, targets=None): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ "task": "image_object_detection", "architecture": "RetinaNet", diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 403b37b0c7c..d6d4a6a0d51 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops import boxes as box_ops -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -28,7 +28,7 @@ class SSD300_VGG16_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ "task": "image_object_detection", "architecture": "SSD", diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 006d2b10eb9..69b8b91f8c7 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops.misc import Conv2dNormActivation -from ...transforms._presets import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetectionInference, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet from .._api import WeightsEnum, Weights @@ -187,7 +187,7 @@ def _mobilenet_extractor( class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetectionInference, meta={ "task": "image_object_detection", "architecture": "SSDLite", diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 691e525e1d9..78733d2ea2a 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,7 @@ from torchvision.ops import StochasticDepth from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -458,7 +458,7 @@ class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -475,7 +475,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -488,7 +488,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ImageClassificationInference, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR ), meta={ **_COMMON_META_V1, @@ -507,7 +507,7 @@ class EfficientNet_B2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -524,7 +524,7 @@ class EfficientNet_B3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -541,7 +541,7 @@ class EfficientNet_B4_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -558,7 +558,7 @@ class EfficientNet_B5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -575,7 +575,7 @@ class EfficientNet_B6_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -592,7 +592,7 @@ class EfficientNet_B7_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ImageClassificationInference, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -609,7 +609,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", transforms=partial( - ImageClassificationEval, + ImageClassificationInference, crop_size=384, resize_size=384, interpolation=InterpolationMode.BILINEAR, @@ -629,7 +629,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", transforms=partial( - ImageClassificationEval, + ImageClassificationInference, crop_size=480, resize_size=480, interpolation=InterpolationMode.BILINEAR, @@ -649,7 +649,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", transforms=partial( - ImageClassificationEval, + ImageClassificationInference, crop_size=480, resize_size=480, interpolation=InterpolationMode.BICUBIC, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index bd07bf6f9ed..af0fc430a4e 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tensor: class GoogLeNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index ef998df1b7d..20ad569b8d6 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -410,7 +410,7 @@ def forward(self, x: Tensor) -> Tensor: class Inception_V3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassificationInference, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 11a69af982f..374c0ecc8fa 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -226,7 +226,7 @@ def _load_from_state_dict( class MNASNet0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 2218512, @@ -245,7 +245,7 @@ class MNASNet0_75_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 4383312, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index cf2753bc9ac..63a27186e61 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -7,7 +7,7 @@ from torch import nn from ..ops.misc import Conv2dNormActivation -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -209,7 +209,7 @@ def forward(self, x: Tensor) -> Tensor: class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", @@ -219,7 +219,7 @@ class MobileNet_V2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index b68ef12fa2f..67ca3dccf76 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _mobilenet_v3( class MobileNet_V3_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 5483032, @@ -328,7 +328,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5483032, @@ -343,7 +343,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 2542856, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index f95411ced21..8921eb96dc3 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,7 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ...transforms._presets import OpticalFlowEval, InterpolationMode +from ...transforms._presets import OpticalFlowInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import Weights, WeightsEnum from .._utils import handle_legacy_interface @@ -523,7 +523,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -538,7 +538,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -553,7 +553,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_V1 = Weights( # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -568,7 +568,7 @@ class Raft_Large_Weights(WeightsEnum): # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -581,7 +581,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_K_V1 = Weights( # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -596,7 +596,7 @@ class Raft_Large_Weights(WeightsEnum): # Same as CT_SKHT with extra fine-tuning on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 5257536, @@ -612,7 +612,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth) url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 990162, @@ -626,7 +626,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlowInference, meta={ **_COMMON_META, "num_params": 990162, diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 61cdc899b8e..15be9ccdcc2 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.nn import functional as F -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -109,7 +109,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class GoogLeNet_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 8780c76f264..a4577946619 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -175,7 +175,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class Inception_V3_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassificationInference, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index be36d0566c2..992188d28ad 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights from ...ops.misc import Conv2dNormActivation -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -67,7 +67,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class MobileNet_V2_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV2", diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index be3a9739a9a..c716b192fd4 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -157,7 +157,7 @@ def _mobilenet_v3_model( class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV3", diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 67e11e7e720..7d55b2c1c2a 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -13,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -161,7 +161,7 @@ def _resnet( class ResNet18_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -178,7 +178,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -191,7 +191,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -208,7 +208,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -221,7 +221,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 3eda3c813ec..9a7354a547a 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -6,7 +6,7 @@ from torch import Tensor from torchvision.models import shufflenetv2 -from ...transforms._presets import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassificationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -118,7 +118,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -133,7 +133,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 4d90d2d551e..6f57d86b67f 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -416,7 +416,7 @@ def _regnet( class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 4344144, @@ -427,7 +427,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 4344144, @@ -442,7 +442,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 6432512, @@ -453,7 +453,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 6432512, @@ -468,7 +468,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 11202430, @@ -479,7 +479,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 11202430, @@ -494,7 +494,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 19436338, @@ -505,7 +505,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 19436338, @@ -520,7 +520,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 39381472, @@ -531,7 +531,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39381472, @@ -546,7 +546,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 83590140, @@ -557,7 +557,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 83590140, @@ -572,7 +572,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 145046770, @@ -583,7 +583,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 145046770, @@ -603,7 +603,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 5495976, @@ -614,7 +614,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5495976, @@ -629,7 +629,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 7259656, @@ -640,7 +640,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 7259656, @@ -655,7 +655,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 9190136, @@ -666,7 +666,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 9190136, @@ -681,7 +681,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 15296552, @@ -692,7 +692,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 15296552, @@ -707,7 +707,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 39572648, @@ -718,7 +718,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39572648, @@ -733,7 +733,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 54278536, @@ -744,7 +744,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 54278536, @@ -759,7 +759,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 107811560, @@ -770,7 +770,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 107811560, diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 2d24e2b1b8e..6f6385a0ba6 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -313,7 +313,7 @@ def _resnet( class ResNet18_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -330,7 +330,7 @@ class ResNet18_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -347,7 +347,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -360,7 +360,7 @@ class ResNet50_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -377,7 +377,7 @@ class ResNet50_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -390,7 +390,7 @@ class ResNet101_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -407,7 +407,7 @@ class ResNet101_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -420,7 +420,7 @@ class ResNet152_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -437,7 +437,7 @@ class ResNet152_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -450,7 +450,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -467,7 +467,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -480,7 +480,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -497,7 +497,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -510,7 +510,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -527,7 +527,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -540,7 +540,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index a4b00c5e792..56d575edcc4 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from ...transforms._presets import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -140,7 +140,7 @@ def _deeplabv3_resnet( class DeepLabV3_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ **_COMMON_META, "num_params": 42004074, @@ -155,7 +155,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ **_COMMON_META, "num_params": 60996202, @@ -170,7 +170,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ **_COMMON_META, "num_params": 11029328, diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index b0836cd07b0..9157a25a093 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn -from ...transforms._presets import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationInference, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -59,7 +59,7 @@ def __init__(self, in_channels: int, channels: int) -> None: class FCN_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ **_COMMON_META, "num_params": 35322218, @@ -74,7 +74,7 @@ class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ **_COMMON_META, "num_params": 54314346, diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index dfa15804cf8..e527892d860 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ...transforms._presets import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentationInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES @@ -96,7 +96,7 @@ def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentationInference, resize_size=520), meta={ "task": "image_semantic_segmentation", "architecture": "LRASPP", diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index c26eaa1d536..4204565868c 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -198,7 +198,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -212,7 +212,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 28e26cd21b6..9fef930fb06 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -128,7 +128,7 @@ def _squeezenet( class SqueezeNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "min_size": (21, 21), @@ -143,7 +143,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "min_size": (17, 17), diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 3398c2d6e9b..ccf0b5895e9 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -120,7 +120,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 132863336, @@ -134,7 +134,7 @@ class VGG11_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 132868840, @@ -148,7 +148,7 @@ class VGG11_BN_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 133047848, @@ -162,7 +162,7 @@ class VGG13_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 133053736, @@ -176,7 +176,7 @@ class VGG13_BN_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 138357544, @@ -190,7 +190,7 @@ class VGG16_Weights(WeightsEnum): IMAGENET1K_FEATURES = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( - ImageClassificationEval, + ImageClassificationInference, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), @@ -210,7 +210,7 @@ class VGG16_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 138365992, @@ -224,7 +224,7 @@ class VGG16_BN_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 143667240, @@ -238,7 +238,7 @@ class VGG19_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 143678248, diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index df215fce0a4..f121609c7e9 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor -from ...transforms._presets import VideoClassificationEval, InterpolationMode +from ...transforms._presets import VideoClassificationInference, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES @@ -322,7 +322,7 @@ def _video_resnet( class R3D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R3D", @@ -337,7 +337,7 @@ class R3D_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "MC3", @@ -352,7 +352,7 @@ class MC3_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R(2+1)D", diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 9e1b701b144..4d34c179ca6 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..ops.misc import Conv2dNormActivation -from ..transforms._presets import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassificationInference, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _vision_transformer( class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, @@ -334,7 +334,7 @@ class ViT_B_16_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, @@ -351,7 +351,7 @@ class ViT_B_32_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), + transforms=partial(ImageClassificationInference, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, @@ -368,7 +368,7 @@ class ViT_L_16_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassificationInference, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index e43c69fff98..455af8d7d84 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -11,22 +11,22 @@ __all__ = [ - "ObjectDetectionEval", - "ImageClassificationEval", - "VideoClassificationEval", - "SemanticSegmentationEval", - "OpticalFlowEval", + "ObjectDetectionInference", + "ImageClassificationInference", + "VideoClassificationInference", + "SemanticSegmentationInference", + "OpticalFlowInference", ] -class ObjectDetectionEval(nn.Module): +class ObjectDetectionInference(nn.Module): def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) return F.convert_image_dtype(img, torch.float) -class ImageClassificationEval(nn.Module): +class ImageClassificationInference(nn.Module): def __init__( self, crop_size: int, @@ -52,7 +52,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class VideoClassificationEval(nn.Module): +class VideoClassificationInference(nn.Module): def __init__( self, crop_size: Tuple[int, int], @@ -89,7 +89,7 @@ def forward(self, vid: Tensor) -> Tensor: return vid -class SemanticSegmentationEval(nn.Module): +class SemanticSegmentationInference(nn.Module): def __init__( self, resize_size: Optional[int], @@ -115,7 +115,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class OpticalFlowEval(nn.Module): +class OpticalFlowInference(nn.Module): def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: if not isinstance(img1, Tensor): img1 = F.pil_to_tensor(img1) From 300eff7c15fde90c5c47f71594d222136967e153 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 17:17:14 +0000 Subject: [PATCH 06/18] Minor refactoring --- references/video_classification/presets.py | 4 ++-- references/video_classification/train.py | 4 ++-- torchvision/transforms/_presets.py | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 04039c9a4f1..d24169e42dd 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -6,8 +6,8 @@ class VideoClassificationPresetTrain: def __init__( self, - resize_size, crop_size, + resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), hflip_prob=0.5, @@ -27,7 +27,7 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__(self, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): self.transforms = transforms.Compose( [ ConvertBHWCtoBCHW(), diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d36785ddf96..dc809e7a98a 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -120,7 +120,7 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) + transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -151,7 +151,7 @@ def main(args): cache_path = _get_cache_path(valdir) if not args.prototype: - transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) + transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) else: if args.weights: weights = prototype.models.get_weight(args.weights) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 455af8d7d84..5ba8c709d22 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -96,14 +96,12 @@ def __init__( mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation - self._interpolation_target = interpolation_target def forward(self, img: Tensor) -> Tensor: if isinstance(self._size, list): @@ -122,8 +120,8 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: if not isinstance(img2, Tensor): img2 = F.pil_to_tensor(img2) - img1 = F.convert_image_dtype(img1, torch.float32) - img2 = F.convert_image_dtype(img2, torch.float32) + img1 = F.convert_image_dtype(img1, torch.float) + img2 = F.convert_image_dtype(img2, torch.float) # map [0, 1] into [-1, 1] img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) From 996f20494c0e5d6f3a44fc5b873a229c7170c7bb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 17:59:50 +0000 Subject: [PATCH 07/18] Remove --prototype and --pretrained from reference scripts --- references/classification/train.py | 42 ++-------------- .../classification/train_quantization.py | 23 +-------- references/classification/utils.py | 8 ++-- references/detection/train.py | 43 +++-------------- references/optical_flow/train.py | 34 +++---------- references/segmentation/train.py | 48 ++++--------------- references/video_classification/train.py | 38 ++------------- 7 files changed, 35 insertions(+), 201 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 569cf3009e7..dddf4e73ec6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -15,12 +15,6 @@ from torchvision.transforms.functional import InterpolationMode -try: - from torchvision import prototype -except ImportError: - prototype = None - - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -154,18 +148,13 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if not args.prototype: + if args.weights: + weights = torchvision.models.get_weight(args.weights) + preprocessing = weights.transforms() + else: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) - else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.ImageClassificationEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation - ) dataset_test = torchvision.datasets.ImageFolder( valdir, @@ -191,10 +180,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -236,10 +221,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - else: - model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -446,12 +428,6 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") @@ -496,14 +472,6 @@ def get_args_parser(add_help=True): parser.add_argument( "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 111777a860b..c0e5af1dcfc 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -12,17 +12,7 @@ from train import train_one_epoch, evaluate, load_data -try: - from torchvision import prototype -except ImportError: - prototype = None - - def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -56,10 +46,7 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - if not args.prototype: - model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) - else: - model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): @@ -264,14 +251,6 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 7f573415c4c..12876f3c241 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T from torchvision import models as M # Classification - model = M.mobilenet_v3_large(pretrained=False) + model = M.mobilenet_v3_large() print(store_model_weights(model, './class.pth')) # Quantized Classification - model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) + model = M.quantization.mobilenet_v3_large(quantize=False) model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False) + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained_backbone=False) print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained_backbone=False, aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: diff --git a/references/detection/train.py b/references/detection/train.py index 3909e6413d0..20a9eba0add 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -33,12 +33,6 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -50,14 +44,12 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - elif not args.prototype: - return presets.DetectionPresetEval() + if args.weights: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target = None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.ObjectDetectionEval() + return presets.DetectionPresetEval() def get_args_parser(add_help=True): @@ -132,24 +124,10 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters @@ -159,10 +137,6 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -204,12 +178,7 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - if not args.prototype: - model = torchvision.models.detection.__dict__[args.model]( - pretrained=args.pretrained, num_classes=num_classes, **kwargs - ) - else: - model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = torchvision.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 83952242eb9..614cbfe02e8 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -9,11 +9,6 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -try: - from torchvision import prototype -except ImportError: - prototype = None - def get_train_dataset(stage, dataset_root): if stage == "chairs": @@ -138,12 +133,10 @@ def inner_loop(blob): def evaluate(model, args): val_datasets = args.val_dataset or [] - if args.prototype: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.OpticalFlowEval() + if args.weights: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img1, img2, flow = None, valid = None: (trans(img1, img2), flow, valid) else: preprocessing = OpticalFlowPresetEval() @@ -201,20 +194,13 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) - if args.prototype: - model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) - else: - model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) if args.distributed: model = model.to(args.local_rank) @@ -356,8 +342,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" ) - # TODO: resume, pretrained, and weights should be in an exclusive arg group - parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + # TODO: resume and weights should be in an exclusive arg group parser.add_argument( "--num_flow_updates", @@ -376,13 +361,6 @@ def get_args_parser(add_help=True): required=True, ) - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5dc03945bd7..5b3bf63889d 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,12 +11,6 @@ from torch import nn -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) @@ -35,14 +29,12 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif not args.prototype: - return presets.SegmentationPresetEval(base_size=520) + elif args.weights: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target = None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.SemanticSegmentationEval(resize_size=520) + return presets.SegmentationPresetEval(base_size=520) def criterion(inputs, target): @@ -100,10 +92,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -135,16 +123,9 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - if not args.prototype: - model = torchvision.models.segmentation.__dict__[args.model]( - pretrained=args.pretrained, - num_classes=num_classes, - aux_loss=args.aux_loss, - ) - else: - model = prototype.models.segmentation.__dict__[args.model]( - weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss - ) + model = torchvision.models.segmentation.__dict__[args.model]( + weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -272,23 +253,10 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters diff --git a/references/video_classification/train.py b/references/video_classification/train.py index dc809e7a98a..05c031305c4 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,11 +12,6 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from torchvision import prototype -except ImportError: - prototype = None - def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() @@ -96,10 +91,6 @@ def collate_fn(batch): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -150,14 +141,11 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if not args.prototype: - transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) + if args.weights: + weights = torchvision.models.get_weight(args.weights) + transform_test = weights.transforms() else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - transform_test = weights.transforms() - else: - transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -208,10 +196,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - else: - model = prototype.models.video.__dict__[args.model](weights=args.weights) + model = torchvision.models.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -352,24 +337,11 @@ def parse_args(): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters From 6dd6983ed220cee9ce068c0308af9c7c10acb469 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 18:31:24 +0000 Subject: [PATCH 08/18] remove pretained_backbone refs --- references/classification/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 12876f3c241..27398d97234 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -341,11 +341,11 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained_backbone=False) + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn() print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained_backbone=False, aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: From 0bf8c6d7c2518fe06ec5964bcca0bb63af5a834f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 09:42:11 +0000 Subject: [PATCH 09/18] Corrections and simplifications --- references/classification/train.py | 2 +- references/detection/train.py | 2 +- references/optical_flow/train.py | 4 +- references/segmentation/train.py | 2 +- references/video_classification/train.py | 2 +- test/test_extended_models.py | 1 - torchvision/models/alexnet.py | 4 +- torchvision/models/convnext.py | 10 ++-- torchvision/models/densenet.py | 10 ++-- torchvision/models/detection/faster_rcnn.py | 8 +-- torchvision/models/detection/fcos.py | 4 +- torchvision/models/detection/keypoint_rcnn.py | 6 +- torchvision/models/detection/mask_rcnn.py | 4 +- torchvision/models/detection/retinanet.py | 4 +- torchvision/models/detection/ssd.py | 4 +- torchvision/models/detection/ssdlite.py | 4 +- torchvision/models/efficientnet.py | 26 ++++----- torchvision/models/googlenet.py | 4 +- torchvision/models/inception.py | 4 +- torchvision/models/mnasnet.py | 6 +- torchvision/models/mobilenetv2.py | 6 +- torchvision/models/mobilenetv3.py | 8 +-- torchvision/models/optical_flow/raft.py | 18 +++--- torchvision/models/quantization/googlenet.py | 4 +- torchvision/models/quantization/inception.py | 4 +- .../models/quantization/mobilenetv2.py | 4 +- .../models/quantization/mobilenetv3.py | 4 +- torchvision/models/quantization/resnet.py | 12 ++-- .../models/quantization/shufflenetv2.py | 6 +- torchvision/models/regnet.py | 58 +++++++++---------- torchvision/models/resnet.py | 34 +++++------ torchvision/models/segmentation/deeplabv3.py | 8 +-- torchvision/models/segmentation/fcn.py | 6 +- torchvision/models/segmentation/lraspp.py | 4 +- torchvision/models/shufflenetv2.py | 6 +- torchvision/models/squeezenet.py | 6 +- torchvision/models/vgg.py | 20 +++---- torchvision/models/video/resnet.py | 8 +-- torchvision/models/vision_transformer.py | 10 ++-- torchvision/transforms/_presets.py | 20 +++---- 40 files changed, 178 insertions(+), 179 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index dddf4e73ec6..eb8b56c1ad0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -148,7 +148,7 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if args.weights: + if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) preprocessing = weights.transforms() else: diff --git a/references/detection/train.py b/references/detection/train.py index 20a9eba0add..376a30a4a78 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -44,7 +44,7 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - if args.weights: + if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target = None: (trans(img), target) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 614cbfe02e8..64e1ba67c1f 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -133,10 +133,10 @@ def inner_loop(blob): def evaluate(model, args): val_datasets = args.val_dataset or [] - if args.weights: + if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img1, img2, flow = None, valid = None: (trans(img1, img2), flow, valid) + return lambda img1, img2, flow = None, valid = None: trans(img1, img2) + (flow, valid) else: preprocessing = OpticalFlowPresetEval() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5b3bf63889d..de0db9b028a 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -29,7 +29,7 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif args.weights: + elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target = None: (trans(img), target) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 05c031305c4..da7ef9fc607 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -141,7 +141,7 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if args.weights: + if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) transform_test = weights.transforms() else: diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 170ba47233a..3d5995b2c83 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -150,7 +150,6 @@ def test_transforms_jit(model_fn): }, "quantization": { "input_shape": (1, 3, 224, 224), - "quantize": True, }, "segmentation": { "input_shape": (1, 3, 520, 520), diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 46887b09e5f..6ee5b98c673 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "AlexNet", diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 7f3d880f66e..8774b9a1bc2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,7 +7,7 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -218,7 +218,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=236), + transforms=partial(ImageClassification, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, @@ -232,7 +232,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=230), + transforms=partial(ImageClassification, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, @@ -246,7 +246,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, @@ -260,7 +260,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 681044aefc8..2ffb29c54cb 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp from torch import Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -280,7 +280,7 @@ def _densenet( class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7978856, @@ -294,7 +294,7 @@ class DenseNet121_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 28681000, @@ -308,7 +308,7 @@ class DenseNet161_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 14149480, @@ -322,7 +322,7 @@ class DenseNet169_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 20013928, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 49c79cc3ae9..7d18fbe90a3 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -336,7 +336,7 @@ def forward(self, x): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 41755286, @@ -350,7 +350,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, @@ -364,7 +364,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index b6f55020cff..27e54a565f2 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -646,7 +646,7 @@ def forward( class FCOS_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "FCOS", diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index fbd82099b2e..2a554a6f56e 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -318,7 +318,7 @@ def forward(self, x): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_LEGACY = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, @@ -329,7 +329,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ) COCO_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index ce03c48a2ad..fb60ffcbb0a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -308,7 +308,7 @@ def __init__(self, in_channels, dim_reduced, num_classes): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "MaskRCNN", diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d6c4e843909..49b9acf45e4 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -10,7 +10,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -588,7 +588,7 @@ def forward(self, images, targets=None): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "RetinaNet", diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index d6d4a6a0d51..c30919e621c 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops import boxes as box_ops -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -28,7 +28,7 @@ class SSD300_VGG16_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSD", diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 69b8b91f8c7..93023337d11 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops.misc import Conv2dNormActivation -from ...transforms._presets import ObjectDetectionInference, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet from .._api import WeightsEnum, Weights @@ -187,7 +187,7 @@ def _mobilenet_extractor( class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionInference, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSDLite", diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 78733d2ea2a..b9d3b9b30c9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,7 @@ from torchvision.ops import StochasticDepth from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -458,7 +458,7 @@ class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial( - ImageClassificationInference, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -475,7 +475,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial( - ImageClassificationInference, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -488,7 +488,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", transforms=partial( - ImageClassificationInference, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR ), meta={ **_COMMON_META_V1, @@ -507,7 +507,7 @@ class EfficientNet_B2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial( - ImageClassificationInference, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -524,7 +524,7 @@ class EfficientNet_B3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial( - ImageClassificationInference, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -541,7 +541,7 @@ class EfficientNet_B4_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial( - ImageClassificationInference, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -558,7 +558,7 @@ class EfficientNet_B5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial( - ImageClassificationInference, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -575,7 +575,7 @@ class EfficientNet_B6_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial( - ImageClassificationInference, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -592,7 +592,7 @@ class EfficientNet_B7_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial( - ImageClassificationInference, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -609,7 +609,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", transforms=partial( - ImageClassificationInference, + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BILINEAR, @@ -629,7 +629,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", transforms=partial( - ImageClassificationInference, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BILINEAR, @@ -649,7 +649,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", transforms=partial( - ImageClassificationInference, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BICUBIC, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index af0fc430a4e..ced92571974 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tensor: class GoogLeNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 20ad569b8d6..816fab45549 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -410,7 +410,7 @@ def forward(self, x: Tensor) -> Tensor: class Inception_V3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationInference, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 374c0ecc8fa..578e77f7934 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -226,7 +226,7 @@ def _load_from_state_dict( class MNASNet0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2218512, @@ -245,7 +245,7 @@ class MNASNet0_75_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4383312, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 63a27186e61..085049117ec 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -7,7 +7,7 @@ from torch import nn from ..ops.misc import Conv2dNormActivation -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -209,7 +209,7 @@ def forward(self, x: Tensor) -> Tensor: class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", @@ -219,7 +219,7 @@ class MobileNet_V2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 67ca3dccf76..91e1ea91a94 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _mobilenet_v3( class MobileNet_V3_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5483032, @@ -328,7 +328,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5483032, @@ -343,7 +343,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2542856, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 8921eb96dc3..244d2b2fac1 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,7 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ...transforms._presets import OpticalFlowInference, InterpolationMode +from ...transforms._presets import OpticalFlow, InterpolationMode from ...utils import _log_api_usage_once from .._api import Weights, WeightsEnum from .._utils import handle_legacy_interface @@ -523,7 +523,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -538,7 +538,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -553,7 +553,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_V1 = Weights( # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -568,7 +568,7 @@ class Raft_Large_Weights(WeightsEnum): # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -581,7 +581,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_K_V1 = Weights( # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -596,7 +596,7 @@ class Raft_Large_Weights(WeightsEnum): # Same as CT_SKHT with extra fine-tuning on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -612,7 +612,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth) url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, @@ -626,7 +626,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowInference, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 15be9ccdcc2..9944e470352 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.nn import functional as F -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -109,7 +109,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class GoogLeNet_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index a4577946619..9a732f79fb7 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -175,7 +175,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class Inception_V3_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationInference, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 992188d28ad..1def3d24b28 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights from ...ops.misc import Conv2dNormActivation -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -67,7 +67,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class MobileNet_V2_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV2", diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index c716b192fd4..4a203ca7095 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -157,7 +157,7 @@ def _mobilenet_v3_model( class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV3", diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 7d55b2c1c2a..ab512a7413f 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -13,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -161,7 +161,7 @@ def _resnet( class ResNet18_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -178,7 +178,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -191,7 +191,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -208,7 +208,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -221,7 +221,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9a7354a547a..a3a26120479 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -6,7 +6,7 @@ from torch import Tensor from torchvision.models import shufflenetv2 -from ...transforms._presets import ImageClassificationInference, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -118,7 +118,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -133,7 +133,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 6f57d86b67f..72093686d84 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -416,7 +416,7 @@ def _regnet( class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4344144, @@ -427,7 +427,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 4344144, @@ -442,7 +442,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 6432512, @@ -453,7 +453,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 6432512, @@ -468,7 +468,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 11202430, @@ -479,7 +479,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 11202430, @@ -494,7 +494,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 19436338, @@ -505,7 +505,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 19436338, @@ -520,7 +520,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39381472, @@ -531,7 +531,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39381472, @@ -546,7 +546,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 83590140, @@ -557,7 +557,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 83590140, @@ -572,7 +572,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 145046770, @@ -583,7 +583,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 145046770, @@ -603,7 +603,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5495976, @@ -614,7 +614,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5495976, @@ -629,7 +629,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7259656, @@ -640,7 +640,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 7259656, @@ -655,7 +655,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 9190136, @@ -666,7 +666,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 9190136, @@ -681,7 +681,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 15296552, @@ -692,7 +692,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 15296552, @@ -707,7 +707,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39572648, @@ -718,7 +718,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39572648, @@ -733,7 +733,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 54278536, @@ -744,7 +744,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 54278536, @@ -759,7 +759,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 107811560, @@ -770,7 +770,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 107811560, diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 6f6385a0ba6..8f44e553296 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -313,7 +313,7 @@ def _resnet( class ResNet18_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -330,7 +330,7 @@ class ResNet18_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -347,7 +347,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -360,7 +360,7 @@ class ResNet50_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -377,7 +377,7 @@ class ResNet50_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -390,7 +390,7 @@ class ResNet101_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -407,7 +407,7 @@ class ResNet101_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -420,7 +420,7 @@ class ResNet152_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -437,7 +437,7 @@ class ResNet152_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -450,7 +450,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -467,7 +467,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -480,7 +480,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -497,7 +497,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -510,7 +510,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -527,7 +527,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -540,7 +540,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 56d575edcc4..41ab34bae07 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from ...transforms._presets import SemanticSegmentationInference, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -140,7 +140,7 @@ def _deeplabv3_resnet( class DeepLabV3_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 42004074, @@ -155,7 +155,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 60996202, @@ -170,7 +170,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 11029328, diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 9157a25a093..6a760be36dc 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn -from ...transforms._presets import SemanticSegmentationInference, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -59,7 +59,7 @@ def __init__(self, in_channels: int, channels: int) -> None: class FCN_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 35322218, @@ -74,7 +74,7 @@ class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 54314346, diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index e527892d860..33684526c6b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ...transforms._presets import SemanticSegmentationInference, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES @@ -96,7 +96,7 @@ def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationInference, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ "task": "image_semantic_segmentation", "architecture": "LRASPP", diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 4204565868c..e988b819078 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -198,7 +198,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -212,7 +212,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 9fef930fb06..bde8b5efcfd 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -128,7 +128,7 @@ def _squeezenet( class SqueezeNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (21, 21), @@ -143,7 +143,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (17, 17), diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index ccf0b5895e9..93bfd5e6ba3 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -120,7 +120,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132863336, @@ -134,7 +134,7 @@ class VGG11_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132868840, @@ -148,7 +148,7 @@ class VGG11_BN_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133047848, @@ -162,7 +162,7 @@ class VGG13_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133053736, @@ -176,7 +176,7 @@ class VGG13_BN_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138357544, @@ -190,7 +190,7 @@ class VGG16_Weights(WeightsEnum): IMAGENET1K_FEATURES = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( - ImageClassificationInference, + ImageClassification, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), @@ -210,7 +210,7 @@ class VGG16_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138365992, @@ -224,7 +224,7 @@ class VGG16_BN_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143667240, @@ -238,7 +238,7 @@ class VGG19_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143678248, diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index f121609c7e9..618ddb96ba2 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor -from ...transforms._presets import VideoClassificationInference, InterpolationMode +from ...transforms._presets import VideoClassification, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES @@ -322,7 +322,7 @@ def _video_resnet( class R3D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R3D", @@ -337,7 +337,7 @@ class R3D_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "MC3", @@ -352,7 +352,7 @@ class MC3_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationInference, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R(2+1)D", diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 4d34c179ca6..fb34cf3c8e1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..ops.misc import Conv2dNormActivation -from ..transforms._presets import ImageClassificationInference, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _vision_transformer( class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, @@ -334,7 +334,7 @@ class ViT_B_16_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, @@ -351,7 +351,7 @@ class ViT_B_32_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationInference, crop_size=224, resize_size=242), + transforms=partial(ImageClassification, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, @@ -368,7 +368,7 @@ class ViT_L_16_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationInference, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 5ba8c709d22..e20e1fea22c 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -11,22 +11,22 @@ __all__ = [ - "ObjectDetectionInference", - "ImageClassificationInference", - "VideoClassificationInference", - "SemanticSegmentationInference", - "OpticalFlowInference", + "ObjectDetection", + "ImageClassification", + "VideoClassification", + "SemanticSegmentation", + "OpticalFlow", ] -class ObjectDetectionInference(nn.Module): +class ObjectDetection(nn.Module): def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) return F.convert_image_dtype(img, torch.float) -class ImageClassificationInference(nn.Module): +class ImageClassification(nn.Module): def __init__( self, crop_size: int, @@ -52,7 +52,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class VideoClassificationInference(nn.Module): +class VideoClassification(nn.Module): def __init__( self, crop_size: Tuple[int, int], @@ -89,7 +89,7 @@ def forward(self, vid: Tensor) -> Tensor: return vid -class SemanticSegmentationInference(nn.Module): +class SemanticSegmentation(nn.Module): def __init__( self, resize_size: Optional[int], @@ -113,7 +113,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class OpticalFlowInference(nn.Module): +class OpticalFlow(nn.Module): def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: if not isinstance(img1, Tensor): img1 = F.pil_to_tensor(img1) From 3311252a6a0ad7f52019de956d15235c3bbca133 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 09:48:54 +0000 Subject: [PATCH 10/18] Fixing bug --- references/optical_flow/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 64e1ba67c1f..4b814bfd3cb 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -136,7 +136,7 @@ def evaluate(model, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img1, img2, flow = None, valid = None: trans(img1, img2) + (flow, valid) + preprocessing = lambda img1, img2, flow = None, valid = None: trans(img1, img2) + (flow, valid) else: preprocessing = OpticalFlowPresetEval() From 6f5dd2d1a97035669ae4d9c4aff7c7c42c85880b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 09:49:31 +0000 Subject: [PATCH 11/18] Fixing linter --- references/detection/train.py | 2 +- references/optical_flow/train.py | 2 +- references/segmentation/train.py | 2 +- test/test_extended_models.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 376a30a4a78..0cf2cf143bf 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -47,7 +47,7 @@ def get_transform(train, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img, target = None: (trans(img), target) + return lambda img, target=None: (trans(img), target) else: return presets.DetectionPresetEval() diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 4b814bfd3cb..fc92155d86e 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -136,7 +136,7 @@ def evaluate(model, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - preprocessing = lambda img1, img2, flow = None, valid = None: trans(img1, img2) + (flow, valid) + preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) else: preprocessing = OpticalFlowPresetEval() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index de0db9b028a..20cbcb86493 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -32,7 +32,7 @@ def get_transform(train, args): elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img, target = None: (trans(img), target) + return lambda img, target=None: (trans(img), target) else: return presets.SegmentationPresetEval(base_size=520) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 3d5995b2c83..a07b501e15b 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -1,9 +1,9 @@ import importlib import os -import torch import pytest import test_models as TM +import torch from torchvision import models from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface @@ -169,7 +169,7 @@ def test_transforms_jit(model_fn): if module_name == "optical_flow": args = (x, x) else: - args = (x, ) + args = (x,) problematic_weights = [] for w in weights_enum: From 0446e68eac403bebd0971d5417a9bf8df852ba29 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 09:51:38 +0000 Subject: [PATCH 12/18] Fix flake8 --- references/optical_flow/train.py | 2 +- torchvision/transforms/_presets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index fc92155d86e..04190ca71f0 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -136,7 +136,7 @@ def evaluate(model, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) + preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) # noqa: E731 else: preprocessing = OpticalFlowPresetEval() diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index e20e1fea22c..0bfb1cf9b38 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -2,7 +2,7 @@ This file is part of the private API. Please do not use directly these classes as they will be modified on future versions without warning. The classes should be accessed only via the transforms argument of Weights. """ -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn From 1c318d97aceeb76cd2479cd0ee73fd540456de67 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 09:56:17 +0000 Subject: [PATCH 13/18] restore documentation example --- gallery/plot_repurposing_annotations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index 0a795b9162b..7bb68617a17 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -147,7 +147,10 @@ def show(imgs): tranforms = weights.transforms() img = tranforms(img) -detection_outputs = model(img.unsqueeze(0)) +target = {} +target["boxes"] = boxes +target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) +detection_outputs = model(img.unsqueeze(0), [target]) #################################### From 3a01c89552d69961f7682c912207bfc2a74da2de Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 10:35:49 +0000 Subject: [PATCH 14/18] minor fixes --- references/detection/train.py | 2 +- test/test_models.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 0cf2cf143bf..c2c5659366f 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -44,7 +44,7 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - if args.weights and args.test_only: + elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() return lambda img, target=None: (trans(img), target) diff --git a/test/test_models.py b/test/test_models.py index 137c325be1a..0d45d61df13 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -134,8 +134,6 @@ def get_export_import_copy(m): if eager_out is None: with torch.no_grad(), freeze_rng_state(): eager_out = nn_module(*args) - if unwrapper: - eager_out = unwrapper(eager_out) with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) From 93d654d512f65ac56cef4f1945029f5d307f48df Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 11:34:42 +0000 Subject: [PATCH 15/18] fix optical flow missing param --- references/optical_flow/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 04190ca71f0..1a50d1c617d 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -195,6 +195,7 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): utils.setup_ddp(args) + args.test_only = args.train_dataset is None if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") @@ -214,7 +215,7 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) - if args.train_dataset is None: + if args.test_only: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True From fc42bf00eb89a8b0ceb36700e9ef68659a4611c9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 11:45:24 +0000 Subject: [PATCH 16/18] Fixing commands --- docs/source/models.rst | 54 +---------------------------- references/classification/README.md | 27 ++++++--------- references/optical_flow/README.md | 4 +-- 3 files changed, 13 insertions(+), 72 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 50af05360e4..39543cb8027 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor: convnext_large = models.convnext_large() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. -These can be constructed by passing ``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - squeezenet = models.squeezenet1_0(pretrained=True) - vgg16 = models.vgg16(pretrained=True) - densenet = models.densenet161(pretrained=True) - inception = models.inception_v3(pretrained=True) - googlenet = models.googlenet(pretrained=True) - shufflenet = models.shufflenet_v2_x1_0(pretrained=True) - mobilenet_v2 = models.mobilenet_v2(pretrained=True) - mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) - mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True) - resnext50_32x4d = models.resnext50_32x4d(pretrained=True) - wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) - mnasnet = models.mnasnet1_0(pretrained=True) - efficientnet_b0 = models.efficientnet_b0(pretrained=True) - efficientnet_b1 = models.efficientnet_b1(pretrained=True) - efficientnet_b2 = models.efficientnet_b2(pretrained=True) - efficientnet_b3 = models.efficientnet_b3(pretrained=True) - efficientnet_b4 = models.efficientnet_b4(pretrained=True) - efficientnet_b5 = models.efficientnet_b5(pretrained=True) - efficientnet_b6 = models.efficientnet_b6(pretrained=True) - efficientnet_b7 = models.efficientnet_b7(pretrained=True) - efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True) - efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True) - efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True) - regnet_y_400mf = models.regnet_y_400mf(pretrained=True) - regnet_y_800mf = models.regnet_y_800mf(pretrained=True) - regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) - regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True) - regnet_y_8gf = models.regnet_y_8gf(pretrained=True) - regnet_y_16gf = models.regnet_y_16gf(pretrained=True) - regnet_y_32gf = models.regnet_y_32gf(pretrained=True) - regnet_x_400mf = models.regnet_x_400mf(pretrained=True) - regnet_x_800mf = models.regnet_x_800mf(pretrained=True) - regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True) - regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True) - regnet_x_8gf = models.regnet_x_8gf(pretrained=True) - regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) - regnet_x_32gf = models.regnet_x_32gf(pretrained=True) - vit_b_16 = models.vit_b_16(pretrained=True) - vit_b_32 = models.vit_b_32(pretrained=True) - vit_l_16 = models.vit_l_16(pretrained=True) - vit_l_32 = models.vit_l_32(pretrained=True) - convnext_tiny = models.convnext_tiny(pretrained=True) - convnext_small = models.convnext_small(pretrained=True) - convnext_base = models.convnext_base(pretrained=True) - convnext_large = models.convnext_large(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code: .. code:: python import torchvision.models as models - model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) + model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True) model.eval() # run the model with quantized inputs and weights out = model(torch.rand(1, 3, 224, 224)) diff --git a/references/classification/README.md b/references/classification/README.md index 173fb454995..289a73b81f9 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -43,7 +43,8 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model ``` torchrun --nproc_per_node=8 train.py --model inception_v3\ - --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained + --val-resize-size 342 --val-crop-size 299 --train-crop-size 299\ + --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 ``` ### ResNet @@ -96,22 +97,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: ``` -torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\ - --val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\ - --val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\ - --val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\ - --val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\ - --val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\ - --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1 ``` diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index a7620ce4be6..a7ac0223739 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -51,7 +51,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ### Evaluation ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 ``` This should give an epe of about 1.3822 on the clean pass and 2.7161 on the @@ -67,6 +67,6 @@ Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: You can also evaluate on Kitti train: ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 ``` From 98df9a8a55460909815d46013485d1f5c5fcacb1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 12:05:22 +0000 Subject: [PATCH 17/18] Adding weights_backbone support in detection and segmentation --- references/detection/README.md | 18 +++++++++--------- references/detection/train.py | 5 ++++- references/segmentation/README.md | 12 ++++++------ references/segmentation/train.py | 3 ++- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/references/detection/README.md b/references/detection/README.md index 3695644138b..aec7c10e1b5 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -24,35 +24,35 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs. ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large 320 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### FCOS ResNet-50 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fcos_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### SSD300 VGG16 @@ -60,7 +60,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssd300_vgg16 --epochs 120\ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ - --weight-decay 0.0005 --data-augmentation ssd + --weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES ``` ### SSDlite320 MobileNetV3-Large @@ -68,7 +68,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ - --weight-decay 0.00004 --data-augmentation ssdlite + --weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` @@ -76,7 +76,7 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` @@ -84,5 +84,5 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\ - --lr-steps 36 43 --aspect-ratio-group-factor 3 + --lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` diff --git a/references/detection/train.py b/references/detection/train.py index c2c5659366f..0e0a0d70fad 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -129,6 +129,7 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") @@ -178,7 +179,9 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - model = torchvision.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/segmentation/README.md b/references/segmentation/README.md index e9b5391215a..2c7391c8380 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -14,30 +14,30 @@ You must modify the following flags: ## fcn_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## fcn_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ## lraspp_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 20cbcb86493..b4e55acd407 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -124,7 +124,7 @@ def main(args): ) model = torchvision.models.segmentation.__dict__[args.model]( - weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss ) model.to(device) if args.distributed: @@ -258,6 +258,7 @@ def get_args_parser(add_help=True): parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") From b916fd249c8a7f7f78ef2c25869a2ed6cc768fe5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 13:09:37 +0000 Subject: [PATCH 18/18] Updating the commands for InceptionV3 --- references/classification/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/references/classification/README.md b/references/classification/README.md index 289a73b81f9..c274c997791 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -43,7 +43,6 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model ``` torchrun --nproc_per_node=8 train.py --model inception_v3\ - --val-resize-size 342 --val-crop-size 299 --train-crop-size 299\ --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 ```