Skip to content

Commit 180920b

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix class summarizer pyre fix me issues
Differential Revision: D67706853
1 parent 0d2d3af commit 180920b

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

captum/attr/_utils/class_summarizer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
from collections import defaultdict
5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, cast, Dict, Generic, List, Optional, TypeVar, Union
66

77
from captum._utils.common import _format_tensor_into_tuples
88
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
@@ -11,8 +11,10 @@
1111
from captum.log import log_usage
1212
from torch import Tensor
1313

14+
KeyType = TypeVar("KeyType")
1415

15-
class ClassSummarizer(Summarizer):
16+
17+
class ClassSummarizer(Summarizer, Generic[KeyType]):
1618
r"""
1719
Used to keep track of summaries for associated classes. The
1820
classes/labels can be of any type that are supported by `dict`.
@@ -23,8 +25,7 @@ class ClassSummarizer(Summarizer):
2325
@log_usage()
2426
def __init__(self, stats: List[Stat]) -> None:
2527
Summarizer.__init__.__wrapped__(self, stats)
26-
# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
27-
self.summaries: Dict[Any, Summarizer] = defaultdict(
28+
self.summaries: Dict[KeyType, Summarizer] = defaultdict(
2829
lambda: Summarizer(stats=stats)
2930
)
3031

@@ -84,15 +85,15 @@ def update( # type: ignore
8485
tensors_to_summarize_copy = tuple(tensor[i].clone() for tensor in x)
8586
label = labels_typed[0] if len(labels_typed) == 1 else labels_typed[i]
8687

87-
self.summaries[label].update(tensors_to_summarize)
88+
self.summaries[cast(KeyType, label)].update(tensors_to_summarize)
8889
super().update(tensors_to_summarize_copy)
8990

9091
@property
91-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
9292
def class_summaries(
9393
self,
9494
) -> Dict[
95-
Any, Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]]
95+
KeyType,
96+
Union[None, Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]],
9697
]:
9798
r"""
9899
Returns:

0 commit comments

Comments
 (0)