Skip to content

Final round of Pyre enablement #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,9 @@ def _expand_target(
return target


# pyre-fixme[3]: Return type must be annotated.
def _expand_feature_mask(
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
):
) -> Tuple[Tensor, ...]:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
# typing.Tuple[Tensor, ...]]`.
is_feature_mask_tuple = _is_tuple(feature_mask)
Expand All @@ -379,18 +378,17 @@ def _expand_feature_mask(
)
for feature_mask_elem in feature_mask
)
return _format_output(is_feature_mask_tuple, feature_mask_new)
return _format_output(is_feature_mask_tuple, feature_mask_new) # type: ignore


# pyre-fixme[3]: Return type must be annotated.
def _expand_and_update_baselines(
inputs: Tuple[Tensor, ...],
n_samples: int,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
kwargs: dict,
draw_baseline_from_distrib: bool = False,
):
) -> None:
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_random_baseline_indices(bsz, baseline):
Expand Down Expand Up @@ -432,10 +430,9 @@ def get_random_baseline_indices(bsz, baseline):
kwargs["baselines"] = baselines


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict) -> None:
if "additional_forward_args" not in kwargs:
return
additional_forward_args = kwargs["additional_forward_args"]
Expand All @@ -451,10 +448,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
kwargs["additional_forward_args"] = additional_forward_args


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_target(n_samples: int, kwargs: dict):
def _expand_and_update_target(n_samples: int, kwargs: dict) -> None:
if "target" not in kwargs:
return
target = kwargs["target"]
Expand All @@ -465,10 +461,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
kwargs["target"] = target


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict) -> None:
if "feature_mask" not in kwargs:
return

Expand Down Expand Up @@ -573,10 +568,9 @@ def _format_outputs(
# pyre-fixme[24] Callable requires 2 arguments
def _construct_future_forward(original_forward: Callable) -> Callable:
# pyre-fixme[3] return type not specified
# pyre-ignore
def future_forward(*args, **kwargs):
# pyre-ignore
fut = torch.futures.Future()
def future_forward(*args: Any, **kwargs: Any):
# pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
fut: torch.futures.Future[Tensor] = torch.futures.Future()
fut.set_result(original_forward(*args, **kwargs))
return fut

Expand Down Expand Up @@ -921,8 +915,7 @@ def input_tensor_hook(input_grad: Tensor):
]


# pyre-fixme[3]: Return type must be annotated.
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]):
def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
"""
Returns the max feature mask index
The feature mask should be formatted to tuple of tensors at first.
Expand Down
6 changes: 5 additions & 1 deletion captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,11 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
out = loss

sample_grad_wrapper.compute_param_sample_gradients(
out, loss_mode=reduction_type
out,
# pyre-fixme[6]: In call `SampleGradientWrapper.
# compute_param_sample_gradients`, for argument `loss_mode`,
# expected `str` but got `Optional[str]`.
loss_mode=reduction_type, # type: ignore
)
if layer_modules is not None:
layer_parameters = _extract_parameters_from_layers(layer_modules)
Expand Down
3 changes: 1 addition & 2 deletions captum/_utils/models/linear_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, train_fn: Callable, **kwargs) -> None:
# pyre-fixme[4]: Attribute must be annotated.
self.construct_kwargs = kwargs

# pyre-fixme[3]: Return type must be annotated.
def _construct_model_params(
self,
in_features: Optional[int] = None,
Expand All @@ -52,7 +51,7 @@ def _construct_model_params(
weight_values: Optional[Tensor] = None,
bias_value: Optional[Tensor] = None,
classes: Optional[Tensor] = None,
):
) -> None:
r"""
Lazily initializes a linear model. This will be called for you in a
train method.
Expand Down
8 changes: 3 additions & 5 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from torch.utils.data import DataLoader


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def l2_loss(x1, x2, weights=None):
def l2_loss(x1, x2, weights=None) -> torch.Tensor:
if weights is None:
return torch.mean((x1 - x2) ** 2) / 2.0
else:
Expand Down Expand Up @@ -236,7 +235,7 @@ def get_point(datapoint):

class NormLayer(nn.Module):
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, mean, std, n=None, eps=1e-8) -> None:
def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None:
super().__init__()
# pyre-fixme[4]: Attribute must be annotated.
self.mean = mean
Expand All @@ -251,7 +250,6 @@ def forward(self, x):
return (x - self.mean) / (self.std + self.eps)


# pyre-fixme[3]: Return type must be annotated.
def sklearn_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
Expand All @@ -260,7 +258,7 @@ def sklearn_train_linear_model(
norm_input: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**fit_kwargs,
):
) -> Dict[str, float]:
r"""
Alternative method to train with sklearn. This does introduce some slight
overhead as we convert the tensors to numpy and then convert the resulting
Expand Down
37 changes: 16 additions & 21 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import warnings
from time import time
from typing import cast, Iterable, Optional, Sized, TextIO
from typing import Any, cast, Iterable, Optional, Sized, TextIO

from captum._utils.typing import Literal

Expand Down Expand Up @@ -61,15 +61,17 @@ class NullProgress:
progress bars.
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs):
def __init__(
self,
# pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter.
iterable: Optional[Iterable] = None,
*args: Any,
**kwargs: Any,
) -> None:
del args, kwargs
self.iterable = iterable

# pyre-fixme[3]: Return type must be annotated.
def __enter__(self):
def __enter__(self) -> "NullProgress":
return self

# pyre-fixme[2]: Parameter must be annotated.
Expand All @@ -87,12 +89,10 @@ def __iter__(self):
for it in self.iterable:
yield it

# pyre-fixme[3]: Return type must be annotated.
def update(self, amount: int = 1):
def update(self, amount: int = 1) -> None:
pass

# pyre-fixme[3]: Return type must be annotated.
def close(self):
def close(self) -> None:
pass


Expand Down Expand Up @@ -133,8 +133,7 @@ def __init__(
self.closed = False
self._is_parent = False

# pyre-fixme[3]: Return type must be annotated.
def __enter__(self):
def __enter__(self) -> "SimpleProgress":
self._is_parent = True
self._refresh()
return self
Expand All @@ -158,8 +157,7 @@ def __iter__(self):
self.update()
self.close()

# pyre-fixme[3]: Return type must be annotated.
def _refresh(self):
def _refresh(self) -> None:
progress_str = self.desc + ": " if self.desc else ""
if self.total:
# e.g., progress: 60% 3/5
Expand All @@ -172,8 +170,7 @@ def _refresh(self):
end = "\n" if self._is_parent else ""
print("\r" + progress_str, end=end, file=self.file)

# pyre-fixme[3]: Return type must be annotated.
def update(self, amount: int = 1):
def update(self, amount: int = 1) -> None:
if self.closed:
return
self.cur += amount
Expand All @@ -183,8 +180,7 @@ def update(self, amount: int = 1):
self._refresh()
self.last_print_t = cur_t

# pyre-fixme[3]: Return type must be annotated.
def close(self):
def close(self) -> None:
if not self.closed and not self._is_parent:
self._refresh()
print(file=self.file) # end with new line
Expand All @@ -197,8 +193,7 @@ def progress(
iterable: Optional[Iterable] = None,
desc: Optional[str] = None,
total: Optional[int] = None,
# pyre-fixme[2]: Parameter must be annotated.
use_tqdm=True,
use_tqdm: bool = True,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
# pyre-fixme[2]: Parameter must be annotated.
Expand Down
7 changes: 4 additions & 3 deletions captum/_utils/sample_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class SampleGradientWrapper:
"""

# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, model, layer_modules=None) -> None:
def __init__(self, model, layer_modules: Optional[List[Module]] = None) -> None:
# pyre-fixme[4]: Attribute must be annotated.
self.model = model
self.hooks_added = False
Expand Down Expand Up @@ -162,8 +162,9 @@ def _reset(self) -> None:
self.activation_dict = defaultdict(list)
self.gradient_dict = defaultdict(list)

# pyre-fixme[2]: Parameter must be annotated.
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None:
def compute_param_sample_gradients(
self, loss_blob: Tensor, loss_mode: str = "mean"
) -> None:
assert (
loss_mode.upper() in LossMode.__members__
), f"Provided loss mode {loss_mode} is not valid"
Expand Down
12 changes: 4 additions & 8 deletions captum/attr/_core/dataloader_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ class InputRole:


# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _concat_tensors(accum, cur_output, _):
def _concat_tensors(accum, cur_output, _) -> Tensor:
return cur_output if accum is None else torch.cat([accum, cur_output])


Expand Down Expand Up @@ -185,7 +184,6 @@ def __init__(self, attr_method: Attribution) -> None:

self.attr_method.forward_func = self._forward_with_dataloader

# pyre-fixme[3]: Return type must be annotated.
def _forward_with_dataloader(
self,
batched_perturbed_feature_indices: Tensor,
Expand All @@ -199,7 +197,7 @@ def _forward_with_dataloader(
to_metric: Optional[Callable],
show_progress: bool,
feature_idx_to_mask_idx: Dict[int, List[int]],
):
) -> Tensor:
"""
Wrapper of the original given forward_func to be used in the attribution method
It iterates over the dataloader with the given forward_func
Expand Down Expand Up @@ -468,10 +466,8 @@ def attribute(

return _format_output(is_inputs_tuple, attr)

# pyre-fixme[3]: Return type must be annotated.
def attribute_future(
self,
):
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
def attribute_future(self) -> Callable:
r"""
This method is not implemented for DataLoaderAttribution.
"""
Expand Down
Loading
Loading