Skip to content

Commit 775129b

Browse files
authored
Remove non-functional Transforms from presets. (#4952)
1 parent 4b2ad55 commit 775129b

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

torchvision/prototype/models/video/resnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _video_resnet(
6262
class R3D_18Weights(Weights):
6363
Kinetics400_RefV1 = WeightEntry(
6464
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
65-
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
65+
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
6666
meta={
6767
**_COMMON_META,
6868
"acc@1": 52.75,
@@ -74,7 +74,7 @@ class R3D_18Weights(Weights):
7474
class MC3_18Weights(Weights):
7575
Kinetics400_RefV1 = WeightEntry(
7676
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
77-
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
77+
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
7878
meta={
7979
**_COMMON_META,
8080
"acc@1": 53.90,
@@ -86,7 +86,7 @@ class MC3_18Weights(Weights):
8686
class R2Plus1D_18Weights(Weights):
8787
Kinetics400_RefV1 = WeightEntry(
8888
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
89-
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
89+
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
9090
meta={
9191
**_COMMON_META,
9292
"acc@1": 57.50,

torchvision/prototype/transforms/_presets.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import torch
44
from torch import Tensor, nn
55

6-
from ... import transforms as T
7-
from ...transforms import functional as F
6+
from ...transforms import functional as F, InterpolationMode
87

98

109
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
@@ -26,42 +25,47 @@ def __init__(
2625
resize_size: int = 256,
2726
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
2827
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
29-
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
28+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
3029
) -> None:
3130
super().__init__()
32-
self._resize = T.Resize(resize_size, interpolation=interpolation)
33-
self._crop = T.CenterCrop(crop_size)
34-
self._normalize = T.Normalize(mean=mean, std=std)
31+
self._crop_size = [crop_size]
32+
self._size = [resize_size]
33+
self._mean = list(mean)
34+
self._std = list(std)
35+
self._interpolation = interpolation
3536

3637
def forward(self, img: Tensor) -> Tensor:
37-
img = self._crop(self._resize(img))
38+
img = F.resize(img, self._size, interpolation=self._interpolation)
39+
img = F.center_crop(img, self._crop_size)
3840
if not isinstance(img, Tensor):
3941
img = F.pil_to_tensor(img)
4042
img = F.convert_image_dtype(img, torch.float)
41-
return self._normalize(img)
43+
img = F.normalize(img, mean=self._mean, std=self._std)
44+
return img
4245

4346

4447
class Kinect400Eval(nn.Module):
4548
def __init__(
4649
self,
47-
resize_size: Tuple[int, int],
4850
crop_size: Tuple[int, int],
51+
resize_size: Tuple[int, int],
4952
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
5053
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
51-
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
54+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
5255
) -> None:
5356
super().__init__()
54-
self._convert = T.ConvertImageDtype(torch.float)
55-
self._resize = T.Resize(resize_size, interpolation=interpolation)
56-
self._normalize = T.Normalize(mean=mean, std=std)
57-
self._crop = T.CenterCrop(crop_size)
57+
self._crop_size = list(crop_size)
58+
self._size = list(resize_size)
59+
self._mean = list(mean)
60+
self._std = list(std)
61+
self._interpolation = interpolation
5862

5963
def forward(self, vid: Tensor) -> Tensor:
6064
vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W)
61-
vid = self._convert(vid)
62-
vid = self._resize(vid)
63-
vid = self._normalize(vid)
64-
vid = self._crop(vid)
65+
vid = F.resize(vid, self._size, interpolation=self._interpolation)
66+
vid = F.center_crop(vid, self._crop_size)
67+
vid = F.convert_image_dtype(vid, torch.float)
68+
vid = F.normalize(vid, mean=self._mean, std=self._std)
6569
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
6670

6771

@@ -71,8 +75,8 @@ def __init__(
7175
resize_size: int,
7276
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
7377
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
74-
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
75-
interpolation_target: T.InterpolationMode = T.InterpolationMode.NEAREST,
78+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
79+
interpolation_target: InterpolationMode = InterpolationMode.NEAREST,
7680
) -> None:
7781
super().__init__()
7882
self._size = [resize_size]

0 commit comments

Comments
 (0)