11#!/usr/bin/env python3
22
3+ # pyre-strict
4+
35from collections import defaultdict
46from typing import Any , cast , Dict , List , Optional , Set , Tuple , Union
57
@@ -53,12 +55,17 @@ def __init__(self, datasets: List[AV.AVDataset], labels: List[int]) -> None:
5355 from itertools import accumulate
5456
5557 offsets = [0 ] + list (accumulate (map (len , datasets ), (lambda x , y : x + y )))
58+ # pyre-fixme[4]: Attribute must be annotated.
5659 self .length = offsets [- 1 ]
5760 self .datasets = datasets
5861 self .labels = labels
62+ # pyre-fixme[4]: Attribute must be annotated.
5963 self .lowers = offsets [:- 1 ]
64+ # pyre-fixme[4]: Attribute must be annotated.
6065 self .uppers = offsets [1 :]
6166
67+ # pyre-fixme[3]: Return type must be annotated.
68+ # pyre-fixme[2]: Parameter must be annotated.
6269 def _i_to_k (self , i ):
6370
6471 left , right = 0 , len (self .uppers )
@@ -71,6 +78,7 @@ def _i_to_k(self, i):
7178 else :
7279 right = mid
7380
81+ # pyre-fixme[3]: Return type must be annotated.
7482 def __getitem__ (self , i : int ):
7583 """
7684 Returns a batch of activation vectors, as well as a batch of labels
@@ -89,8 +97,14 @@ def __getitem__(self, i: int):
8997 assert i < self .length
9098 k = self ._i_to_k (i )
9199 inputs = self .datasets [k ][i - self .lowers [k ]]
100+ # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
101+ # attribute `shape`.
92102 assert len (inputs .shape ) == 2
93103
104+ # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
105+ # attribute `size`.
106+ # pyre-fixme[16]: Item `tuple` of `Union[Tensor, Tuple[Tensor, ...]]` has no
107+ # attribute `device`.
94108 labels = torch .tensor ([self .labels [k ]] * inputs .size (0 ), device = inputs .device )
95109 return inputs , labels
96110
@@ -102,11 +116,14 @@ def __len__(self) -> int:
102116
103117
104118def train_cav (
119+ # pyre-fixme[2]: Parameter must be annotated.
105120 model_id ,
106121 concepts : List [Concept ],
107122 layers : Union [str , List [str ]],
108123 classifier : Classifier ,
109124 save_path : str ,
125+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
126+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
110127 classifier_kwargs : Dict ,
111128) -> Dict [str , Dict [str , CAV ]]:
112129 r"""
@@ -159,8 +176,11 @@ def train_cav(
159176
160177 labels = [concept .id for concept in concepts ]
161178
179+ # pyre-fixme[22]: The cast is redundant.
162180 labelled_dataset = LabelledDataset (cast (List [AV .AVDataset ], datasets ), labels )
163181
182+ # pyre-fixme[3]: Return type must be annotated.
183+ # pyre-fixme[2]: Parameter must be annotated.
164184 def batch_collate (batch ):
165185 inputs , labels = zip (* batch )
166186 return torch .cat (inputs ), torch .cat (labels )
@@ -249,6 +269,7 @@ def __init__(
249269 model_id : str = "default_model_id" ,
250270 classifier : Optional [Classifier ] = None ,
251271 layer_attr_method : Optional [LayerAttribution ] = None ,
272+ # pyre-fixme[2]: Parameter must be annotated.
252273 attribute_to_layer_input = False ,
253274 save_path : str = "./cav/" ,
254275 ** classifier_kwargs : Any ,
@@ -300,19 +321,28 @@ def __init__(
300321 For more thorough examples, please check out TCAV tutorial and test cases.
301322 """
302323 ConceptInterpreter .__init__ (self , model )
324+ # pyre-fixme[4]: Attribute must be annotated.
303325 self .layers = [layers ] if isinstance (layers , str ) else layers
304326 self .model_id = model_id
305327 self .concepts : Set [Concept ] = set ()
306328 self .classifier = classifier
329+ # pyre-fixme[4]: Attribute must be annotated.
307330 self .classifier_kwargs = classifier_kwargs
331+ # pyre-fixme[8]: Attribute has type `Dict[str, Dict[str, CAV]]`; used as
332+ # `DefaultDict[Variable[_KT], DefaultDict[Variable[_KT], Variable[_VT]]]`.
308333 self .cavs : Dict [str , Dict [str , CAV ]] = defaultdict (lambda : defaultdict ())
309334 if self .classifier is None :
310335 self .classifier = DefaultClassifier ()
311336 if layer_attr_method is None :
337+ # pyre-fixme[4]: Attribute must be annotated.
312338 self .layer_attr_method = cast (
313339 LayerAttribution ,
314340 LayerGradientXActivation ( # type: ignore
315- model , None , multiply_by_inputs = False
341+ model ,
342+ # pyre-fixme[6]: For 2nd argument expected `ModuleOrModuleList`
343+ # but got `None`.
344+ None ,
345+ multiply_by_inputs = False ,
316346 ),
317347 )
318348 else :
@@ -324,6 +354,7 @@ def __init__(
324354 "will use `default_model_id` as its default value."
325355 )
326356
357+ # pyre-fixme[4]: Attribute must be annotated.
327358 self .attribute_to_layer_input = attribute_to_layer_input
328359 self .save_path = save_path
329360
@@ -341,6 +372,8 @@ def generate_all_activations(self) -> None:
341372 for concept in self .concepts :
342373 self .generate_activation (self .layers , concept )
343374
375+ # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
376+ # `typing.List[<element type>]` to avoid runtime subscripting errors.
344377 def generate_activation (self , layers : Union [str , List ], concept : Concept ) -> None :
345378 r"""
346379 Computes layer activations for the specified `concept` and
@@ -361,6 +394,8 @@ def generate_activation(self, layers: Union[str, List], concept: Concept) -> Non
361394 "Data iterator for concept id:" ,
362395 "{} must be specified" .format (concept .id ),
363396 )
397+ # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
398+ # `Optional[DataLoader[typing.Any]]`.
364399 for i , examples in enumerate (concept .data_iter ):
365400 activations = layer_act .attribute .__wrapped__ ( # type: ignore
366401 layer_act ,
@@ -447,6 +482,7 @@ def load_cavs(
447482 concept_layers [concept ].append (layer )
448483 return layers , concept_layers
449484
485+ # pyre-fixme[3]: Return type must be annotated.
450486 def compute_cavs (
451487 self ,
452488 experimental_sets : List [List [Concept ]],
@@ -566,6 +602,7 @@ def interpret(
566602 inputs : TensorOrTupleOfTensorsGeneric ,
567603 experimental_sets : List [List [Concept ]],
568604 target : TargetType = None ,
605+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
569606 additional_forward_args : Any = None ,
570607 processes : Optional [int ] = None ,
571608 ** kwargs : Any ,
@@ -661,6 +698,9 @@ def interpret(
661698 )
662699 self .compute_cavs (experimental_sets , processes = processes )
663700
701+ # pyre-fixme[9]: scores has type `Dict[str, Dict[str, Dict[str, Tensor]]]`;
702+ # used as `DefaultDict[Variable[_KT], DefaultDict[Variable[_KT],
703+ # Variable[_VT]]]`.
664704 scores : Dict [str , Dict [str , Dict [str , Tensor ]]] = defaultdict (
665705 lambda : defaultdict ()
666706 )
@@ -704,6 +744,7 @@ def interpret(
704744 attribs = _format_tensor_into_tuples (attribs )
705745 # n_inputs x n_features
706746 attribs = torch .cat (
747+ # pyre-fixme[16]: `None` has no attribute `__iter__`.
707748 [torch .reshape (attrib , (attrib .shape [0 ], - 1 )) for attrib in attribs ],
708749 dim = 1 ,
709750 )
@@ -713,6 +754,7 @@ def interpret(
713754 classes = []
714755 for concepts in experimental_sets :
715756 concepts_key = concepts_to_str (concepts )
757+ # pyre-fixme[33]: Given annotation cannot contain `Any`.
716758 cavs_stats = cast (Dict [str , Any ], self .cavs [concepts_key ][layer ].stats )
717759 cavs .append (cavs_stats ["weights" ].float ().detach ().tolist ())
718760 classes .append (cavs_stats ["classes" ])
0 commit comments