Skip to content

Commit 51b6596

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Kernel Shap (#1399)
Summary: Initial work on fixing Pyre errors in KernelShap Reviewed By: csauper Differential Revision: D64677350
1 parent 9d0fb13 commit 51b6596

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

captum/attr/_core/kernel_shap.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import Any, Callable, Generator, Tuple, Union
5+
from typing import Any, Callable, cast, Generator, Tuple, Union
66

77
import torch
88
from captum._utils.models.linear_model import SkLearnLinearRegression
@@ -27,8 +27,7 @@ class KernelShap(Lime):
2727
https://arxiv.org/abs/1705.07874
2828
"""
2929

30-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31-
def __init__(self, forward_func: Callable) -> None:
30+
def __init__(self, forward_func: Callable[..., Tensor]) -> None:
3231
r"""
3332
Args:
3433
@@ -50,8 +49,7 @@ def attribute( # type: ignore
5049
inputs: TensorOrTupleOfTensorsGeneric,
5150
baselines: BaselineType = None,
5251
target: TargetType = None,
53-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
54-
additional_forward_args: Any = None,
52+
additional_forward_args: object = None,
5553
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
5654
n_samples: int = 25,
5755
perturbations_per_eval: int = 1,
@@ -279,10 +277,7 @@ def attribute( # type: ignore
279277
)
280278
num_features_list = torch.arange(num_interp_features, dtype=torch.float)
281279
denom = num_features_list * (num_interp_features - num_features_list)
282-
# pyre-fixme[58]: `/` is not supported for operand types
283-
# `int` and `torch._tensor.Tensor`.
284-
probs = (num_interp_features - 1) / denom
285-
# pyre-fixme[16]: `float` has no attribute `__setitem__`.
280+
probs = torch.tensor((num_interp_features - 1)) / denom
286281
probs[0] = 0.0
287282
return self._attribute_kwargs(
288283
inputs=inputs,
@@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel(
309304
_,
310305
__,
311306
interpretable_sample: Tensor,
312-
# pyre-fixme[2]: Parameter must be annotated.
313-
**kwargs,
307+
**kwargs: object,
314308
) -> Tensor:
315309
assert (
316310
"num_interp_features" in kwargs
@@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel(
332326
def kernel_shap_perturb_generator(
333327
self,
334328
original_inp: Union[Tensor, Tuple[Tensor, ...]],
335-
# pyre-fixme[2]: Parameter must be annotated.
336-
**kwargs,
329+
**kwargs: object,
337330
) -> Generator[Tensor, None, None]:
338331
r"""
339332
Perturbations are sampled by the following process:
@@ -361,11 +354,13 @@ def kernel_shap_perturb_generator(
361354
device = original_inp.device
362355
else:
363356
device = original_inp[0].device
364-
num_features = kwargs["num_interp_features"]
357+
num_features = cast(int, kwargs["num_interp_features"])
365358
yield torch.ones(1, num_features, device=device, dtype=torch.long)
366359
yield torch.zeros(1, num_features, device=device, dtype=torch.long)
367360
while True:
368-
num_selected_features = kwargs["num_select_distribution"].sample()
361+
num_selected_features = cast(
362+
Categorical, kwargs["num_select_distribution"]
363+
).sample()
369364
rand_vals = torch.randn(1, num_features)
370365
threshold = torch.kthvalue(
371366
rand_vals, num_features - num_selected_features

0 commit comments

Comments
 (0)