Skip to content

Commit f193598

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Fix incompatible type breaking OSS tests
Summary: Correct typing in dataloader attr to prevent tests from breaking Differential Revision: D64412835
1 parent 4cb2808 commit f193598

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

captum/attr/_core/dataloader_attr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
from collections import defaultdict
55
from copy import copy
6-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -193,8 +193,7 @@ def _forward_with_dataloader(
193193
feature_mask: Tuple[Tensor, ...],
194194
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
195195
reduce: Callable,
196-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
197-
to_metric: Optional[Callable],
196+
to_metric: Optional[Callable[[Tensor], Tensor]],
198197
show_progress: bool,
199198
feature_idx_to_mask_idx: Dict[int, List[int]],
200199
) -> Tensor:
@@ -243,7 +242,8 @@ def _forward_with_dataloader(
243242

244243
accum_states[i] = reduce(accum_states[i], output, perturbed_inputs)
245244

246-
accum_results = [
245+
accum_states = cast(List[Tensor], accum_states)
246+
accum_results: List[Tensor] = [
247247
to_metric(accum) if to_metric else accum for accum in accum_states
248248
]
249249

0 commit comments

Comments
 (0)