1
1
import math
2
2
import numbers
3
3
import warnings
4
+ from types import SimpleNamespace
4
5
from typing import Any , cast , Dict , List , Optional , Tuple , Union
5
6
6
7
import PIL .Image
12
13
from torchvision .prototype .transforms import functional as F , InterpolationMode , Transform
13
14
14
15
from ._transform import _RandomApplyTransform
15
- from .utils import get_dimensions , has_any , is_simple_tensor , query_chw , query_spatial_size
16
+ from .utils import has_any , is_simple_tensor , query_chw , query_spatial_size
16
17
17
18
18
19
class RandomErasing (_RandomApplyTransform ):
@@ -53,24 +54,19 @@ def __init__(
53
54
54
55
self ._log_ratio = torch .log (torch .tensor (self .ratio ))
55
56
56
- @staticmethod
57
- def _get_params_internal (
58
- img_c : int ,
59
- img_h : int ,
60
- img_w : int ,
61
- scale : Tuple [float , float ],
62
- log_ratio : torch .Tensor ,
63
- value : Optional [List [float ]] = None ,
64
- ) -> Tuple [int , int , int , int , Optional [torch .Tensor ]]:
65
- if value is not None and not (len (value ) in (1 , img_c )):
57
+ def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
58
+ img_c , img_h , img_w = query_chw (flat_inputs )
59
+
60
+ if self .value is not None and not (len (self .value ) in (1 , img_c )):
66
61
raise ValueError (
67
62
f"If value is a sequence, it should have either a single value or { img_c } (number of inpt channels)"
68
63
)
69
64
70
65
area = img_h * img_w
71
66
67
+ log_ratio = self ._log_ratio
72
68
for _ in range (10 ):
73
- erase_area = area * torch .empty (1 ).uniform_ (scale [0 ], scale [1 ]).item ()
69
+ erase_area = area * torch .empty (1 ).uniform_ (self . scale [0 ], self . scale [1 ]).item ()
74
70
aspect_ratio = torch .exp (
75
71
torch .empty (1 ).uniform_ (
76
72
log_ratio [0 ], # type: ignore[arg-type]
@@ -83,34 +79,18 @@ def _get_params_internal(
83
79
if not (h < img_h and w < img_w ):
84
80
continue
85
81
86
- if value is None :
82
+ if self . value is None :
87
83
v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
88
84
else :
89
- v = torch .tensor (value )[:, None , None ]
85
+ v = torch .tensor (self . value )[:, None , None ]
90
86
91
87
i = int (torch .randint (0 , img_h - h + 1 , size = (1 ,)))
92
88
j = int (torch .randint (0 , img_w - w + 1 , size = (1 ,)))
93
89
break
94
90
else :
95
91
i , j , h , w , v = 0 , 0 , img_h , img_w , None
96
92
97
- return i , j , h , w , v
98
-
99
- def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
100
- img_c , img_h , img_w = query_chw (flat_inputs )
101
- return dict (
102
- zip (
103
- "ijhwv" ,
104
- self ._get_params_internal (
105
- img_c ,
106
- img_h ,
107
- img_w ,
108
- self .scale ,
109
- self ._log_ratio ,
110
- self .value , # type: ignore[arg-type]
111
- ),
112
- )
113
- )
93
+ return dict (i = i , j = j , h = h , w = w , v = v )
114
94
115
95
@staticmethod
116
96
def get_params (
@@ -119,13 +99,12 @@ def get_params(
119
99
ratio : Tuple [float , float ],
120
100
value : Optional [List [float ]] = None ,
121
101
) -> Tuple [int , int , int , int , torch .Tensor ]:
122
- img_c , img_h , img_w = get_dimensions (image )
123
- i , j , h , w , v = RandomErasing ._get_params_internal (
124
- img_c , img_h , img_w , scale , torch .log (torch .tensor (ratio )), value
125
- )
102
+ self = SimpleNamespace (scale = scale , _log_ratio = torch .log (torch .tensor (ratio )), value = value )
103
+ params = RandomErasing ._get_params (self , [image ]) # type: ignore[arg-type]
104
+ v = params ["v" ]
126
105
if v is None :
127
106
v = image
128
- return i , j , h , w , v
107
+ return params [ "i" ], params [ "j" ], params [ "h" ], params [ "w" ] , v
129
108
130
109
def _transform (
131
110
self , inpt : Union [datapoints .ImageType , datapoints .VideoType ], params : Dict [str , Any ]
0 commit comments