22
33# pyre-strict
44
5- from typing import Any , Callable , Generator , Tuple , Union
5+ from typing import Any , Callable , cast , Generator , Tuple , Union
66
77import torch
88from captum ._utils .models .linear_model import SkLearnLinearRegression
@@ -27,8 +27,7 @@ class KernelShap(Lime):
2727 https://arxiv.org/abs/1705.07874
2828 """
2929
30- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
31- def __init__ (self , forward_func : Callable ) -> None :
30+ def __init__ (self , forward_func : Callable [..., Tensor ]) -> None :
3231 r"""
3332 Args:
3433
@@ -50,8 +49,7 @@ def attribute( # type: ignore
5049 inputs : TensorOrTupleOfTensorsGeneric ,
5150 baselines : BaselineType = None ,
5251 target : TargetType = None ,
53- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
54- additional_forward_args : Any = None ,
52+ additional_forward_args : object = None ,
5553 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
5654 n_samples : int = 25 ,
5755 perturbations_per_eval : int = 1 ,
@@ -279,10 +277,7 @@ def attribute( # type: ignore
279277 )
280278 num_features_list = torch .arange (num_interp_features , dtype = torch .float )
281279 denom = num_features_list * (num_interp_features - num_features_list )
282- # pyre-fixme[58]: `/` is not supported for operand types
283- # `int` and `torch._tensor.Tensor`.
284- probs = (num_interp_features - 1 ) / denom
285- # pyre-fixme[16]: `float` has no attribute `__setitem__`.
280+ probs = torch .tensor ((num_interp_features - 1 )) / denom
286281 probs [0 ] = 0.0
287282 return self ._attribute_kwargs (
288283 inputs = inputs ,
@@ -309,8 +304,7 @@ def kernel_shap_similarity_kernel(
309304 _ ,
310305 __ ,
311306 interpretable_sample : Tensor ,
312- # pyre-fixme[2]: Parameter must be annotated.
313- ** kwargs ,
307+ ** kwargs : object ,
314308 ) -> Tensor :
315309 assert (
316310 "num_interp_features" in kwargs
@@ -332,8 +326,7 @@ def kernel_shap_similarity_kernel(
332326 def kernel_shap_perturb_generator (
333327 self ,
334328 original_inp : Union [Tensor , Tuple [Tensor , ...]],
335- # pyre-fixme[2]: Parameter must be annotated.
336- ** kwargs ,
329+ ** kwargs : object ,
337330 ) -> Generator [Tensor , None , None ]:
338331 r"""
339332 Perturbations are sampled by the following process:
@@ -361,11 +354,13 @@ def kernel_shap_perturb_generator(
361354 device = original_inp .device
362355 else :
363356 device = original_inp [0 ].device
364- num_features = kwargs ["num_interp_features" ]
357+ num_features = cast ( int , kwargs ["num_interp_features" ])
365358 yield torch .ones (1 , num_features , device = device , dtype = torch .long )
366359 yield torch .zeros (1 , num_features , device = device , dtype = torch .long )
367360 while True :
368- num_selected_features = kwargs ["num_select_distribution" ].sample ()
361+ num_selected_features = cast (
362+ Categorical , kwargs ["num_select_distribution" ]
363+ ).sample ()
369364 rand_vals = torch .randn (1 , num_features )
370365 threshold = torch .kthvalue (
371366 rand_vals , num_features - num_selected_features
0 commit comments