Skip to content

Commit feba98a

Browse files
committed
nits
1 parent f7513a4 commit feba98a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278
sample = inputs if len(inputs) > 1 else inputs[0]
279279

280280
image = query_image(sample)
281-
_, height, width = get_image_dims(image)
281+
_, *image_size = get_image_dims(image)
282282

283283
policy = self._policies[int(torch.randint(len(self._policies), ()))]
284284

@@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:
288288

289289
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
290290

291-
magnitudes = magnitudes_fn(10, (height, width))
291+
magnitudes = magnitudes_fn(10, image_size)
292292
if magnitudes is not None:
293293
magnitude = float(magnitudes[magnitude_idx])
294294
if signed and torch.rand(()) <= 0.5:
@@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
334334
sample = inputs if len(inputs) > 1 else inputs[0]
335335

336336
image = query_image(sample)
337-
_, height, width = get_image_dims(image)
337+
_, *image_size = get_image_dims(image)
338338

339339
for _ in range(self.num_ops):
340340
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
341341

342-
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
342+
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
343343
if magnitudes is not None:
344344
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
345345
if signed and torch.rand(()) <= 0.5:
@@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
383383
sample = inputs if len(inputs) > 1 else inputs[0]
384384

385385
image = query_image(sample)
386-
_, height, width = get_image_dims(image)
386+
_, *image_size = get_image_dims(image)
387387

388388
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
389389

390-
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
390+
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
391391
if magnitudes is not None:
392392
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
393393
if signed and torch.rand(()) <= 0.5:

0 commit comments

Comments
 (0)