Skip to content

Commit 07306ce

Browse files
yucufacebook-github-bot
authored andcommitted
Enable pyre for Captum open source part- 1/2 (#1318)
Summary: Pull Request resolved: #1318 Differential Revision: D60748396
1 parent 430793e commit 07306ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+719
-17
lines changed

captum/concept/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
3+
# pyre-strict
24
from captum.concept._core.cav import CAV
35
from captum.concept._core.concept import Concept, ConceptInterpreter
46
from captum.concept._core.tcav import TCAV

captum/concept/_core/cav.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
import os
46
from typing import Any, Dict, List, Optional
57

@@ -87,6 +89,7 @@ def assemble_save_path(
8789
file_name = concepts_to_str(concepts) + "-" + layer + ".pkl"
8890
return os.path.join(path, model_id, file_name)
8991

92+
# pyre-fixme[3]: Return type must be annotated.
9093
def save(self):
9194
r"""
9295
Saves a dictionary of the CAV computed values into a pickle file in the
@@ -134,6 +137,7 @@ def create_cav_dir_if_missing(save_path: str, model_id: str) -> None:
134137
os.makedirs(cav_model_id_path)
135138

136139
@staticmethod
140+
# pyre-fixme[3]: Return type must be annotated.
137141
def load(cavs_path: str, model_id: str, concepts: List[Concept], layer: str):
138142
r"""
139143
Loads CAV dictionary from a pickle file for given input

captum/concept/_core/concept.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
from typing import Callable, Union
46

57
import torch
@@ -56,6 +58,7 @@ def __repr__(self) -> str:
5658
return "Concept(%r, %r)" % (self.id, self.name)
5759

5860

61+
# pyre-fixme[13]: Attribute `interpret` is never initialized.
5962
class ConceptInterpreter:
6063
r"""
6164
An abstract class that exposes an abstract interpret method
@@ -70,6 +73,7 @@ def __init__(self, model: Module) -> None:
7073
"""
7174
self.model = model
7275

76+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
7377
interpret: Callable
7478
r"""
7579
An abstract interpret method that performs concept-based model interpretability

captum/concept/_core/tcav.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
from collections import defaultdict
46
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
57

@@ -53,12 +55,17 @@ def __init__(self, datasets: List[AV.AVDataset], labels: List[int]) -> None:
5355
from itertools import accumulate
5456

5557
offsets = [0] + list(accumulate(map(len, datasets), (lambda x, y: x + y)))
58+
# pyre-fixme[4]: Attribute must be annotated.
5659
self.length = offsets[-1]
5760
self.datasets = datasets
5861
self.labels = labels
62+
# pyre-fixme[4]: Attribute must be annotated.
5963
self.lowers = offsets[:-1]
64+
# pyre-fixme[4]: Attribute must be annotated.
6065
self.uppers = offsets[1:]
6166

67+
# pyre-fixme[3]: Return type must be annotated.
68+
# pyre-fixme[2]: Parameter must be annotated.
6269
def _i_to_k(self, i):
6370

6471
left, right = 0, len(self.uppers)
@@ -71,6 +78,7 @@ def _i_to_k(self, i):
7178
else:
7279
right = mid
7380

81+
# pyre-fixme[3]: Return type must be annotated.
7482
def __getitem__(self, i: int):
7583
"""
7684
Returns a batch of activation vectors, as well as a batch of labels
@@ -89,8 +97,14 @@ def __getitem__(self, i: int):
8997
assert i < self.length
9098
k = self._i_to_k(i)
9199
inputs = self.datasets[k][i - self.lowers[k]]
100+
# pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
101+
# attribute `shape`.
92102
assert len(inputs.shape) == 2
93103

104+
# pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
105+
# attribute `size`.
106+
# pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
107+
# attribute `device`.
94108
labels = torch.tensor([self.labels[k]] * inputs.size(0), device=inputs.device)
95109
return inputs, labels
96110

@@ -102,11 +116,14 @@ def __len__(self) -> int:
102116

103117

104118
def train_cav(
119+
# pyre-fixme[2]: Parameter must be annotated.
105120
model_id,
106121
concepts: List[Concept],
107122
layers: Union[str, List[str]],
108123
classifier: Classifier,
109124
save_path: str,
125+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
126+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
110127
classifier_kwargs: Dict,
111128
) -> Dict[str, Dict[str, CAV]]:
112129
r"""
@@ -159,8 +176,11 @@ def train_cav(
159176

160177
labels = [concept.id for concept in concepts]
161178

179+
# pyre-fixme[22]: The cast is redundant.
162180
labelled_dataset = LabelledDataset(cast(List[AV.AVDataset], datasets), labels)
163181

182+
# pyre-fixme[3]: Return type must be annotated.
183+
# pyre-fixme[2]: Parameter must be annotated.
164184
def batch_collate(batch):
165185
inputs, labels = zip(*batch)
166186
return torch.cat(inputs), torch.cat(labels)
@@ -249,6 +269,7 @@ def __init__(
249269
model_id: str = "default_model_id",
250270
classifier: Optional[Classifier] = None,
251271
layer_attr_method: Optional[LayerAttribution] = None,
272+
# pyre-fixme[2]: Parameter must be annotated.
252273
attribute_to_layer_input=False,
253274
save_path: str = "./cav/",
254275
**classifier_kwargs: Any,
@@ -300,19 +321,28 @@ def __init__(
300321
For more thorough examples, please check out TCAV tutorial and test cases.
301322
"""
302323
ConceptInterpreter.__init__(self, model)
324+
# pyre-fixme[4]: Attribute must be annotated.
303325
self.layers = [layers] if isinstance(layers, str) else layers
304326
self.model_id = model_id
305327
self.concepts: Set[Concept] = set()
306328
self.classifier = classifier
329+
# pyre-fixme[4]: Attribute must be annotated.
307330
self.classifier_kwargs = classifier_kwargs
331+
# pyre-fixme[8]: Attribute has type `Dict[str, Dict[str, CAV]]`; used as
332+
# `DefaultDict[Variable[_KT], DefaultDict[Variable[_KT], Variable[_VT]]]`.
308333
self.cavs: Dict[str, Dict[str, CAV]] = defaultdict(lambda: defaultdict())
309334
if self.classifier is None:
310335
self.classifier = DefaultClassifier()
311336
if layer_attr_method is None:
337+
# pyre-fixme[4]: Attribute must be annotated.
312338
self.layer_attr_method = cast(
313339
LayerAttribution,
314340
LayerGradientXActivation( # type: ignore
315-
model, None, multiply_by_inputs=False
341+
model,
342+
# pyre-fixme[6]: For 2nd argument expected `ModuleOrModuleList`
343+
# but got `None`.
344+
None,
345+
multiply_by_inputs=False,
316346
),
317347
)
318348
else:
@@ -324,6 +354,7 @@ def __init__(
324354
"will use `default_model_id` as its default value."
325355
)
326356

357+
# pyre-fixme[4]: Attribute must be annotated.
327358
self.attribute_to_layer_input = attribute_to_layer_input
328359
self.save_path = save_path
329360

@@ -341,6 +372,8 @@ def generate_all_activations(self) -> None:
341372
for concept in self.concepts:
342373
self.generate_activation(self.layers, concept)
343374

375+
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
376+
# `typing.List[<element type>]` to avoid runtime subscripting errors.
344377
def generate_activation(self, layers: Union[str, List], concept: Concept) -> None:
345378
r"""
346379
Computes layer activations for the specified `concept` and
@@ -361,6 +394,8 @@ def generate_activation(self, layers: Union[str, List], concept: Concept) -> Non
361394
"Data iterator for concept id:",
362395
"{} must be specified".format(concept.id),
363396
)
397+
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
398+
# `Optional[DataLoader[typing.Any]]`.
364399
for i, examples in enumerate(concept.data_iter):
365400
activations = layer_act.attribute.__wrapped__( # type: ignore
366401
layer_act,
@@ -447,6 +482,7 @@ def load_cavs(
447482
concept_layers[concept].append(layer)
448483
return layers, concept_layers
449484

485+
# pyre-fixme[3]: Return type must be annotated.
450486
def compute_cavs(
451487
self,
452488
experimental_sets: List[List[Concept]],
@@ -566,6 +602,7 @@ def interpret(
566602
inputs: TensorOrTupleOfTensorsGeneric,
567603
experimental_sets: List[List[Concept]],
568604
target: TargetType = None,
605+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
569606
additional_forward_args: Any = None,
570607
processes: Optional[int] = None,
571608
**kwargs: Any,
@@ -661,6 +698,9 @@ def interpret(
661698
)
662699
self.compute_cavs(experimental_sets, processes=processes)
663700

701+
# pyre-fixme[9]: scores has type `Dict[str, Dict[str, Dict[str, Tensor]]]`;
702+
# used as `DefaultDict[Variable[_KT], DefaultDict[Variable[_KT],
703+
# Variable[_VT]]]`.
664704
scores: Dict[str, Dict[str, Dict[str, Tensor]]] = defaultdict(
665705
lambda: defaultdict()
666706
)
@@ -704,6 +744,7 @@ def interpret(
704744
attribs = _format_tensor_into_tuples(attribs)
705745
# n_inputs x n_features
706746
attribs = torch.cat(
747+
# pyre-fixme[16]: `None` has no attribute `__iter__`.
707748
[torch.reshape(attrib, (attrib.shape[0], -1)) for attrib in attribs],
708749
dim=1,
709750
)
@@ -713,6 +754,7 @@ def interpret(
713754
classes = []
714755
for concepts in experimental_sets:
715756
concepts_key = concepts_to_str(concepts)
757+
# pyre-fixme[33]: Given annotation cannot contain `Any`.
716758
cavs_stats = cast(Dict[str, Any], self.cavs[concepts_key][layer].stats)
717759
cavs.append(cavs_stats["weights"].float().detach().tolist())
718760
classes.append(cavs_stats["classes"])

captum/concept/_utils/classifier.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
import random
46
import warnings
57
from abc import ABC, abstractmethod
@@ -64,7 +66,11 @@ def __init__(self) -> None:
6466

6567
@abstractmethod
6668
def train_and_eval(
67-
self, dataloader: DataLoader, **kwargs: Any
69+
self,
70+
dataloader: DataLoader,
71+
**kwargs: Any
72+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
73+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
6874
) -> Union[Dict, None]:
6975
r"""
7076
This method is responsible for training a classifier using the data
@@ -132,12 +138,18 @@ def __init__(self) -> None:
132138
" both train and test datasets in the memory. Consider defining"
133139
" your own classifier that doesn't rely heavily on memory, for"
134140
" large number of concepts, by extending"
135-
" `Classifer` abstract class"
141+
" `Classifer` abstract class",
142+
stacklevel=2,
136143
)
137144
self.lm = model.SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3)
138145

139146
def train_and_eval(
140-
self, dataloader: DataLoader, test_split_ratio: float = 0.33, **kwargs: Any
147+
self,
148+
dataloader: DataLoader,
149+
test_split_ratio: float = 0.33,
150+
**kwargs: Any
151+
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
152+
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
141153
) -> Union[Dict, None]:
142154
r"""
143155
Implements Classifier::train_and_eval abstract method for small concept
@@ -169,6 +181,7 @@ def train_and_eval(
169181
inputs.append(input)
170182
labels.append(label)
171183

184+
# pyre-fixme[61]: `input` is undefined, or not always defined.
172185
device = "cpu" if input is None else input.device
173186
x_train, x_test, y_train, y_test = _train_test_split(
174187
torch.cat(inputs), torch.cat(labels), test_split=test_split_ratio

captum/concept/_utils/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
from typing import List
46

57
from captum.concept._core.concept import Concept

captum/concept/_utils/data_iterator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
import glob
46
import os
57
from typing import Callable, Iterator
@@ -13,6 +15,7 @@ class CustomIterableDataset(IterableDataset):
1315
An auxiliary class for iterating through a dataset.
1416
"""
1517

18+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
1619
def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None:
1720
r"""
1821
Args:
@@ -21,6 +24,7 @@ def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None:
2124
path (str): Path to dataset files. This can be either a path to a
2225
directory or a file where input examples are stored.
2326
"""
27+
# pyre-fixme[4]: Attribute must be annotated.
2428
self.file_itr = None
2529
self.path = path
2630

captum/influence/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
# pyre-strict
4+
35
from captum.influence._core.influence import DataInfluence
46
from captum.influence._core.influence_function import NaiveInfluenceFunction
57
from captum.influence._core.similarity_influence import SimilarityInfluence

0 commit comments

Comments
 (0)