Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions captum/_utils/av.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

# pyre-strict

import glob
import os
import re
Expand Down Expand Up @@ -66,12 +68,14 @@ def __init__(
which the activation vectors are computed
"""

# pyre-fixme[4]: Attribute must be annotated.
self.av_filesearch = AV._construct_file_search(
path, model_id, identifier, layer, num_id
)

files = glob.glob(self.av_filesearch)

# pyre-fixme[4]: Attribute must be annotated.
self.files = AV.sort_files(files)

def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, ...]]:
Expand Down Expand Up @@ -346,6 +350,7 @@ def _compute_and_save_activations(
inputs: Union[Tensor, Tuple[Tensor, ...]],
identifier: str,
num_id: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
load_from_disk: bool = True,
) -> None:
Expand Down Expand Up @@ -395,6 +400,8 @@ def _compute_and_save_activations(
AV.save(path, model_id, identifier, unsaved_layers, new_activations, num_id)

@staticmethod
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def _unpack_data(data: Union[Any, Tuple[Any, Any]]) -> Any:
r"""
Helper to extract input from labels when getting items from a Dataset. Assumes
Expand Down Expand Up @@ -490,6 +497,8 @@ def sort_files(files: List[str]) -> List[str]:
lexigraphical sort.
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def split_alphanum(s):
r"""
Splits string into a list of strings and numbers
Expand Down
98 changes: 93 additions & 5 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3

# pyre-strict
import typing
from enum import Enum
from functools import reduce
Expand Down Expand Up @@ -68,10 +70,18 @@ def safe_div(


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...


Expand Down Expand Up @@ -230,6 +240,8 @@ def _format_tensor_into_tuples(
return inputs


# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
return (
inputs
Expand Down Expand Up @@ -257,16 +269,21 @@ def _format_additional_forward_args(additional_forward_args: None) -> None: ...

@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple]
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...


@overload
def _format_additional_forward_args(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]: ...


# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
if additional_forward_args is not None and not isinstance(
additional_forward_args, tuple
Expand All @@ -276,9 +293,11 @@ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None,


def _expand_additional_forward_args(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
n_steps: int,
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]:
def _expand_tensor_forward_arg(
additional_forward_arg: Tensor,
Expand Down Expand Up @@ -343,9 +362,12 @@ def _expand_target(
return target


# pyre-fixme[3]: Return type must be annotated.
def _expand_feature_mask(
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
):
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
# typing.Tuple[Tensor, ...]]`.
is_feature_mask_tuple = _is_tuple(feature_mask)
feature_mask = _format_tensor_into_tuples(feature_mask)
feature_mask_new = tuple(
Expand All @@ -359,12 +381,17 @@ def _expand_feature_mask(
return _format_output(is_feature_mask_tuple, feature_mask_new)


# pyre-fixme[3]: Return type must be annotated.
def _expand_and_update_baselines(
inputs: Tuple[Tensor, ...],
n_samples: int,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
kwargs: dict,
draw_baseline_from_distrib: bool = False,
):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_random_baseline_indices(bsz, baseline):
num_ref_samples = baseline.shape[0]
return np.random.choice(num_ref_samples, n_samples * bsz).tolist()
Expand Down Expand Up @@ -404,6 +431,9 @@ def get_random_baseline_indices(bsz, baseline):
kwargs["baselines"] = baselines


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
if "additional_forward_args" not in kwargs:
return
Expand All @@ -420,6 +450,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
kwargs["additional_forward_args"] = additional_forward_args


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_target(n_samples: int, kwargs: dict):
if "target" not in kwargs:
return
Expand All @@ -431,6 +464,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
kwargs["target"] = target


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
if "feature_mask" not in kwargs:
return
Expand All @@ -444,14 +480,24 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `449`.
def _format_output(
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[True],
output: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_output` does not accept all
# possible arguments of overload defined on line `455`.
def _format_output(
is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...]
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_inputs_tuple: Literal[False],
output: Tuple[Tensor, ...],
) -> Tensor: ...


Expand All @@ -474,18 +520,30 @@ def _format_output(
"The input is a single tensor however the output isn't."
"The number of output tensors is: {}".format(len(output))
)
# pyre-fixme[7]: Expected `Union[Tensor, typing.Tuple[Tensor, ...]]` but got
# `Union[tuple[Tensor], Tensor]`.
return output if is_inputs_tuple else output[0]


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `483`.
def _format_outputs(
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[False],
outputs: List[Tuple[Tensor, ...]],
) -> Union[Tensor, Tuple[Tensor, ...]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
# possible arguments of overload defined on line `489`.
def _format_outputs(
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
is_multiple_inputs: Literal[True],
outputs: List[Tuple[Tensor, ...]],
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: ...


Expand All @@ -512,9 +570,12 @@ def _format_outputs(


def _run_forward(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
inputs: Any,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
) -> Union[Tensor, Future[Tensor]]:
forward_func_args = signature(forward_func).parameters
Expand All @@ -529,6 +590,8 @@ def _run_forward(

output = forward_func(
*(
# pyre-fixme[60]: Concatenation not yet support for multiple variadic
# tuples: `*inputs, *additional_forward_args`.
(*inputs, *additional_forward_args)
if additional_forward_args is not None
else inputs
Expand Down Expand Up @@ -606,6 +669,8 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
elif isinstance(target[0], tuple):
return torch.stack(
[
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type
# parameter.
output[(i,) + cast(Tuple, targ_elem)]
for i, targ_elem in enumerate(target)
]
Expand Down Expand Up @@ -639,9 +704,11 @@ def _verify_select_column(

def _verify_select_neuron(
layer_output: Tuple[Tensor, ...],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
) -> Tensor:
if callable(selector):
# pyre-fixme[7]: Expected `Tensor` but got `object`.
return selector(layer_output if len(layer_output) > 1 else layer_output[0])

assert len(layer_output) == 1, (
Expand Down Expand Up @@ -688,6 +755,9 @@ def _extract_device(

def _reduce_list(
val_list: Sequence[TupleOrTensorOrBoolGeneric],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List[<element type>]` to avoid runtime subscripting errors.
red_func: Callable[[List], Any] = torch.cat,
) -> TupleOrTensorOrBoolGeneric:
"""
Expand All @@ -702,21 +772,28 @@ def _reduce_list(
"""
assert len(val_list) > 0, "Cannot reduce empty list!"
if isinstance(val_list[0], torch.Tensor):
# pyre-fixme[16]: `bool` has no attribute `device`.
first_device = val_list[0].device
# pyre-fixme[16]: `bool` has no attribute `to`.
return red_func([elem.to(first_device) for elem in val_list])
elif isinstance(val_list[0], bool):
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
return any(val_list)
elif isinstance(val_list[0], tuple):
final_out = []
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
# but got `TupleOrTensorOrBoolGeneric`.
for i in range(len(val_list[0])):
final_out.append(
# pyre-fixme[16]: `bool` has no attribute `__getitem__`.
_reduce_list([val_elem[i] for val_elem in val_list], red_func)
)
else:
raise AssertionError(
"Elements to be reduced can only be"
"either Tensors or tuples containing Tensors."
)
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`.
return tuple(final_out)


Expand Down Expand Up @@ -756,6 +833,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
return torch.cat([single_inp.flatten() for single_inp in inp])


# pyre-fixme[3]: Return annotation cannot be `Any`.
def _get_module_from_name(model: Module, layer_name: str) -> Any:
r"""
Returns the module (layer) object, given its (string) name
Expand All @@ -772,7 +850,11 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any:


def _register_backward_hook(
module: Module, hook: Callable, attr_obj: Any
module: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
hook: Callable,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
attr_obj: Any,
) -> List[torch.utils.hooks.RemovableHandle]:
grad_out: Dict[device, Tensor] = {}

Expand All @@ -784,6 +866,7 @@ def forward_hook(
nonlocal grad_out
grad_out = {}

# pyre-fixme[53]: Captured variable `grad_out` is not annotated.
def output_tensor_hook(output_grad: Tensor) -> None:
grad_out[output_grad.device] = output_grad

Expand All @@ -795,7 +878,11 @@ def output_tensor_hook(output_grad: Tensor) -> None:
else:
out.register_hook(output_tensor_hook)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def pre_hook(module, inp):
# pyre-fixme[53]: Captured variable `module` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
def input_tensor_hook(input_grad: Tensor):
if len(grad_out) == 0:
return
Expand All @@ -820,6 +907,7 @@ def input_tensor_hook(input_grad: Tensor):
]


# pyre-fixme[3]: Return type must be annotated.
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]):
"""
Returns the max feature mask index
Expand Down
Loading