Skip to content

Commit 5594316

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix progress pyre fixme issues
Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D67725994
1 parent d0da727 commit 5594316

File tree

4 files changed

+126
-64
lines changed

4 files changed

+126
-64
lines changed

captum/_utils/progress.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,34 @@
33
# pyre-strict
44

55
import sys
6+
import typing
67
import warnings
78
from time import time
8-
from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO
9+
from types import TracebackType
10+
from typing import (
11+
Any,
12+
Callable,
13+
cast,
14+
Generic,
15+
Iterable,
16+
Iterator,
17+
Literal,
18+
Optional,
19+
Sized,
20+
TextIO,
21+
Type,
22+
TypeVar,
23+
Union,
24+
)
925

1026
try:
1127
from tqdm.auto import tqdm
1228
except ImportError:
1329
tqdm = None
1430

31+
T = TypeVar("T")
32+
IterableType = TypeVar("IterableType")
33+
1534

1635
class DisableErrorIOWrapper(object):
1736
def __init__(self, wrapped: TextIO) -> None:
@@ -21,15 +40,13 @@ def __init__(self, wrapped: TextIO) -> None:
2140
"""
2241
self._wrapped = wrapped
2342

24-
# pyre-fixme[3]: Return type must be annotated.
25-
# pyre-fixme[2]: Parameter must be annotated.
26-
def __getattr__(self, name):
43+
def __getattr__(self, name: str) -> object:
2744
return getattr(self._wrapped, name)
2845

2946
@staticmethod
30-
# pyre-fixme[3]: Return type must be annotated.
31-
# pyre-fixme[2]: Parameter must be annotated.
32-
def _wrapped_run(func, *args, **kwargs):
47+
def _wrapped_run(
48+
func: Callable[..., T], *args: object, **kwargs: object
49+
) -> Union[T, None]:
3350
try:
3451
return func(*args, **kwargs)
3552
except OSError as e:
@@ -38,19 +55,16 @@ def _wrapped_run(func, *args, **kwargs):
3855
except ValueError as e:
3956
if "closed" not in str(e):
4057
raise
58+
return None
4159

42-
# pyre-fixme[3]: Return type must be annotated.
43-
# pyre-fixme[2]: Parameter must be annotated.
44-
def write(self, *args, **kwargs):
60+
def write(self, *args: object, **kwargs: object) -> Optional[int]:
4561
return self._wrapped_run(self._wrapped.write, *args, **kwargs)
4662

47-
# pyre-fixme[3]: Return type must be annotated.
48-
# pyre-fixme[2]: Parameter must be annotated.
49-
def flush(self, *args, **kwargs):
63+
def flush(self, *args: object, **kwargs: object) -> None:
5064
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
5165

5266

53-
class NullProgress:
67+
class NullProgress(Iterable[IterableType]):
5468
"""Passthrough class that implements the progress API.
5569
5670
This class implements the tqdm and SimpleProgressBar api but
@@ -61,27 +75,28 @@ class NullProgress:
6175

6276
def __init__(
6377
self,
64-
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
65-
iterable: Optional[Iterable] = None,
78+
iterable: Optional[Iterable[IterableType]] = None,
6679
*args: Any,
6780
**kwargs: Any,
6881
) -> None:
6982
del args, kwargs
7083
self.iterable = iterable
7184

72-
def __enter__(self) -> "NullProgress":
85+
def __enter__(self) -> "NullProgress[IterableType]":
7386
return self
7487

75-
# pyre-fixme[2]: Parameter must be annotated.
76-
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
88+
def __exit__(
89+
self,
90+
exc_type: Union[Type[BaseException], None],
91+
exc_value: Union[BaseException, None],
92+
exc_traceback: Union[TracebackType, None],
93+
) -> Literal[False]:
7794
return False
7895

79-
# pyre-fixme[3]: Return type must be annotated.
80-
def __iter__(self):
96+
def __iter__(self) -> Iterator[IterableType]:
8197
if not self.iterable:
8298
return
83-
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
84-
for it in self.iterable:
99+
for it in cast(Iterable[IterableType], self.iterable):
85100
yield it
86101

87102
def update(self, amount: int = 1) -> None:
@@ -91,11 +106,10 @@ def close(self) -> None:
91106
pass
92107

93108

94-
class SimpleProgress:
109+
class SimpleProgress(Iterable[IterableType]):
95110
def __init__(
96111
self,
97-
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
98-
iterable: Optional[Iterable] = None,
112+
iterable: Optional[Iterable[IterableType]] = None,
99113
desc: Optional[str] = None,
100114
total: Optional[int] = None,
101115
file: Optional[TextIO] = None,
@@ -117,34 +131,33 @@ def __init__(
117131

118132
self.desc = desc
119133

120-
# pyre-fixme[9]: file has type `Optional[TextIO]`; used as
121-
# `DisableErrorIOWrapper`.
122-
file = DisableErrorIOWrapper(file if file else sys.stderr)
123-
cast(TextIO, file)
124-
self.file = file
134+
file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr)
135+
self.file: DisableErrorIOWrapper = file_wrapper
125136

126137
self.mininterval = mininterval
127138
self.last_print_t = 0.0
128139
self.closed = False
129140
self._is_parent = False
130141

131-
def __enter__(self) -> "SimpleProgress":
142+
def __enter__(self) -> "SimpleProgress[IterableType]":
132143
self._is_parent = True
133144
self._refresh()
134145
return self
135146

136-
# pyre-fixme[2]: Parameter must be annotated.
137-
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
147+
def __exit__(
148+
self,
149+
exc_type: Union[Type[BaseException], None],
150+
exc_value: Union[BaseException, None],
151+
exc_traceback: Union[TracebackType, None],
152+
) -> Literal[False]:
138153
self.close()
139154
return False
140155

141-
# pyre-fixme[3]: Return type must be annotated.
142-
def __iter__(self):
156+
def __iter__(self) -> Iterator[IterableType]:
143157
if self.closed or not self.iterable:
144158
return
145159
self._refresh()
146-
# pyre-fixme[16]: `Optional` has no attribute `__iter__`.
147-
for it in self.iterable:
160+
for it in cast(Iterable[IterableType], self.iterable):
148161
yield it
149162
self.update()
150163
self.close()
@@ -153,9 +166,7 @@ def _refresh(self) -> None:
153166
progress_str = self.desc + ": " if self.desc else ""
154167
if self.total:
155168
# e.g., progress: 60% 3/5
156-
# pyre-fixme[58]: `//` is not supported for operand types `int` and
157-
# `Optional[int]`.
158-
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
169+
progress_str += f"{100 * self.cur // cast(int, self.total)}% {self.cur}/{cast(int, self.total)}"
159170
else:
160171
# e.g., progress: .....
161172
progress_str += "." * self.cur
@@ -179,18 +190,39 @@ def close(self) -> None:
179190
self.closed = True
180191

181192

182-
# pyre-fixme[3]: Return type must be annotated.
193+
@typing.overload
194+
def progress(
195+
iterable: None = None,
196+
desc: Optional[str] = None,
197+
total: Optional[int] = None,
198+
use_tqdm: bool = True,
199+
file: Optional[TextIO] = None,
200+
mininterval: float = 0.5,
201+
**kwargs: object,
202+
) -> Union[SimpleProgress[None], tqdm]: ...
203+
204+
205+
@typing.overload
206+
def progress(
207+
iterable: Iterable[IterableType],
208+
desc: Optional[str] = None,
209+
total: Optional[int] = None,
210+
use_tqdm: bool = True,
211+
file: Optional[TextIO] = None,
212+
mininterval: float = 0.5,
213+
**kwargs: object,
214+
) -> Union[SimpleProgress[IterableType], tqdm]: ...
215+
216+
183217
def progress(
184-
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
185-
iterable: Optional[Iterable] = None,
218+
iterable: Optional[Iterable[IterableType]] = None,
186219
desc: Optional[str] = None,
187220
total: Optional[int] = None,
188221
use_tqdm: bool = True,
189222
file: Optional[TextIO] = None,
190223
mininterval: float = 0.5,
191-
# pyre-fixme[2]: Parameter must be annotated.
192-
**kwargs,
193-
):
224+
**kwargs: object,
225+
) -> Union[SimpleProgress[IterableType], tqdm]:
194226
# Try to use tqdm is possible. Fall back to simple progress print
195227
if tqdm and use_tqdm:
196228
return tqdm(

captum/influence/_core/tracincp.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
import warnings
77
from abc import abstractmethod
88
from os.path import join
9-
from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union
9+
from typing import (
10+
Any,
11+
Callable,
12+
cast,
13+
Iterable,
14+
Iterator,
15+
List,
16+
Optional,
17+
Tuple,
18+
Type,
19+
Union,
20+
)
1021

1122
import torch
1223
from captum._utils.av import AV
@@ -1033,10 +1044,12 @@ def _influence(
10331044
inputs = _format_inputs_dataset(inputs)
10341045

10351046
train_dataloader = self.train_dataloader
1036-
1047+
data_iterable: Union[Iterable[Tuple[object, ...]], DataLoader] = (
1048+
train_dataloader
1049+
)
10371050
if show_progress:
1038-
train_dataloader = progress(
1039-
train_dataloader,
1051+
data_iterable = progress(
1052+
cast(Iterable[Tuple[object, ...]], train_dataloader),
10401053
desc=(
10411054
f"Using {self.get_name()} to compute "
10421055
"influence for training batches"
@@ -1053,7 +1066,7 @@ def _influence(
10531066
return torch.cat(
10541067
[
10551068
self._influence_batch_tracincp(inputs_checkpoint_jacobians, batch)
1056-
for batch in train_dataloader
1069+
for batch in data_iterable
10571070
],
10581071
dim=1,
10591072
)
@@ -1250,7 +1263,7 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor:
12501263
# the same)
12511264
checkpoint_contribution = []
12521265

1253-
_inputs = inputs
1266+
_inputs: Union[DataLoader, Iterable[Tuple[Tensor, ...]]] = inputs
12541267
# If `show_progress` is true, create an inner progress bar that keeps track
12551268
# of how many batches have been processed for the current checkpoint
12561269
if show_progress:
@@ -1266,8 +1279,8 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor:
12661279
for batch in _inputs:
12671280

12681281
layer_jacobians = self._basic_computation_tracincp(
1269-
batch[0:-1],
1270-
batch[-1],
1282+
cast(Tuple[Tensor, ...], batch)[0:-1],
1283+
cast(Tuple[Tensor, ...], batch)[-1],
12711284
self.loss_fn,
12721285
self.reduction_type,
12731286
)

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import threading
66
import warnings
77
from collections import defaultdict
8-
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
cast,
12+
Dict,
13+
Iterable,
14+
Iterator,
15+
List,
16+
Optional,
17+
Tuple,
18+
Union,
19+
)
920

1021
import torch
1122
from captum._utils.common import _get_module_from_name, _sort_key_list
@@ -418,10 +429,13 @@ def _influence( # type: ignore[override]
418429
"""
419430

420431
train_dataloader = self.train_dataloader
432+
train_dataloader_iterable: Union[DataLoader, Iterable[Tuple[object, ...]]] = (
433+
train_dataloader
434+
)
421435

422436
if show_progress:
423-
train_dataloader = progress(
424-
train_dataloader,
437+
train_dataloader_iterable = progress(
438+
cast(Iterable[Tuple[object, ...]], train_dataloader),
425439
desc=(
426440
f"Using {self.get_name()} to compute "
427441
"influence for training batches"
@@ -432,7 +446,7 @@ def _influence( # type: ignore[override]
432446
return torch.cat(
433447
[
434448
self._influence_batch_tracincp_fast(inputs, batch)
435-
for batch in train_dataloader
449+
for batch in train_dataloader_iterable
436450
],
437451
dim=1,
438452
)

0 commit comments

Comments
 (0)