Skip to content

Commit ac62dc5

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Fix incompatible type breaking OSS tests (meta-pytorch#1371)
Summary: Correct typing in dataloader attr to prevent tests from breaking Differential Revision: D64412835
1 parent 4cb2808 commit ac62dc5

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

captum/attr/_core/dataloader_attr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
from collections import defaultdict
55
from copy import copy
6-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
77

88
import torch
99
from captum._utils.common import (
@@ -193,8 +193,7 @@ def _forward_with_dataloader(
193193
feature_mask: Tuple[Tensor, ...],
194194
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
195195
reduce: Callable,
196-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
197-
to_metric: Optional[Callable],
196+
to_metric: Optional[Callable[[Tensor], Tensor]],
198197
show_progress: bool,
199198
feature_idx_to_mask_idx: Dict[int, List[int]],
200199
) -> Tensor:
@@ -243,7 +242,8 @@ def _forward_with_dataloader(
243242

244243
accum_states[i] = reduce(accum_states[i], output, perturbed_inputs)
245244

246-
accum_results = [
245+
accum_states = cast(List[Tensor], accum_states)
246+
accum_results: List[Tensor] = [
247247
to_metric(accum) if to_metric else accum for accum in accum_states
248248
]
249249

captum/attr/_core/lime.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,9 @@ def attribute(
522522
if show_progress:
523523
attr_progress.close()
524524

525-
combined_interp_inps = torch.cat(interpretable_inps).float()
525+
# Argument 1 to "cat" has incompatible type "list[Tensor | tuple[Tensor, ...]]";
526+
# expected "tuple[Tensor, ...] | list[Tensor]" [arg-type]
527+
combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore
526528
combined_outputs = (
527529
torch.cat(outputs)
528530
if len(outputs[0].shape) > 0

captum/concept/_utils/classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def train_and_eval(
186186
x_train, x_test, y_train, y_test = _train_test_split(
187187
torch.cat(inputs), torch.cat(labels), test_split=test_split_ratio
188188
)
189-
self.lm.device = device
189+
# error: Incompatible types in assignment (expression has type "str | Any",
190+
# variable has type "Tensor | Module") [assignment]
191+
self.lm.device = device # type: ignore
190192
self.lm.fit(DataLoader(TensorDataset(x_train, y_train)))
191193

192194
predict = self.lm(x_test)

0 commit comments

Comments
 (0)