16
16
17
17
import torch
18
18
from torch .nn import Module
19
-
20
- try :
21
- from typing_extensions import Self
22
- except ImportError :
23
- # workaround for Python 3.7.
24
- # see https://www.python.org/dev/peps/pep-0673/
25
- from typing import TypeVar
26
-
27
- Self = TypeVar ("TDeviceDtypeModuleMixin" , bound = "DeviceDtypeModuleMixin" )
28
-
19
+ from typing_extensions import Self
29
20
30
21
import pytorch_lightning as pl
31
22
@@ -57,7 +48,7 @@ def device(self) -> Union[str, torch.device]:
57
48
58
49
return device
59
50
60
- def to (self , * args : Any , ** kwargs : Any ) -> Self :
51
+ def to (self , * args : Any , ** kwargs : Any ) -> Self : # type: ignore[valid-type]
61
52
"""Moves and/or casts the parameters and buffers.
62
53
63
54
This can be called as
@@ -121,7 +112,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
121
112
self .__update_properties (device = out [0 ], dtype = out [1 ])
122
113
return super ().to (* args , ** kwargs )
123
114
124
- def cuda (self , device : Optional [Union [torch .device , int ]] = None ) -> Self :
115
+ def cuda (self , device : Optional [Union [torch .device , int ]] = None ) -> Self : # type: ignore[valid-type]
125
116
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
126
117
different objects. So it should be called before constructing optimizer if the module will live on GPU
127
118
while being optimized.
@@ -134,11 +125,11 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
134
125
Module: self
135
126
"""
136
127
if device is None or isinstance (device , int ):
137
- device = torch .device ("cuda" , index = device )
128
+ device = torch .device ("cuda" , index = ( device or 0 ) )
138
129
self .__update_properties (device = device )
139
130
return super ().cuda (device = device )
140
131
141
- def cpu (self ) -> Self :
132
+ def cpu (self ) -> Self : # type: ignore[valid-type]
142
133
"""Moves all model parameters and buffers to the CPU.
143
134
144
135
Returns:
@@ -147,7 +138,7 @@ def cpu(self) -> Self:
147
138
self .__update_properties (device = torch .device ("cpu" ))
148
139
return super ().cpu ()
149
140
150
- def type (self , dst_type : Union [str , torch .dtype ]) -> Self :
141
+ def type (self , dst_type : Union [str , torch .dtype ]) -> Self : # type: ignore[valid-type]
151
142
"""Casts all parameters and buffers to :attr:`dst_type`.
152
143
153
144
Arguments:
@@ -159,7 +150,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Self:
159
150
self .__update_properties (dtype = dst_type )
160
151
return super ().type (dst_type = dst_type )
161
152
162
- def float (self ) -> Self :
153
+ def float (self ) -> Self : # type: ignore[valid-type]
163
154
"""Casts all floating point parameters and buffers to ``float`` datatype.
164
155
165
156
Returns:
@@ -168,7 +159,7 @@ def float(self) -> Self:
168
159
self .__update_properties (dtype = torch .float )
169
160
return super ().float ()
170
161
171
- def double (self ) -> Self :
162
+ def double (self ) -> Self : # type: ignore[valid-type]
172
163
"""Casts all floating point parameters and buffers to ``double`` datatype.
173
164
174
165
Returns:
@@ -177,7 +168,7 @@ def double(self) -> Self:
177
168
self .__update_properties (dtype = torch .double )
178
169
return super ().double ()
179
170
180
- def half (self ) -> Self :
171
+ def half (self ) -> Self : # type: ignore[valid-type]
181
172
"""Casts all floating point parameters and buffers to ``half`` datatype.
182
173
183
174
Returns:
0 commit comments