Skip to content

Commit a95eb46

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix summarizer pyre fix me issues (#1479)
Summary: Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67707848
1 parent c0b1dda commit a95eb46

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

captum/attr/_utils/summarizer.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class Summarizer:
2828
>>>print(summ.summary['mean'])
2929
"""
3030

31+
_stats: List[Stat]
32+
_summary_stats_indicies: List[int]
33+
3134
@log_usage()
3235
def __init__(self, stats: List[Stat]) -> None:
3336
r"""
@@ -37,11 +40,9 @@ def __init__(self, stats: List[Stat]) -> None:
3740
"""
3841
self._summarizers: List[SummarizerSingleTensor] = []
3942
self._is_inputs_tuple: Optional[bool] = None
40-
# pyre-fixme[4]: Attribute must be annotated.
4143
self._stats, self._summary_stats_indicies = _reorder_stats(stats)
4244

43-
# pyre-fixme[3]: Return type must be annotated.
44-
def _copy_stats(self):
45+
def _copy_stats(self) -> List[Stat]:
4546
import copy
4647

4748
return copy.deepcopy(self._stats)
@@ -125,48 +126,37 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
125126
dep_order = [StdDev, Var, MSE, Mean, Count]
126127

127128
# remove dupe stats
128-
# pyre-fixme[9]: stats has type `List[Stat]`; used as `Set[Stat]`.
129-
stats = set(stats)
129+
stats_set = set(stats)
130130
summary_stats = set(stats)
131131

132132
from collections import defaultdict
133133

134-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
135-
# `typing.Type[<base type>]` to avoid runtime subscripting errors.
136-
stats_by_module: Dict[Type, List[Stat]] = defaultdict(list)
137-
for stat in stats:
134+
stats_by_module: Dict[Type[Stat], List[Stat]] = defaultdict(list)
135+
for stat in stats_set:
138136
stats_by_module[stat.__class__].append(stat)
139137

140138
# StdDev is an odd case since it is parameterized, thus
141139
# for each StdDev(order) we must ensure there is an associated Var(order)
142140
for std_dev in stats_by_module[StdDev]:
143141
stat_to_add = Var(order=std_dev.order) # type: ignore
144-
# pyre-fixme[16]: `List` has no attribute `add`.
145-
stats.add(stat_to_add)
142+
stats_set.add(stat_to_add)
146143
stats_by_module[stat_to_add.__class__].append(stat_to_add)
147144

148145
# For the other modules (deps[1:n-1]): if i exists =>
149146
# we want to ensure i...n-1 exists
150147
for i, dep in enumerate(dep_order[1:]):
151148
if dep in stats_by_module:
152-
# pyre-fixme[16]: `List` has no attribute `update`.
153-
stats.update([mod() for mod in dep_order[i + 1 :]])
149+
stats_set.update([mod() for mod in dep_order[i + 1 :]])
154150
break
155151

156152
# Step 2: get the correct order
157153
# NOTE: we are sorting via a given topological order
158-
sort_order = {mod: i for i, mod in enumerate(dep_order)}
159-
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
160-
# Var]]` but got `Type[Min]`.
154+
sort_order: Dict[Type[Stat], int] = {mod: i for i, mod in enumerate(dep_order)}
161155
sort_order[Min] = -1
162-
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
163-
# Var]]` but got `Type[Max]`.
164156
sort_order[Max] = -1
165-
# pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
166-
# Var]]` but got `Type[Sum]`.
167157
sort_order[Sum] = -1
168158

169-
stats = list(stats)
159+
stats = list(stats_set)
170160
stats.sort(key=lambda x: sort_order[x.__class__], reverse=True)
171161

172162
# get the summary stat indices
@@ -185,6 +175,10 @@ class SummarizerSingleTensor:
185175
If possible use `Summarizer` instead.
186176
"""
187177

178+
_stats: List[Stat]
179+
_stat_to_stat: Dict[Stat, Stat]
180+
_summary_stats: List[Stat]
181+
188182
def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
189183
r"""
190184
Args:
@@ -196,9 +190,7 @@ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
196190
does not require any specific order.
197191
"""
198192
self._stats = stats
199-
# pyre-fixme[4]: Attribute must be annotated.
200193
self._stat_to_stat = {stat: stat for stat in self._stats}
201-
# pyre-fixme[4]: Attribute must be annotated.
202194
self._summary_stats = [stats[i] for i in summary_stats_indices]
203195

204196
for stat in stats:

0 commit comments

Comments
 (0)