@@ -40,24 +40,22 @@ def __init__(
40
40
raise ValueError ("Scale should be between 0 and 1" )
41
41
self .scale = scale
42
42
self .ratio = ratio
43
- self .value = value
43
+ if isinstance (value , (int , float )):
44
+ self .value = [value ]
45
+ elif isinstance (value , str ):
46
+ self .value = None
47
+ elif isinstance (value , tuple ):
48
+ self .value = list (value )
49
+ else :
50
+ self .value = value
44
51
self .inplace = inplace
45
52
46
53
self ._log_ratio = torch .log (torch .tensor (self .ratio ))
47
54
48
55
def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
49
56
img_c , img_h , img_w = query_chw (flat_inputs )
50
57
51
- if isinstance (self .value , (int , float )):
52
- value = [self .value ]
53
- elif isinstance (self .value , str ):
54
- value = None
55
- elif isinstance (self .value , tuple ):
56
- value = list (self .value )
57
- else :
58
- value = self .value
59
-
60
- if value is not None and not (len (value ) in (1 , img_c )):
58
+ if self .value is not None and not (len (self .value ) in (1 , img_c )):
61
59
raise ValueError (
62
60
f"If value is a sequence, it should have either a single value or { img_c } (number of inpt channels)"
63
61
)
@@ -79,10 +77,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
79
77
if not (h < img_h and w < img_w ):
80
78
continue
81
79
82
- if value is None :
80
+ if self . value is None :
83
81
v = torch .empty ([img_c , h , w ], dtype = torch .float32 ).normal_ ()
84
82
else :
85
- v = torch .tensor (value )[:, None , None ]
83
+ v = torch .tensor (self . value )[:, None , None ]
86
84
87
85
i = torch .randint (0 , img_h - h + 1 , size = (1 ,)).item ()
88
86
j = torch .randint (0 , img_w - w + 1 , size = (1 ,)).item ()
@@ -121,8 +119,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
121
119
def _mixup_onehotlabel (self , inpt : features .OneHotLabel , lam : float ) -> features .OneHotLabel :
122
120
if inpt .ndim < 2 :
123
121
raise ValueError ("Need a batch of one hot labels" )
124
- output = inpt .clone ()
125
- output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
122
+ output = inpt .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (inpt .mul (lam ))
126
123
return features .OneHotLabel .wrap_like (inpt , output )
127
124
128
125
@@ -136,8 +133,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
136
133
expected_ndim = 5 if isinstance (inpt , features .Video ) else 4
137
134
if inpt .ndim < expected_ndim :
138
135
raise ValueError ("The transform expects a batched input" )
139
- output = inpt .clone ()
140
- output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
136
+ output = inpt .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (inpt .mul (lam ))
141
137
142
138
if isinstance (inpt , (features .Image , features .Video )):
143
139
output = type (inpt ).wrap_like (inpt , output ) # type: ignore[arg-type]
@@ -243,11 +239,12 @@ def _copy_paste(
243
239
if blending :
244
240
paste_alpha_mask = F .gaussian_blur (paste_alpha_mask .unsqueeze (0 ), kernel_size = [5 , 5 ], sigma = [2.0 ])
245
241
242
+ inverse_paste_alpha_mask = paste_alpha_mask .logical_not ()
246
243
# Copy-paste images:
247
- image = ( image * ( ~ paste_alpha_mask )) + (paste_image * paste_alpha_mask )
244
+ image = image . mul ( inverse_paste_alpha_mask ). add_ (paste_image . mul ( paste_alpha_mask ) )
248
245
249
246
# Copy-paste masks:
250
- masks = masks * ( ~ paste_alpha_mask )
247
+ masks = masks * inverse_paste_alpha_mask
251
248
non_all_zero_masks = masks .sum ((- 1 , - 2 )) > 0
252
249
masks = masks [non_all_zero_masks ]
253
250
0 commit comments