3
3
import torch
4
4
from torch import Tensor , nn
5
5
6
- from ... import transforms as T
7
- from ...transforms import functional as F
6
+ from ...transforms import functional as F , InterpolationMode
8
7
9
8
10
9
__all__ = ["CocoEval" , "ImageNetEval" , "Kinect400Eval" , "VocEval" ]
@@ -26,42 +25,47 @@ def __init__(
26
25
resize_size : int = 256 ,
27
26
mean : Tuple [float , ...] = (0.485 , 0.456 , 0.406 ),
28
27
std : Tuple [float , ...] = (0.229 , 0.224 , 0.225 ),
29
- interpolation : T . InterpolationMode = T . InterpolationMode .BILINEAR ,
28
+ interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
30
29
) -> None :
31
30
super ().__init__ ()
32
- self ._resize = T .Resize (resize_size , interpolation = interpolation )
33
- self ._crop = T .CenterCrop (crop_size )
34
- self ._normalize = T .Normalize (mean = mean , std = std )
31
+ self ._crop_size = [crop_size ]
32
+ self ._size = [resize_size ]
33
+ self ._mean = list (mean )
34
+ self ._std = list (std )
35
+ self ._interpolation = interpolation
35
36
36
37
def forward (self , img : Tensor ) -> Tensor :
37
- img = self ._crop (self ._resize (img ))
38
+ img = F .resize (img , self ._size , interpolation = self ._interpolation )
39
+ img = F .center_crop (img , self ._crop_size )
38
40
if not isinstance (img , Tensor ):
39
41
img = F .pil_to_tensor (img )
40
42
img = F .convert_image_dtype (img , torch .float )
41
- return self ._normalize (img )
43
+ img = F .normalize (img , mean = self ._mean , std = self ._std )
44
+ return img
42
45
43
46
44
47
class Kinect400Eval (nn .Module ):
45
48
def __init__ (
46
49
self ,
47
- resize_size : Tuple [int , int ],
48
50
crop_size : Tuple [int , int ],
51
+ resize_size : Tuple [int , int ],
49
52
mean : Tuple [float , ...] = (0.43216 , 0.394666 , 0.37645 ),
50
53
std : Tuple [float , ...] = (0.22803 , 0.22145 , 0.216989 ),
51
- interpolation : T . InterpolationMode = T . InterpolationMode .BILINEAR ,
54
+ interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
52
55
) -> None :
53
56
super ().__init__ ()
54
- self ._convert = T .ConvertImageDtype (torch .float )
55
- self ._resize = T .Resize (resize_size , interpolation = interpolation )
56
- self ._normalize = T .Normalize (mean = mean , std = std )
57
- self ._crop = T .CenterCrop (crop_size )
57
+ self ._crop_size = list (crop_size )
58
+ self ._size = list (resize_size )
59
+ self ._mean = list (mean )
60
+ self ._std = list (std )
61
+ self ._interpolation = interpolation
58
62
59
63
def forward (self , vid : Tensor ) -> Tensor :
60
64
vid = vid .permute (0 , 3 , 1 , 2 ) # (T, H, W, C) => (T, C, H, W)
61
- vid = self . _convert (vid )
62
- vid = self . _resize (vid )
63
- vid = self . _normalize (vid )
64
- vid = self . _crop (vid )
65
+ vid = F . resize (vid , self . _size , interpolation = self . _interpolation )
66
+ vid = F . center_crop (vid , self . _crop_size )
67
+ vid = F . convert_image_dtype (vid , torch . float )
68
+ vid = F . normalize (vid , mean = self . _mean , std = self . _std )
65
69
return vid .permute (1 , 0 , 2 , 3 ) # (T, C, H, W) => (C, T, H, W)
66
70
67
71
@@ -71,8 +75,8 @@ def __init__(
71
75
resize_size : int ,
72
76
mean : Tuple [float , ...] = (0.485 , 0.456 , 0.406 ),
73
77
std : Tuple [float , ...] = (0.229 , 0.224 , 0.225 ),
74
- interpolation : T . InterpolationMode = T . InterpolationMode .BILINEAR ,
75
- interpolation_target : T . InterpolationMode = T . InterpolationMode .NEAREST ,
78
+ interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
79
+ interpolation_target : InterpolationMode = InterpolationMode .NEAREST ,
76
80
) -> None :
77
81
super ().__init__ ()
78
82
self ._size = [resize_size ]
0 commit comments