@@ -28,6 +28,9 @@ class Summarizer:
2828 >>>print(summ.summary['mean'])
2929 """
3030
31+ _stats : List [Stat ]
32+ _summary_stats_indicies : List [int ]
33+
3134 @log_usage ()
3235 def __init__ (self , stats : List [Stat ]) -> None :
3336 r"""
@@ -37,11 +40,9 @@ def __init__(self, stats: List[Stat]) -> None:
3740 """
3841 self ._summarizers : List [SummarizerSingleTensor ] = []
3942 self ._is_inputs_tuple : Optional [bool ] = None
40- # pyre-fixme[4]: Attribute must be annotated.
4143 self ._stats , self ._summary_stats_indicies = _reorder_stats (stats )
4244
43- # pyre-fixme[3]: Return type must be annotated.
44- def _copy_stats (self ):
45+ def _copy_stats (self ) -> List [Stat ]:
4546 import copy
4647
4748 return copy .deepcopy (self ._stats )
@@ -125,48 +126,37 @@ def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
125126 dep_order = [StdDev , Var , MSE , Mean , Count ]
126127
127128 # remove dupe stats
128- # pyre-fixme[9]: stats has type `List[Stat]`; used as `Set[Stat]`.
129- stats = set (stats )
129+ stats_set = set (stats )
130130 summary_stats = set (stats )
131131
132132 from collections import defaultdict
133133
134- # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
135- # `typing.Type[<base type>]` to avoid runtime subscripting errors.
136- stats_by_module : Dict [Type , List [Stat ]] = defaultdict (list )
137- for stat in stats :
134+ stats_by_module : Dict [Type [Stat ], List [Stat ]] = defaultdict (list )
135+ for stat in stats_set :
138136 stats_by_module [stat .__class__ ].append (stat )
139137
140138 # StdDev is an odd case since it is parameterized, thus
141139 # for each StdDev(order) we must ensure there is an associated Var(order)
142140 for std_dev in stats_by_module [StdDev ]:
143141 stat_to_add = Var (order = std_dev .order ) # type: ignore
144- # pyre-fixme[16]: `List` has no attribute `add`.
145- stats .add (stat_to_add )
142+ stats_set .add (stat_to_add )
146143 stats_by_module [stat_to_add .__class__ ].append (stat_to_add )
147144
148145 # For the other modules (deps[1:n-1]): if i exists =>
149146 # we want to ensure i...n-1 exists
150147 for i , dep in enumerate (dep_order [1 :]):
151148 if dep in stats_by_module :
152- # pyre-fixme[16]: `List` has no attribute `update`.
153- stats .update ([mod () for mod in dep_order [i + 1 :]])
149+ stats_set .update ([mod () for mod in dep_order [i + 1 :]])
154150 break
155151
156152 # Step 2: get the correct order
157153 # NOTE: we are sorting via a given topological order
158- sort_order = {mod : i for i , mod in enumerate (dep_order )}
159- # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
160- # Var]]` but got `Type[Min]`.
154+ sort_order : Dict [Type [Stat ], int ] = {mod : i for i , mod in enumerate (dep_order )}
161155 sort_order [Min ] = - 1
162- # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
163- # Var]]` but got `Type[Max]`.
164156 sort_order [Max ] = - 1
165- # pyre-fixme[6]: For 1st argument expected `Type[Union[Count, MSE, Mean, StdDev,
166- # Var]]` but got `Type[Sum]`.
167157 sort_order [Sum ] = - 1
168158
169- stats = list (stats )
159+ stats = list (stats_set )
170160 stats .sort (key = lambda x : sort_order [x .__class__ ], reverse = True )
171161
172162 # get the summary stat indices
@@ -185,6 +175,10 @@ class SummarizerSingleTensor:
185175 If possible use `Summarizer` instead.
186176 """
187177
178+ _stats : List [Stat ]
179+ _stat_to_stat : Dict [Stat , Stat ]
180+ _summary_stats : List [Stat ]
181+
188182 def __init__ (self , stats : List [Stat ], summary_stats_indices : List [int ]) -> None :
189183 r"""
190184 Args:
@@ -196,9 +190,7 @@ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
196190 does not require any specific order.
197191 """
198192 self ._stats = stats
199- # pyre-fixme[4]: Attribute must be annotated.
200193 self ._stat_to_stat = {stat : stat for stat in self ._stats }
201- # pyre-fixme[4]: Attribute must be annotated.
202194 self ._summary_stats = [stats [i ] for i in summary_stats_indices ]
203195
204196 for stat in stats :
0 commit comments