Skip to content

Commit 85b0356

Browse files
Fix mypy errors attributed to pytorch_lightning.core.mixins.device_dtype_mixin (#13704)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 4c35867 commit 85b0356

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ module = [
5252
"pytorch_lightning.callbacks.stochastic_weight_avg",
5353
"pytorch_lightning.core.datamodule",
5454
"pytorch_lightning.core.decorators",
55-
"pytorch_lightning.core.mixins.device_dtype_mixin",
5655
"pytorch_lightning.core.module",
5756
"pytorch_lightning.core.saving",
5857
"pytorch_lightning.demos.boring_classes",

src/pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,7 @@
1616

1717
import torch
1818
from torch.nn import Module
19-
20-
try:
21-
from typing_extensions import Self
22-
except ImportError:
23-
# workaround for Python 3.7.
24-
# see https://www.python.org/dev/peps/pep-0673/
25-
from typing import TypeVar
26-
27-
Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin")
28-
19+
from typing_extensions import Self
2920

3021
import pytorch_lightning as pl
3122

@@ -57,7 +48,7 @@ def device(self) -> Union[str, torch.device]:
5748

5849
return device
5950

60-
def to(self, *args: Any, **kwargs: Any) -> Self:
51+
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
6152
"""Moves and/or casts the parameters and buffers.
6253
6354
This can be called as
@@ -121,7 +112,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
121112
self.__update_properties(device=out[0], dtype=out[1])
122113
return super().to(*args, **kwargs)
123114

124-
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
115+
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
125116
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
126117
different objects. So it should be called before constructing optimizer if the module will live on GPU
127118
while being optimized.
@@ -134,11 +125,11 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
134125
Module: self
135126
"""
136127
if device is None or isinstance(device, int):
137-
device = torch.device("cuda", index=device)
128+
device = torch.device("cuda", index=(device or 0))
138129
self.__update_properties(device=device)
139130
return super().cuda(device=device)
140131

141-
def cpu(self) -> Self:
132+
def cpu(self) -> Self: # type: ignore[valid-type]
142133
"""Moves all model parameters and buffers to the CPU.
143134
144135
Returns:
@@ -147,7 +138,7 @@ def cpu(self) -> Self:
147138
self.__update_properties(device=torch.device("cpu"))
148139
return super().cpu()
149140

150-
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
141+
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
151142
"""Casts all parameters and buffers to :attr:`dst_type`.
152143
153144
Arguments:
@@ -159,7 +150,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Self:
159150
self.__update_properties(dtype=dst_type)
160151
return super().type(dst_type=dst_type)
161152

162-
def float(self) -> Self:
153+
def float(self) -> Self: # type: ignore[valid-type]
163154
"""Casts all floating point parameters and buffers to ``float`` datatype.
164155
165156
Returns:
@@ -168,7 +159,7 @@ def float(self) -> Self:
168159
self.__update_properties(dtype=torch.float)
169160
return super().float()
170161

171-
def double(self) -> Self:
162+
def double(self) -> Self: # type: ignore[valid-type]
172163
"""Casts all floating point parameters and buffers to ``double`` datatype.
173164
174165
Returns:
@@ -177,7 +168,7 @@ def double(self) -> Self:
177168
self.__update_properties(dtype=torch.double)
178169
return super().double()
179170

180-
def half(self) -> Self:
171+
def half(self) -> Self: # type: ignore[valid-type]
181172
"""Casts all floating point parameters and buffers to ``half`` datatype.
182173
183174
Returns:

0 commit comments

Comments
 (0)