Skip to content

Commit a81d99b

Browse files
authored
Add static type check with mypy (#2195)
* add mypy config * fix syntax error * fix annotations in torchvision/utils.py * add mypy type check to CircleCI * add mypy cache to ignore files * try fix CI * ignore flake8 F821 since it interferes with mypy * add mypy type check to config generator * explicitly set config files
1 parent f71316f commit a81d99b

File tree

7 files changed

+72
-10
lines changed

7 files changed

+72
-10
lines changed

.circleci/config.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,19 @@ jobs:
109109
- run:
110110
command: |
111111
pip install --user --progress-bar off flake8 typing
112-
flake8 .
112+
flake8 --config=setup.cfg .
113+
114+
python_type_check:
115+
docker:
116+
- image: circleci/python:3.7
117+
steps:
118+
- checkout
119+
- run:
120+
command: |
121+
pip install --user --progress-bar off numpy mypy
122+
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
123+
pip install --user --progress-bar off .
124+
mypy --config-file mypy.ini
113125
114126
clang_format:
115127
docker:
@@ -702,12 +714,14 @@ workflows:
702714
python_version: "3.6"
703715
cu_version: "cu101"
704716
- python_lint
717+
- python_type_check
705718
- clang_format
706719

707720
nightly:
708721
jobs:
709722
- circleci_consistency
710723
- python_lint
724+
- python_type_check
711725
- clang_format
712726
- binary_linux_wheel:
713727
cu_version: cpu

.circleci/config.yml.in

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,19 @@ jobs:
109109
- run:
110110
command: |
111111
pip install --user --progress-bar off flake8 typing
112-
flake8 .
112+
flake8 --config=setup.cfg .
113+
114+
python_type_check:
115+
docker:
116+
- image: circleci/python:3.7
117+
steps:
118+
- checkout
119+
- run:
120+
command: |
121+
pip install --user --progress-bar off numpy mypy
122+
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
123+
pip install --user --progress-bar off .
124+
mypy --config-file mypy.ini
113125

114126
clang_format:
115127
docker:
@@ -398,12 +410,14 @@ workflows:
398410
python_version: "3.6"
399411
cu_version: "cu101"
400412
- python_lint
413+
- python_type_check
401414
- clang_format
402415

403416
nightly:
404417
{%- endif %}
405418
jobs:
406419
- circleci_consistency
407420
- python_lint
421+
- python_type_check
408422
- clang_format
409423
{{ workflows(prefix="nightly_", filter_branch="nightly", upload=True) }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ htmlcov
2020
*.swp
2121
*.swo
2222
gen.yml
23+
.mypy_cache

mypy.ini

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
[mypy]
2+
3+
files = torchvision
4+
show_error_codes = True
5+
pretty = True
6+
7+
[mypy-torchvision.datasets.*]
8+
9+
ignore_errors = True
10+
11+
[mypy-torchvision.io.*]
12+
13+
ignore_errors = True
14+
15+
[mypy-torchvision.models.*]
16+
17+
ignore_errors = True
18+
19+
[mypy-torchvision.ops.*]
20+
21+
ignore_errors = True
22+
23+
[mypy-torchvision.transforms.*]
24+
25+
ignore_errors = True
26+
27+
[mypy-PIL]
28+
29+
ignore_missing_imports = True
30+

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ max-line-length = 120
99

1010
[flake8]
1111
max-line-length = 120
12-
ignore = F401,E402,F403,W503,W504
12+
ignore = F401,E402,F403,W503,W504,F821
1313
exclude = venv

torchvision/io/_video_opt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def __init__(self):
8484

8585

8686
def _validate_pts(pts_range):
87-
# type: (List[int])
87+
# type: (List[int]) -> None
88+
8889
if pts_range[1] > 0:
8990
assert (
9091
pts_range[0] <= pts_range[1]

torchvision/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, Optional, Sequence, Tuple, Text, BinaryIO
1+
from typing import Union, Optional, List, Tuple, Text, BinaryIO
22
import io
33
import pathlib
44
import torch
@@ -7,7 +7,7 @@
77

88

99
def make_grid(
10-
tensor: Union[torch.Tensor, Sequence[torch.Tensor]],
10+
tensor: Union[torch.Tensor, List[torch.Tensor]],
1111
nrow: int = 8,
1212
padding: int = 2,
1313
normalize: bool = False,
@@ -91,15 +91,17 @@ def norm_range(t, range):
9191
for x in irange(xmaps):
9292
if k >= nmaps:
9393
break
94-
grid.narrow(1, y * height + padding, height - padding)\
95-
.narrow(2, x * width + padding, width - padding)\
96-
.copy_(tensor[k])
94+
# Tensor.copy_() is a valid method but seems to be missing from the stubs
95+
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
96+
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
97+
2, x * width + padding, width - padding
98+
).copy_(tensor[k])
9799
k = k + 1
98100
return grid
99101

100102

101103
def save_image(
102-
tensor: Union[torch.Tensor, Sequence[torch.Tensor]],
104+
tensor: Union[torch.Tensor, List[torch.Tensor]],
103105
fp: Union[Text, pathlib.Path, BinaryIO],
104106
nrow: int = 8,
105107
padding: int = 2,

0 commit comments

Comments
 (0)