Skip to content

Commit 85b769d

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix deeplift mypy error (#1459)
Summary: Currently, Captum OSS tests are failing due to mypy failures (likely from new version) in DeepLift test cases. Adds fix for type failure caused by different signature between DeepLift and DeepLiftShap. Differential Revision: D67538043
1 parent 43a647f commit 85b769d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/attr/test_deeplift_classification.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-unsafe
44

5-
from typing import Union
5+
from typing import TypeVar, Union
66

77
import torch
88
from captum._utils.typing import TargetType
@@ -21,6 +21,8 @@
2121
from torch import Tensor
2222
from torch.nn import Module
2323

24+
DeepLiftAttrMethod = TypeVar("DeepLiftAttrMethod", DeepLift, DeepLiftShap)
25+
2426

2527
class Test(BaseTest):
2628
def test_sigmoid_classification(self) -> None:
@@ -155,7 +157,7 @@ def test_convnet_with_maxpool1d_large_baselines(self) -> None:
155157
def softmax_classification(
156158
self,
157159
model: Module,
158-
attr_method: Union[DeepLift, DeepLiftShap],
160+
attr_method: DeepLiftAttrMethod,
159161
input: Tensor,
160162
baselines: Union[float, int, Tensor],
161163
target: TargetType,

0 commit comments

Comments
 (0)