22
33# pyre-strict
44from collections import defaultdict
5- from typing import Any , Dict , List , Optional , Union
5+ from typing import Any , cast , Dict , Generic , List , Optional , TypeVar , Union
66
77from captum ._utils .common import _format_tensor_into_tuples
88from captum ._utils .typing import TargetType , TensorOrTupleOfTensorsGeneric
1111from captum .log import log_usage
1212from torch import Tensor
1313
14+ KeyType = TypeVar ("KeyType" )
1415
15- class ClassSummarizer (Summarizer ):
16+
17+ class ClassSummarizer (Summarizer , Generic [KeyType ]):
1618 r"""
1719 Used to keep track of summaries for associated classes. The
1820 classes/labels can be of any type that are supported by `dict`.
@@ -23,8 +25,7 @@ class ClassSummarizer(Summarizer):
2325 @log_usage ()
2426 def __init__ (self , stats : List [Stat ]) -> None :
2527 Summarizer .__init__ .__wrapped__ (self , stats )
26- # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
27- self .summaries : Dict [Any , Summarizer ] = defaultdict (
28+ self .summaries : Dict [KeyType , Summarizer ] = defaultdict (
2829 lambda : Summarizer (stats = stats )
2930 )
3031
@@ -84,15 +85,15 @@ def update( # type: ignore
8485 tensors_to_summarize_copy = tuple (tensor [i ].clone () for tensor in x )
8586 label = labels_typed [0 ] if len (labels_typed ) == 1 else labels_typed [i ]
8687
87- self .summaries [label ].update (tensors_to_summarize )
88+ self .summaries [cast ( KeyType , label ) ].update (tensors_to_summarize )
8889 super ().update (tensors_to_summarize_copy )
8990
9091 @property
91- # pyre-fixme[3]: Return annotation cannot contain `Any`.
9292 def class_summaries (
9393 self ,
9494 ) -> Dict [
95- Any , Union [None , Dict [str , Optional [Tensor ]], List [Dict [str , Optional [Tensor ]]]]
95+ KeyType ,
96+ Union [None , Dict [str , Optional [Tensor ]], List [Dict [str , Optional [Tensor ]]]],
9697 ]:
9798 r"""
9899 Returns:
0 commit comments