Skip to content

Commit d8654bb

Browse files
authored
Refactor preset transforms (#5562)
* Refactor preset transforms * Making presets public.
1 parent 2b5ab1b commit d8654bb

40 files changed

+207
-174
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
163163
weights = prototype.models.get_weight(args.weights)
164164
preprocessing = weights.transforms()
165165
else:
166-
preprocessing = prototype.transforms.ImageNetEval(
166+
preprocessing = prototype.transforms.ImageClassificationEval(
167167
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168168
)
169169

references/detection/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_transform(train, args):
5757
weights = prototype.models.get_weight(args.weights)
5858
return weights.transforms()
5959
else:
60-
return prototype.transforms.CocoEval()
60+
return prototype.transforms.ObjectDetectionEval()
6161

6262

6363
def get_args_parser(add_help=True):

references/optical_flow/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def validate(model, args):
137137
weights = prototype.models.get_weight(args.weights)
138138
preprocessing = weights.transforms()
139139
else:
140-
preprocessing = prototype.transforms.RaftEval()
140+
preprocessing = prototype.transforms.OpticalFlowEval()
141141
else:
142142
preprocessing = OpticalFlowPresetEval()
143143

references/segmentation/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_transform(train, args):
4242
weights = prototype.models.get_weight(args.weights)
4343
return weights.transforms()
4444
else:
45-
return prototype.transforms.VocEval(resize_size=520)
45+
return prototype.transforms.SemanticSegmentationEval(resize_size=520)
4646

4747

4848
def criterion(inputs, target):

references/video_classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def main(args):
157157
weights = prototype.models.get_weight(args.weights)
158158
transform_test = weights.transforms()
159159
else:
160-
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
160+
transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))
161161

162162
if args.cache_dataset and os.path.exists(cache_path):
163163
print(f"Loading dataset_test from {cache_path}")

torchvision/prototype/models/alexnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import partial
22
from typing import Any, Optional
33

4-
from torchvision.prototype.transforms import ImageNetEval
4+
from torchvision.prototype.transforms import ImageClassificationEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ...models.alexnet import AlexNet
@@ -16,7 +16,7 @@
1616
class AlexNet_Weights(WeightsEnum):
1717
IMAGENET1K_V1 = Weights(
1818
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
19-
transforms=partial(ImageNetEval, crop_size=224),
19+
transforms=partial(ImageClassificationEval, crop_size=224),
2020
meta={
2121
"task": "image_classification",
2222
"architecture": "AlexNet",

torchvision/prototype/models/convnext.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import partial
22
from typing import Any, List, Optional
33

4-
from torchvision.prototype.transforms import ImageNetEval
4+
from torchvision.prototype.transforms import ImageClassificationEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ...models.convnext import ConvNeXt, CNBlockConfig
@@ -56,7 +56,7 @@ def _convnext(
5656
class ConvNeXt_Tiny_Weights(WeightsEnum):
5757
IMAGENET1K_V1 = Weights(
5858
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
59-
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
59+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
6060
meta={
6161
**_COMMON_META,
6262
"num_params": 28589128,
@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
7070
class ConvNeXt_Small_Weights(WeightsEnum):
7171
IMAGENET1K_V1 = Weights(
7272
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
73-
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
73+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
7474
meta={
7575
**_COMMON_META,
7676
"num_params": 50223688,
@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
8484
class ConvNeXt_Base_Weights(WeightsEnum):
8585
IMAGENET1K_V1 = Weights(
8686
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
87-
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
87+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
8888
meta={
8989
**_COMMON_META,
9090
"num_params": 88591464,
@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
9898
class ConvNeXt_Large_Weights(WeightsEnum):
9999
IMAGENET1K_V1 = Weights(
100100
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
101-
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
101+
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
102102
meta={
103103
**_COMMON_META,
104104
"num_params": 197767336,

torchvision/prototype/models/densenet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Optional, Tuple
44

55
import torch.nn as nn
6-
from torchvision.prototype.transforms import ImageNetEval
6+
from torchvision.prototype.transforms import ImageClassificationEval
77
from torchvision.transforms.functional import InterpolationMode
88

99
from ...models.densenet import DenseNet
@@ -78,7 +78,7 @@ def _densenet(
7878
class DenseNet121_Weights(WeightsEnum):
7979
IMAGENET1K_V1 = Weights(
8080
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
81-
transforms=partial(ImageNetEval, crop_size=224),
81+
transforms=partial(ImageClassificationEval, crop_size=224),
8282
meta={
8383
**_COMMON_META,
8484
"num_params": 7978856,
@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
9292
class DenseNet161_Weights(WeightsEnum):
9393
IMAGENET1K_V1 = Weights(
9494
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
95-
transforms=partial(ImageNetEval, crop_size=224),
95+
transforms=partial(ImageClassificationEval, crop_size=224),
9696
meta={
9797
**_COMMON_META,
9898
"num_params": 28681000,
@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
106106
class DenseNet169_Weights(WeightsEnum):
107107
IMAGENET1K_V1 = Weights(
108108
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
109-
transforms=partial(ImageNetEval, crop_size=224),
109+
transforms=partial(ImageClassificationEval, crop_size=224),
110110
meta={
111111
**_COMMON_META,
112112
"num_params": 14149480,
@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
120120
class DenseNet201_Weights(WeightsEnum):
121121
IMAGENET1K_V1 = Weights(
122122
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
123-
transforms=partial(ImageNetEval, crop_size=224),
123+
transforms=partial(ImageClassificationEval, crop_size=224),
124124
meta={
125125
**_COMMON_META,
126126
"num_params": 20013928,

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional, Union
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.faster_rcnn import (
@@ -43,7 +43,7 @@
4343
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
4444
COCO_V1 = Weights(
4545
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
46-
transforms=CocoEval,
46+
transforms=ObjectDetectionEval,
4747
meta={
4848
**_COMMON_META,
4949
"num_params": 41755286,
@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
5757
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
5858
COCO_V1 = Weights(
5959
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
60-
transforms=CocoEval,
60+
transforms=ObjectDetectionEval,
6161
meta={
6262
**_COMMON_META,
6363
"num_params": 19386354,
@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
7171
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
7272
COCO_V1 = Weights(
7373
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
74-
transforms=CocoEval,
74+
transforms=ObjectDetectionEval,
7575
meta={
7676
**_COMMON_META,
7777
"num_params": 19386354,

torchvision/prototype/models/detection/fcos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.fcos import (
@@ -27,7 +27,7 @@
2727
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
2828
COCO_V1 = Weights(
2929
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
30-
transforms=CocoEval,
30+
transforms=ObjectDetectionEval,
3131
meta={
3232
"task": "image_object_detection",
3333
"architecture": "FCOS",

torchvision/prototype/models/detection/keypoint_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.keypoint_rcnn import (
@@ -37,7 +37,7 @@
3737
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
3838
COCO_LEGACY = Weights(
3939
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
40-
transforms=CocoEval,
40+
transforms=ObjectDetectionEval,
4141
meta={
4242
**_COMMON_META,
4343
"num_params": 59137258,
@@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
4848
)
4949
COCO_V1 = Weights(
5050
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
51-
transforms=CocoEval,
51+
transforms=ObjectDetectionEval,
5252
meta={
5353
**_COMMON_META,
5454
"num_params": 59137258,

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.mask_rcnn import (
@@ -27,7 +27,7 @@
2727
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
2828
COCO_V1 = Weights(
2929
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
30-
transforms=CocoEval,
30+
transforms=ObjectDetectionEval,
3131
meta={
3232
"task": "image_object_detection",
3333
"architecture": "MaskRCNN",

torchvision/prototype/models/detection/retinanet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from torch import nn
4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.retinanet import (
@@ -28,7 +28,7 @@
2828
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
2929
COCO_V1 = Weights(
3030
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
31-
transforms=CocoEval,
31+
transforms=ObjectDetectionEval,
3232
meta={
3333
"task": "image_object_detection",
3434
"architecture": "RetinaNet",

torchvision/prototype/models/detection/ssd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from typing import Any, Optional
33

4-
from torchvision.prototype.transforms import CocoEval
4+
from torchvision.prototype.transforms import ObjectDetectionEval
55
from torchvision.transforms.functional import InterpolationMode
66

77
from ....models.detection.ssd import (
@@ -25,7 +25,7 @@
2525
class SSD300_VGG16_Weights(WeightsEnum):
2626
COCO_V1 = Weights(
2727
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
28-
transforms=CocoEval,
28+
transforms=ObjectDetectionEval,
2929
meta={
3030
"task": "image_object_detection",
3131
"architecture": "SSD",

torchvision/prototype/models/detection/ssdlite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Optional
44

55
from torch import nn
6-
from torchvision.prototype.transforms import CocoEval
6+
from torchvision.prototype.transforms import ObjectDetectionEval
77
from torchvision.transforms.functional import InterpolationMode
88

99
from ....models.detection.ssdlite import (
@@ -30,7 +30,7 @@
3030
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
3131
COCO_V1 = Weights(
3232
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
33-
transforms=CocoEval,
33+
transforms=ObjectDetectionEval,
3434
meta={
3535
"task": "image_object_detection",
3636
"architecture": "SSDLite",

0 commit comments

Comments
 (0)