Skip to content

Commit 71e4c56

Browse files
committed
revert image size to (width, height)
1 parent ed32288 commit 71e4c56

File tree

5 files changed

+19
-20
lines changed

5 files changed

+19
-20
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
def _get_params(self, sample: Any) -> Dict[str, Any]:
4343
image = query_image(sample)
4444
img_c = F.get_image_num_channels(image)
45-
img_h, img_w = F.get_image_size(image)
45+
img_w, img_h = F.get_image_size(image)
4646

4747
if isinstance(self.value, (int, float)):
4848
value = [self.value]
@@ -138,7 +138,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
138138
lam = float(self._dist.sample(()))
139139

140140
image = query_image(sample)
141-
H, W = F.get_image_size(image)
141+
W, H = F.get_image_size(image)
142142

143143
r_x = torch.randint(W, ())
144144
r_y = torch.randint(H, ())

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
160160
_AUGMENTATION_SPACE = {
161161
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
162162
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
163-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
164-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
163+
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
164+
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
165165
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
166166
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
167167
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
306306
"Identity": (lambda num_bins, image_size: None, False),
307307
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
308308
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
309-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
310-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
309+
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
310+
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
311311
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
312312
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
313313
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109

110110
def _get_params(self, sample: Any) -> Dict[str, Any]:
111111
image = query_image(sample)
112-
height, width = F.get_image_size(image)
112+
width, height = F.get_image_size(image)
113113
area = height * width
114114

115115
log_ratio = torch.log(torch.tensor(self.ratio))

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def resize_image_tensor(
4040
antialias: Optional[bool] = None,
4141
) -> torch.Tensor:
4242
new_height, new_width = size
43-
old_height, old_width = _FT.get_image_size(image)
43+
old_width, old_height = _FT.get_image_size(image)
4444
num_channels = _FT.get_image_num_channels(image)
4545
batch_shape = image.shape[:-3]
4646
return _FT.resize(
@@ -143,7 +143,7 @@ def affine_image_tensor(
143143

144144
center_f = [0.0, 0.0]
145145
if center is not None:
146-
height, width = get_image_size(img)
146+
width, height = get_image_size(img)
147147
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
148148
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
149149

@@ -169,7 +169,7 @@ def affine_image_pil(
169169
# it is visually better to estimate the center without 0.5 offset
170170
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
171171
if center is None:
172-
height, width = get_image_size(img)
172+
width, height = get_image_size(img)
173173
center = [width * 0.5, height * 0.5]
174174
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
175175

@@ -186,7 +186,7 @@ def rotate_image_tensor(
186186
) -> torch.Tensor:
187187
center_f = [0.0, 0.0]
188188
if center is not None:
189-
height, width = get_image_size(img)
189+
width, height = get_image_size(img)
190190
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
191191
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
192192

@@ -262,13 +262,13 @@ def _center_crop_compute_crop_anchor(
262262

263263
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
264264
crop_height, crop_width = _center_crop_parse_output_size(output_size)
265-
image_height, image_width = get_image_size(img)
265+
image_width, image_height = get_image_size(img)
266266

267267
if crop_height > image_height or crop_width > image_width:
268268
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
269269
img = pad_image_tensor(img, padding_ltrb, fill=0)
270270

271-
image_height, image_width = get_image_size(img)
271+
image_width, image_height = get_image_size(img)
272272
if crop_width == image_width and crop_height == image_height:
273273
return img
274274

@@ -278,13 +278,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
278278

279279
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
280280
crop_height, crop_width = _center_crop_parse_output_size(output_size)
281-
image_height, image_width = get_image_size(img)
281+
image_width, image_height = get_image_size(img)
282282

283283
if crop_height > image_height or crop_width > image_width:
284284
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
285285
img = pad_image_pil(img, padding_ltrb, fill=0)
286286

287-
image_height, image_width = get_image_size(img)
287+
image_width, image_height = get_image_size(img)
288288
if crop_width == image_width and crop_height == image_height:
289289
return img
290290

torchvision/prototype/transforms/functional/_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88

99
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]:
1010
if isinstance(image, features.Image):
11-
return image.image_size
11+
height, width = image.image_size
12+
return width, height
1213
elif isinstance(image, torch.Tensor):
13-
width, height = _FT.get_image_size(image)
14-
return height, width
14+
return cast(Tuple[int, int], tuple(_FT.get_image_size(image)))
1515
if isinstance(image, PIL.Image.Image):
16-
width, height = _FP.get_image_size(image)
17-
return height, width
16+
return cast(Tuple[int, int], tuple(_FP.get_image_size(image)))
1817
else:
1918
raise TypeError(f"unable to get image size from object of type {type(image).__name__}")
2019

0 commit comments

Comments
 (0)