|
3 | 3 | # pyre-strict |
4 | 4 | from collections import defaultdict |
5 | 5 | from copy import copy |
6 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
| 6 | +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | from captum._utils.common import ( |
@@ -193,8 +193,7 @@ def _forward_with_dataloader( |
193 | 193 | feature_mask: Tuple[Tensor, ...], |
194 | 194 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
195 | 195 | reduce: Callable, |
196 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
197 | | - to_metric: Optional[Callable], |
| 196 | + to_metric: Optional[Callable[[Tensor], Tensor]], |
198 | 197 | show_progress: bool, |
199 | 198 | feature_idx_to_mask_idx: Dict[int, List[int]], |
200 | 199 | ) -> Tensor: |
@@ -243,7 +242,8 @@ def _forward_with_dataloader( |
243 | 242 |
|
244 | 243 | accum_states[i] = reduce(accum_states[i], output, perturbed_inputs) |
245 | 244 |
|
246 | | - accum_results = [ |
| 245 | + accum_states = cast(List[Tensor], accum_states) |
| 246 | + accum_results: List[Tensor] = [ |
247 | 247 | to_metric(accum) if to_metric else accum for accum in accum_states |
248 | 248 | ] |
249 | 249 |
|
@@ -276,7 +276,7 @@ def attribute( |
276 | 276 | Args: |
277 | 277 |
|
278 | 278 | dataloader (torch.Dataloader): the dataloader to attribute, which should |
279 | | - return a tuple of consistant size for every iteration |
| 279 | + return a tuple of consistent size for every iteration |
280 | 280 | input_roles (tuple[int, ...], optional): a tuple of integers to define the |
281 | 281 | role of each element returned from the dataloader. It should |
282 | 282 | have the same size as the return of the dataloader. |
@@ -326,7 +326,7 @@ def attribute( |
326 | 326 | traverses needed is |
327 | 327 | ceil(n_perturbations / perturbations_per_pass). |
328 | 328 |
|
329 | | - This arguement offers control of the trade-off between memory |
| 329 | + This argument offers control of the trade-off between memory |
330 | 330 | and efficiency. If the dataloader involves slow operations like |
331 | 331 | remote request or file I/O, multiple traversals can be |
332 | 332 | inefficient. On the other hand, each perturbation needs to |
|
0 commit comments