Skip to content

Commit 4c3a9f4

Browse files
gruebelvfdev-5
andauthored
replace Number type with float; remove unneeded type ignores (#1425)
Co-authored-by: vfdev <[email protected]>
1 parent 64d9145 commit 4c3a9f4

File tree

11 files changed

+50
-58
lines changed

11 files changed

+50
-58
lines changed

ignite/distributed/auto.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -284,19 +284,19 @@ def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: O
284284

285285
def __iter__(self) -> Iterator:
286286
# deterministically shuffle based on epoch
287-
torch.manual_seed(self.epoch) # type: ignore[attr-defined]
287+
torch.manual_seed(self.epoch)
288288

289289
indices = [] # type: List
290-
while len(indices) < self.total_size: # type: ignore[attr-defined]
290+
while len(indices) < self.total_size:
291291
indices += list(self.sampler)
292292

293-
if len(indices) > self.total_size: # type: ignore[attr-defined]
294-
indices = indices[: self.total_size] # type: ignore[attr-defined]
293+
if len(indices) > self.total_size:
294+
indices = indices[: self.total_size]
295295

296296
# subsample
297-
indices = indices[self.rank : self.total_size : self.num_replicas] # type: ignore[attr-defined]
298-
if len(indices) != self.num_samples: # type: ignore[attr-defined]
299-
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) # type: ignore[attr-defined]
297+
indices = indices[self.rank : self.total_size : self.num_replicas]
298+
if len(indices) != self.num_samples:
299+
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
300300

301301
return iter(indices)
302302

ignite/distributed/comp_models/base.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ def _apply_op(
135135
return tensor
136136

137137
def _collective_op(
138-
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
139-
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
138+
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
139+
) -> Union[torch.Tensor, float, List[float], List[str]]:
140140
tensor_to_number = tensor_to_str = False
141141
device = self.device()
142-
if isinstance(tensor, Number):
142+
if isinstance(tensor, (Number, float)):
143143
tensor_to_number = True
144144
tensor = torch.tensor(tensor, device=device, dtype=self._collective_op_dtype)
145145
elif isinstance(tensor, str):
@@ -150,28 +150,26 @@ def _collective_op(
150150

151151
if tensor_to_number:
152152
if tensor.numel() == 1:
153-
return cast(Number, tensor.item())
153+
return tensor.item()
154154
else:
155155
return tensor.tolist()
156156
elif tensor_to_str:
157157
return self._decode_str(tensor)
158158
return tensor
159159

160-
def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
160+
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Union[torch.Tensor, float]:
161161
if not isinstance(tensor, (torch.Tensor, Number)):
162162
raise TypeError("Unhandled input type {}".format(type(tensor)))
163163

164-
return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op))
164+
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op))
165165

166-
def all_gather(
167-
self, tensor: Union[torch.Tensor, Number, str]
168-
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
166+
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
169167
if not isinstance(tensor, (torch.Tensor, Number, str)):
170168
raise TypeError("Unhandled input type {}".format(type(tensor)))
171169

172170
return self._collective_op(tensor, self._do_all_gather)
173171

174-
def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
172+
def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
175173
if not isinstance(tensor, (torch.Tensor, Number, str)):
176174
raise TypeError("Unhandled input type {}".format(type(tensor)))
177175

@@ -196,7 +194,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U
196194
tensor = self._apply_op(tensor, device, self._do_broadcast, src)
197195

198196
if tensor_to_number:
199-
return cast(Number, tensor.item())
197+
return tensor.item()
200198
if tensor_to_str:
201199
list_str = self._decode_str(tensor)
202200
return list_str[0]
@@ -273,17 +271,15 @@ def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_Seria
273271
def spawn(*args: Any, **kwargs: Any) -> None:
274272
raise NotImplementedError("Serial computation model does not implement spawn method")
275273

276-
def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
274+
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Union[torch.Tensor, float]:
277275
return tensor
278276

279-
def all_gather(
280-
self, tensor: Union[torch.Tensor, Number, str]
281-
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
277+
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
282278
if isinstance(tensor, torch.Tensor):
283279
return tensor
284-
return cast(Union[List[Number], List[str]], [tensor])
280+
return cast(Union[List[float], List[str]], [tensor])
285281

286-
def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
282+
def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
287283
return tensor
288284

289285
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:

ignite/distributed/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import socket
22
from functools import wraps
3-
from numbers import Number
43
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
54

65
import torch
@@ -316,7 +315,7 @@ def train_fn(local_rank, a, b, c, d=12):
316315
)
317316

318317

319-
def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[torch.Tensor, Number]:
318+
def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
320319
"""Helper method to perform all reduce operation.
321320
322321
Args:
@@ -334,7 +333,7 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
334333
return _model.all_reduce(tensor, op)
335334

336335

337-
def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[Number], List[str]]:
336+
def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
338337
"""Helper method to perform all gather operation.
339338
340339
Args:
@@ -352,7 +351,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
352351
return _model.all_gather(tensor)
353352

354353

355-
def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
354+
def broadcast(tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
356355
"""Helper method to perform broadcast operation.
357356
358357
Args:

ignite/engine/deterministic.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def state_dict(self) -> OrderedDict:
183183
def _init_run(self) -> None:
184184
self.state.seed = int(torch.randint(0, int(1e9), (1,)).item())
185185
if not hasattr(self.state, "rng_states"):
186-
self.state.rng_states = None # type: ignore[attr-defined]
186+
setattr(self.state, "rng_states", None)
187187

188188
if torch.cuda.is_available():
189189
torch.backends.cudnn.deterministic = True
@@ -203,21 +203,19 @@ def _setup_engine(self) -> None:
203203
# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like
204204
can_patch_dataloader = True
205205
if hasattr(self.state.dataloader, "_dataset_kind"):
206-
from torch.utils.data.dataloader import _DatasetKind # type: ignore[attr-defined]
206+
from torch.utils.data.dataloader import _DatasetKind
207207

208-
_dataloader_kind = self.state.dataloader._dataset_kind # type: ignore[attr-defined]
208+
_dataloader_kind = self.state.dataloader._dataset_kind
209209
can_patch_dataloader = _dataloader_kind == _DatasetKind.Map
210210
if can_patch_dataloader:
211-
if self._dataloader_len is not None and hasattr(
212-
self.state.dataloader.sampler, "epoch" # type: ignore[attr-defined]
213-
):
211+
if self._dataloader_len is not None and hasattr(self.state.dataloader.sampler, "epoch"):
214212
if self._dataloader_len != self.state.epoch_length:
215213
warnings.warn(
216214
"When defined engine's epoch length is different of input dataloader length, "
217215
"distributed sampler indices can not be setup in a reproducible manner"
218216
)
219217

220-
batch_sampler = self.state.dataloader.batch_sampler # type: ignore[attr-defined]
218+
batch_sampler = self.state.dataloader.batch_sampler
221219
if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)):
222220
self.state.dataloader = update_dataloader(
223221
self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type]
@@ -233,9 +231,10 @@ def _setup_engine(self) -> None:
233231

234232
# restore rng state if in the middle
235233
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
236-
if (getattr(self.state, "rng_states", None) is not None) and in_the_middle:
237-
_set_rng_states(self.state.rng_states) # type: ignore[attr-defined]
238-
self.state.rng_states = None # type: ignore[attr-defined]
234+
rng_states = getattr(self.state, "rng_states", None)
235+
if rng_states is not None and in_the_middle:
236+
_set_rng_states(rng_states)
237+
setattr(self.state, "rng_states", None)
239238

240239
def _from_iteration(self, iteration: int) -> Iterator:
241240
if self.state.dataloader is None:

ignite/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
240240
return handler(*args, **kwargs)
241241

242242
# setup input handler as parent to make has_event_handler work
243-
wrapper._parent = weakref.ref(handler) # type: ignore[attr-defined]
243+
setattr(wrapper, "_parent", weakref.ref(handler))
244244
return wrapper
245245

246246
def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle:

ignite/handlers/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from abc import ABCMeta, abstractmethod
77
from collections import OrderedDict, namedtuple
8-
from tempfile import _TemporaryFileWrapper # type: ignore
8+
from tempfile import _TemporaryFileWrapper # type: ignore[attr-defined]
99
from typing import Callable, Mapping, Optional, Union
1010

1111
import torch

ignite/handlers/terminate_on_nan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, output_transform: Callable = lambda x: x):
4040
def __call__(self, engine: Engine) -> None:
4141
output = self._output_transform(engine.state.output)
4242

43-
def raise_error(x: Union[numbers.Number, torch.Tensor]) -> None:
43+
def raise_error(x: Union[float, torch.Tensor]) -> None:
4444

4545
if isinstance(x, numbers.Number):
4646
x = torch.tensor(x)

ignite/metrics/accumulation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numbers
2-
from typing import Any, Callable, Tuple, Union, cast
2+
from typing import Any, Callable, Tuple, Union
33

44
import torch
55

@@ -57,12 +57,12 @@ def reset(self) -> None:
5757
self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device)
5858
self.num_examples = 0
5959

60-
def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
60+
def _check_output_type(self, output: Union[float, torch.Tensor]) -> None:
6161
if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)):
6262
raise TypeError("Output should be a number or torch.Tensor, but given {}".format(type(output)))
6363

6464
@reinit__is_reduced
65-
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
65+
def update(self, output: Union[float, torch.Tensor]) -> None:
6666
self._check_output_type(output)
6767

6868
if isinstance(output, torch.Tensor):
@@ -125,14 +125,14 @@ def __init__(
125125
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
126126
):
127127
def _mean_op(a: Union[float, torch.Tensor], x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
128-
if isinstance(x, torch.Tensor) and x.ndim > 1: # type: ignore[attr-defined]
128+
if isinstance(x, torch.Tensor) and x.ndim > 1:
129129
x = x.sum(dim=0)
130130
return a + x
131131

132132
super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device)
133133

134134
@sync_all_reduce("accumulator", "num_examples")
135-
def compute(self) -> Union[torch.Tensor, numbers.Number]:
135+
def compute(self) -> Union[float, torch.Tensor]:
136136
if self.num_examples < 1:
137137
raise NotComputableError(
138138
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
@@ -172,18 +172,18 @@ class GeometricAverage(VariableAccumulation):
172172
def __init__(
173173
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
174174
):
175-
def _geom_op(a: torch.Tensor, x: Union[numbers.Number, torch.Tensor]) -> torch.Tensor:
175+
def _geom_op(a: torch.Tensor, x: Union[float, torch.Tensor]) -> torch.Tensor:
176176
if not isinstance(x, torch.Tensor):
177177
x = torch.tensor(x)
178178
x = torch.log(x)
179-
if x.ndim > 1: # type: ignore[attr-defined]
179+
if x.ndim > 1:
180180
x = x.sum(dim=0)
181181
return a + x
182182

183183
super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device)
184184

185185
@sync_all_reduce("accumulator", "num_examples")
186-
def compute(self) -> Union[torch.Tensor, numbers.Number]:
186+
def compute(self) -> Union[float, torch.Tensor]:
187187
if self.num_examples < 1:
188188
raise NotComputableError(
189189
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
@@ -192,6 +192,6 @@ def compute(self) -> Union[torch.Tensor, numbers.Number]:
192192
tensor = torch.exp(self.accumulator / self.num_examples)
193193

194194
if tensor.numel() == 1:
195-
return cast(numbers.Number, tensor.item())
195+
return tensor.item()
196196

197197
return tensor

ignite/metrics/epoch_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def compute(self) -> float:
145145

146146
if ws > 1:
147147
# broadcast result to all processes
148-
result = cast(float, idist.broadcast(result, src=0)) # type: ignore[arg-type]
148+
result = cast(float, idist.broadcast(result, src=0))
149149

150150
return result
151151

ignite/metrics/metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,10 @@ def __init__(
215215
)
216216

217217
# Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it.
218-
if torch.device(device).type == "xla": # type: ignore[arg-type]
218+
if torch.device(device).type == "xla":
219219
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
220220

221-
self._device = torch.device(device) # type: ignore[arg-type]
221+
self._device = torch.device(device)
222222
self._is_reduced = False
223223
self.reset()
224224

0 commit comments

Comments
 (0)