Skip to content

Commit 397d46f

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix layer gradient x activation pyre fixme issues (#1472)
Summary: Pull Request resolved: #1472 Fixing unresolved pyre fixme issues in corresponding file Reviewed By: cyrjano Differential Revision: D67705758 fbshipit-source-id: e05d67d46a59211cd959f8824f6069d990f13720
1 parent 0f50fa8 commit 397d46f

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

captum/attr/_core/layer/layer_gradient_x_activation.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55

66
from captum._utils.common import (
77
_format_additional_forward_args,
@@ -24,8 +24,7 @@ class LayerGradientXActivation(LayerAttribution, GradientAttribution):
2424

2525
def __init__(
2626
self,
27-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
28-
forward_func: Callable,
27+
forward_func: Callable[..., Tensor],
2928
layer: ModuleOrModuleList,
3029
device_ids: Union[None, List[int]] = None,
3130
multiply_by_inputs: bool = True,
@@ -186,11 +185,10 @@ def attribute(
186185
if isinstance(self.layer, Module):
187186
return _format_output(
188187
len(layer_evals) > 1,
189-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
190-
# got `List[typing.Tuple[Tensor, ...]]`.
191-
# pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but
192-
# got `List[typing.Tuple[Tensor, ...]]`.
193-
self.multiply_gradient_acts(layer_gradients, layer_evals),
188+
self.multiply_gradient_acts(
189+
cast(Tuple[Tensor, ...], layer_gradients),
190+
cast(Tuple[Tensor, ...], layer_evals),
191+
),
194192
)
195193
else:
196194
return [

0 commit comments

Comments
 (0)