@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278
278
sample = inputs if len (inputs ) > 1 else inputs [0 ]
279
279
280
280
image = query_image (sample )
281
- _ , height , width = get_image_dims (image )
281
+ _ , * image_size = get_image_dims (image )
282
282
283
283
policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
284
284
@@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:
288
288
289
289
magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
290
290
291
- magnitudes = magnitudes_fn (10 , ( height , width ) )
291
+ magnitudes = magnitudes_fn (10 , image_size )
292
292
if magnitudes is not None :
293
293
magnitude = float (magnitudes [magnitude_idx ])
294
294
if signed and torch .rand (()) <= 0.5 :
@@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
334
334
sample = inputs if len (inputs ) > 1 else inputs [0 ]
335
335
336
336
image = query_image (sample )
337
- _ , height , width = get_image_dims (image )
337
+ _ , * image_size = get_image_dims (image )
338
338
339
339
for _ in range (self .num_ops ):
340
340
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
341
341
342
- magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
342
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , image_size )
343
343
if magnitudes is not None :
344
344
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
345
345
if signed and torch .rand (()) <= 0.5 :
@@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
383
383
sample = inputs if len (inputs ) > 1 else inputs [0 ]
384
384
385
385
image = query_image (sample )
386
- _ , height , width = get_image_dims (image )
386
+ _ , * image_size = get_image_dims (image )
387
387
388
388
transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
389
389
390
- magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
390
+ magnitudes = magnitudes_fn (self .num_magnitude_bins , image_size )
391
391
if magnitudes is not None :
392
392
magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
393
393
if signed and torch .rand (()) <= 0.5 :
0 commit comments