Skip to content

Commit de2b0b5

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix pyre errors in Lime (#1400)
Summary: Initial work on fixing Pyre errors in Lime Reviewed By: csauper Differential Revision: D64677340
1 parent 53c19e4 commit de2b0b5

File tree

1 file changed

+43
-57
lines changed

1 file changed

+43
-57
lines changed

captum/attr/_core/lime.py

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import typing
77
import warnings
88
from collections.abc import Iterator
9-
from typing import Any, Callable, cast, List, Optional, Tuple, Union
9+
from typing import Any, Callable, cast, Generator, List, Literal, Optional, Tuple, Union
1010

1111
import torch
1212
from captum._utils.common import (
@@ -23,12 +23,7 @@
2323
from captum._utils.models.linear_model import SkLearnLasso
2424
from captum._utils.models.model import Model
2525
from captum._utils.progress import progress
26-
from captum._utils.typing import (
27-
BaselineType,
28-
Literal,
29-
TargetType,
30-
TensorOrTupleOfTensorsGeneric,
31-
)
26+
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
3227
from captum.attr._utils.attribution import PerturbationAttribution
3328
from captum.attr._utils.batching import _batch_example_iterator
3429
from captum.attr._utils.common import (
@@ -73,18 +68,18 @@ class LimeBase(PerturbationAttribution):
7368

7469
def __init__(
7570
self,
76-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
77-
forward_func: Callable,
71+
forward_func: Callable[..., Tensor],
7872
interpretable_model: Model,
79-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
80-
similarity_func: Callable,
81-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
82-
perturb_func: Callable,
73+
similarity_func: Callable[
74+
...,
75+
Union[float, Tensor],
76+
],
77+
perturb_func: Callable[..., object],
8378
perturb_interpretable_space: bool,
84-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
85-
from_interp_rep_transform: Optional[Callable],
86-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
87-
to_interp_rep_transform: Optional[Callable],
79+
from_interp_rep_transform: Optional[
80+
Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
81+
],
82+
to_interp_rep_transform: Optional[Callable[..., Tensor]],
8883
) -> None:
8984
r"""
9085
@@ -249,13 +244,11 @@ def attribute(
249244
self,
250245
inputs: TensorOrTupleOfTensorsGeneric,
251246
target: TargetType = None,
252-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
253-
additional_forward_args: Any = None,
247+
additional_forward_args: object = None,
254248
n_samples: int = 50,
255249
perturbations_per_eval: int = 1,
256250
show_progress: bool = False,
257-
# pyre-fixme[2]: Parameter must be annotated.
258-
**kwargs,
251+
**kwargs: object,
259252
) -> Tensor:
260253
r"""
261254
This method attributes the output of the model with given target index
@@ -551,7 +544,7 @@ def generate_perturbation() -> (
551544
curr_sample, inputs, **kwargs
552545
)
553546

554-
return interpretable_inp, curr_model_input
547+
return interpretable_inp, curr_model_input # type: ignore
555548

556549
return generate_perturbation
557550

@@ -568,8 +561,7 @@ def _evaluate_batch(
568561
self,
569562
curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
570563
expanded_target: TargetType,
571-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
572-
expanded_additional_args: Any,
564+
expanded_additional_args: object,
573565
device: torch.device,
574566
) -> Tensor:
575567
model_out = _run_forward(
@@ -630,8 +622,7 @@ def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
630622
def get_exp_kernel_similarity_function(
631623
distance_mode: str = "cosine",
632624
kernel_width: float = 1.0,
633-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
634-
) -> Callable:
625+
) -> Callable[..., float]:
635626
r"""
636627
This method constructs an appropriate similarity function to compute
637628
weights for perturbed sample in LIME. Distance between the original
@@ -680,8 +671,9 @@ def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
680671
return default_exp_kernel
681672

682673

683-
# pyre-fixme[2]: Parameter must be annotated.
684-
def default_perturb_func(original_inp, **kwargs) -> Tensor:
674+
def default_perturb_func(
675+
original_inp: TensorOrTupleOfTensorsGeneric, **kwargs: object
676+
) -> Tensor:
685677
assert (
686678
"num_interp_features" in kwargs
687679
), "Must provide num_interp_features to use default interpretable sampling function"
@@ -690,25 +682,25 @@ def default_perturb_func(original_inp, **kwargs) -> Tensor:
690682
else:
691683
device = original_inp[0].device
692684

693-
probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
685+
probs = torch.ones(1, cast(int, kwargs["num_interp_features"])) * 0.5
694686
return torch.bernoulli(probs).to(device=device).long()
695687

696688

697689
def construct_feature_mask(
698690
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
699691
formatted_inputs: Tuple[Tensor, ...],
700692
) -> Tuple[Tuple[Tensor, ...], int]:
693+
feature_mask_tuple: Tuple[Tensor, ...]
701694
if feature_mask is None:
702-
feature_mask, num_interp_features = _construct_default_feature_mask(
695+
feature_mask_tuple, num_interp_features = _construct_default_feature_mask(
703696
formatted_inputs
704697
)
705698
else:
706-
feature_mask = _format_tensor_into_tuples(feature_mask)
699+
feature_mask_tuple = _format_tensor_into_tuples(feature_mask)
707700
min_interp_features = int(
708701
min(
709702
torch.min(single_mask).item()
710-
# pyre-fixme[16]: `None` has no attribute `__iter__`.
711-
for single_mask in feature_mask
703+
for single_mask in feature_mask_tuple
712704
if single_mask.numel()
713705
)
714706
)
@@ -718,14 +710,12 @@ def construct_feature_mask(
718710
" start at 0.",
719711
stacklevel=2,
720712
)
721-
feature_mask = tuple(
722-
single_mask - min_interp_features for single_mask in feature_mask
713+
feature_mask_tuple = tuple(
714+
single_mask - min_interp_features for single_mask in feature_mask_tuple
723715
)
724716

725-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
726-
# `Optional[typing.Tuple[typing.Any, ...]]`.
727-
num_interp_features = _get_max_feature_index(feature_mask) + 1
728-
return feature_mask, num_interp_features
717+
num_interp_features = _get_max_feature_index(feature_mask_tuple) + 1
718+
return feature_mask_tuple, num_interp_features
729719

730720

731721
class Lime(LimeBase):
@@ -766,8 +756,7 @@ class Lime(LimeBase):
766756

767757
def __init__(
768758
self,
769-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
770-
forward_func: Callable,
759+
forward_func: Callable[..., Tensor],
771760
interpretable_model: Optional[Model] = None,
772761
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
773762
similarity_func: Optional[Callable] = None,
@@ -887,8 +876,7 @@ def attribute( # type: ignore
887876
inputs: TensorOrTupleOfTensorsGeneric,
888877
baselines: BaselineType = None,
889878
target: TargetType = None,
890-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
891-
additional_forward_args: Any = None,
879+
additional_forward_args: object = None,
892880
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
893881
n_samples: int = 25,
894882
perturbations_per_eval: int = 1,
@@ -1133,18 +1121,14 @@ def _attribute_kwargs( # type: ignore
11331121
inputs: TensorOrTupleOfTensorsGeneric,
11341122
baselines: BaselineType = None,
11351123
target: TargetType = None,
1136-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
1137-
additional_forward_args: Any = None,
1124+
additional_forward_args: object = None,
11381125
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
11391126
n_samples: int = 25,
11401127
perturbations_per_eval: int = 1,
11411128
return_input_shape: bool = True,
11421129
show_progress: bool = False,
1143-
# pyre-fixme[2]: Parameter must be annotated.
1144-
**kwargs,
1130+
**kwargs: object,
11451131
) -> TensorOrTupleOfTensorsGeneric:
1146-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
1147-
# `TensorOrTupleOfTensorsGeneric`.
11481132
is_inputs_tuple = _is_tuple(inputs)
11491133
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
11501134
bsz = formatted_inputs[0].shape[0]
@@ -1263,33 +1247,35 @@ def _attribute_kwargs( # type: ignore
12631247
return coefs
12641248

12651249
@typing.overload
1266-
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
1267-
# all possible arguments of overload defined on line `1201`.
12681250
def _convert_output_shape(
12691251
self,
12701252
formatted_inp: Tuple[Tensor, ...],
12711253
feature_mask: Tuple[Tensor, ...],
12721254
coefs: Tensor,
12731255
num_interp_features: int,
1274-
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
1275-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
12761256
is_inputs_tuple: Literal[True],
12771257
) -> Tuple[Tensor, ...]: ...
12781258

12791259
@typing.overload
1280-
# pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept
1281-
# all possible arguments of overload defined on line `1211`.
12821260
def _convert_output_shape( # type: ignore
12831261
self,
12841262
formatted_inp: Tuple[Tensor, ...],
12851263
feature_mask: Tuple[Tensor, ...],
12861264
coefs: Tensor,
12871265
num_interp_features: int,
1288-
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
1289-
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
12901266
is_inputs_tuple: Literal[False],
12911267
) -> Tensor: ...
12921268

1269+
@typing.overload
1270+
def _convert_output_shape(
1271+
self,
1272+
formatted_inp: Tuple[Tensor, ...],
1273+
feature_mask: Tuple[Tensor, ...],
1274+
coefs: Tensor,
1275+
num_interp_features: int,
1276+
is_inputs_tuple: bool,
1277+
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
1278+
12931279
def _convert_output_shape(
12941280
self,
12951281
formatted_inp: Tuple[Tensor, ...],

0 commit comments

Comments
 (0)