Skip to content

Commit 278c6ae

Browse files
yiwen-songdatumbox
authored andcommitted
[fbsync] Added typing annotations to transforms/functional_pil (#4234)
Summary: * fix * add functional PIL typings * fix types * fix types * fix a small one * small fix * fix type * fix interpolation types Reviewed By: NicolasHug Differential Revision: D30417195 fbshipit-source-id: 5e09a14011e5cca76d87c2a3dfc2872303f40b2c Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 3dd8d65 commit 278c6ae

File tree

1 file changed

+71
-22
lines changed

1 file changed

+71
-22
lines changed

torchvision/transforms/functional_pil.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numbers
2-
from typing import Any, List, Sequence
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
33

44
import numpy as np
55
import torch
@@ -34,23 +34,23 @@ def _get_image_num_channels(img: Any) -> int:
3434

3535

3636
@torch.jit.unused
37-
def hflip(img):
37+
def hflip(img: Image.Image) -> Image.Image:
3838
if not _is_pil_image(img):
3939
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
4040

4141
return img.transpose(Image.FLIP_LEFT_RIGHT)
4242

4343

4444
@torch.jit.unused
45-
def vflip(img):
45+
def vflip(img: Image.Image) -> Image.Image:
4646
if not _is_pil_image(img):
4747
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
4848

4949
return img.transpose(Image.FLIP_TOP_BOTTOM)
5050

5151

5252
@torch.jit.unused
53-
def adjust_brightness(img, brightness_factor):
53+
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
5454
if not _is_pil_image(img):
5555
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
5656

@@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor):
6060

6161

6262
@torch.jit.unused
63-
def adjust_contrast(img, contrast_factor):
63+
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
6464
if not _is_pil_image(img):
6565
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
6666

@@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor):
7070

7171

7272
@torch.jit.unused
73-
def adjust_saturation(img, saturation_factor):
73+
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
7474
if not _is_pil_image(img):
7575
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
7676

@@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor):
8080

8181

8282
@torch.jit.unused
83-
def adjust_hue(img, hue_factor):
83+
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
8484
if not(-0.5 <= hue_factor <= 0.5):
8585
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
8686

@@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor):
104104

105105

106106
@torch.jit.unused
107-
def adjust_gamma(img, gamma, gain=1):
107+
def adjust_gamma(
108+
img: Image.Image,
109+
gamma: float,
110+
gain: float = 1.0,
111+
) -> Image.Image:
112+
108113
if not _is_pil_image(img):
109114
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
110115

@@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1):
121126

122127

123128
@torch.jit.unused
124-
def pad(img, padding, fill=0, padding_mode="constant"):
129+
def pad(
130+
img: Image.Image,
131+
padding: Union[int, List[int], Tuple[int, ...]],
132+
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
133+
padding_mode: str = "constant",
134+
) -> Image.Image:
135+
125136
if not _is_pil_image(img):
126137
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
127138

@@ -196,15 +207,28 @@ def pad(img, padding, fill=0, padding_mode="constant"):
196207

197208

198209
@torch.jit.unused
199-
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
210+
def crop(
211+
img: Image.Image,
212+
top: int,
213+
left: int,
214+
height: int,
215+
width: int,
216+
) -> Image.Image:
217+
200218
if not _is_pil_image(img):
201219
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
202220

203221
return img.crop((left, top, left + width, top + height))
204222

205223

206224
@torch.jit.unused
207-
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
225+
def resize(
226+
img: Image.Image,
227+
size: Union[Sequence[int], int],
228+
interpolation: int = Image.BILINEAR,
229+
max_size: Optional[int] = None,
230+
) -> Image.Image:
231+
208232
if not _is_pil_image(img):
209233
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
210234
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
@@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
242266

243267

244268
@torch.jit.unused
245-
def _parse_fill(fill, img, name="fillcolor"):
269+
def _parse_fill(
270+
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
271+
img: Image.Image,
272+
name: str = "fillcolor",
273+
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
274+
246275
# Process fill color for affine transforms
247276
num_bands = len(img.getbands())
248277
if fill is None:
@@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"):
261290

262291

263292
@torch.jit.unused
264-
def affine(img, matrix, interpolation=0, fill=None):
293+
def affine(
294+
img: Image.Image,
295+
matrix: List[float],
296+
interpolation: int = Image.NEAREST,
297+
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
298+
) -> Image.Image:
299+
265300
if not _is_pil_image(img):
266301
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
267302

@@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None):
271306

272307

273308
@torch.jit.unused
274-
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
309+
def rotate(
310+
img: Image.Image,
311+
angle: float,
312+
interpolation: int = Image.NEAREST,
313+
expand: bool = False,
314+
center: Optional[Tuple[int, int]] = None,
315+
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
316+
) -> Image.Image:
317+
275318
if not _is_pil_image(img):
276319
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
277320

@@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
280323

281324

282325
@torch.jit.unused
283-
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None):
326+
def perspective(
327+
img: Image.Image,
328+
perspective_coeffs: float,
329+
interpolation: int = Image.BICUBIC,
330+
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
331+
) -> Image.Image:
332+
284333
if not _is_pil_image(img):
285334
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
286335

@@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
290339

291340

292341
@torch.jit.unused
293-
def to_grayscale(img, num_output_channels):
342+
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
294343
if not _is_pil_image(img):
295344
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
296345

@@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels):
308357

309358

310359
@torch.jit.unused
311-
def invert(img):
360+
def invert(img: Image.Image) -> Image.Image:
312361
if not _is_pil_image(img):
313362
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
314363
return ImageOps.invert(img)
315364

316365

317366
@torch.jit.unused
318-
def posterize(img, bits):
367+
def posterize(img: Image.Image, bits: int) -> Image.Image:
319368
if not _is_pil_image(img):
320369
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
321370
return ImageOps.posterize(img, bits)
322371

323372

324373
@torch.jit.unused
325-
def solarize(img, threshold):
374+
def solarize(img: Image.Image, threshold: int) -> Image.Image:
326375
if not _is_pil_image(img):
327376
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
328377
return ImageOps.solarize(img, threshold)
329378

330379

331380
@torch.jit.unused
332-
def adjust_sharpness(img, sharpness_factor):
381+
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
333382
if not _is_pil_image(img):
334383
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
335384

@@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor):
339388

340389

341390
@torch.jit.unused
342-
def autocontrast(img):
391+
def autocontrast(img: Image.Image) -> Image.Image:
343392
if not _is_pil_image(img):
344393
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
345394
return ImageOps.autocontrast(img)
346395

347396

348397
@torch.jit.unused
349-
def equalize(img):
398+
def equalize(img: Image.Image) -> Image.Image:
350399
if not _is_pil_image(img):
351400
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
352401
return ImageOps.equalize(img)

0 commit comments

Comments
 (0)