Skip to content

Commit 85d7c4d

Browse files
Configure mypy to install dependencies in CI and update pyproject.toml (#10682)
* mypy install deps * fix deps * add examples * fix type errors * fix type error * fix * fix * update pyproject.toml
1 parent f8b2d5b commit 85d7c4d

File tree

4 files changed

+13
-26
lines changed

4 files changed

+13
-26
lines changed

.github/workflows/code-checks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ jobs:
1414
- uses: actions/setup-python@v2
1515
with:
1616
python-version: 3.9
17-
- name: Install mypy
17+
- name: Install dependencies
1818
run: |
19-
grep mypy requirements/test.txt | xargs -0 pip install
19+
pip install '.[dev]'
2020
pip list
2121
- run: mypy --install-types --non-interactive

pyproject.toml

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,31 @@ disable_error_code = "attr-defined"
3636
# style choices
3737
warn_no_return = "False"
3838

39-
# TODO: Fix typing for these modules
39+
# Changes mypy default to ignore all errors
4040
[[tool.mypy.overrides]]
4141
module = [
42-
"pytorch_lightning.callbacks.*",
43-
"pytorch_lightning.core.*",
44-
"pytorch_lightning.loggers.*",
45-
"pytorch_lightning.loops.*",
46-
"pytorch_lightning.overrides.*",
47-
"pytorch_lightning.plugins.environments.*",
48-
"pytorch_lightning.plugins.training_type.*",
49-
"pytorch_lightning.profiler.*",
50-
"pytorch_lightning.trainer.*",
51-
"pytorch_lightning.distributed.*",
52-
"pytorch_lightning.tuner.*",
53-
"pytorch_lightning.utilities.*",
42+
"pytorch_lightning.*",
5443
]
5544
ignore_errors = "True"
5645

46+
# Override the default for files where we would like to enable type checking
47+
# TODO: Bring more files into this section
5748
[[tool.mypy.overrides]]
5849
module = [
5950
"pytorch_lightning.callbacks.device_stats_monitor",
6051
"pytorch_lightning.callbacks.early_stopping",
6152
"pytorch_lightning.callbacks.gpu_stats_monitor",
6253
"pytorch_lightning.callbacks.gradient_accumulation_scheduler",
63-
"pytorch_lightning.callbacks.lr_monitor",
6454
"pytorch_lightning.callbacks.model_summary",
6555
"pytorch_lightning.callbacks.progress",
6656
"pytorch_lightning.callbacks.pruning",
6757
"pytorch_lightning.callbacks.rich_model_summary",
6858
"pytorch_lightning.core.optimizer",
69-
"pytorch_lightning.lite.*",
70-
"pytorch_lightning.loops.optimization.*",
59+
"pytorch_lightning.loops.optimization.closure.py",
60+
"pytorch_lightning.loops.optimization.manual_loop.py",
7161
"pytorch_lightning.loops.evaluation_loop",
72-
"pytorch_lightning.trainer.connectors.checkpoint_connector",
73-
"pytorch_lightning.trainer.connectors.logger_connector.*",
62+
"pytorch_lightning.trainer.connectors.logger_connector.py",
63+
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator.py",
7464
"pytorch_lightning.trainer.connectors.signal_connector",
7565
"pytorch_lightning.trainer.progress.*",
7666
"pytorch_lightning.tuner.auto_gpu_select",
@@ -80,8 +70,6 @@ module = [
8070
"pytorch_lightning.utilities.cloud_io",
8171
"pytorch_lightning.utilities.device_dtype_mixin",
8272
"pytorch_lightning.utilities.device_parser",
83-
"pytorch_lightning.utilities.distributed",
84-
"pytorch_lightning.utilities.memory",
8573
"pytorch_lightning.utilities.model_summary",
8674
"pytorch_lightning.utilities.parameter_tying",
8775
"pytorch_lightning.utilities.parsing",

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
280280
)
281281

282282
# return cached value
283-
if self._computed is not None: # type: ignore
284-
return self._computed # type: ignore
283+
if self._computed is not None:
284+
return self._computed
285285
self._computed = compute(*args, **kwargs)
286286
return self._computed
287287

pytorch_lightning/utilities/parameter_tying.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Dict, List, Optional
2020

2121
from torch import nn
22-
from torch.nn import Parameter
2322

2423

2524
def find_shared_parameters(module: nn.Module) -> List[str]:
@@ -64,7 +63,7 @@ def _get_module_by_path(module: nn.Module, path: str) -> nn.Module:
6463
return module
6564

6665

67-
def _set_module_by_path(module: nn.Module, path: str, value: Parameter) -> None:
66+
def _set_module_by_path(module: nn.Module, path: str, value: nn.Module) -> None:
6867
path = path.split(".")
6968
for name in path[:-1]:
7069
module = getattr(module, name)

0 commit comments

Comments
 (0)