Skip to content

Commit 5eb5498

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Fix Black version and OSS Failures (#1241)
Summary: Currently, OSS GitHub Actions tests are failing due to failing test, lint and typing issues. This updates the black version used externally (and corresponding python version to support latest black) to match the internal updates in D54447730 and also updates flake8 settings to avoid incompatibilities. Typing issues are also resolved and imports from torch._tensor are removed, since these are not supported for previous torch versions. Pull Request resolved: #1241 Reviewed By: cyrjano Differential Revision: D54901754 Pulled By: vivekmig fbshipit-source-id: 2b94bf36488b11b6c145175cfe10fc5433b014fe
1 parent 837168f commit 5eb5498

27 files changed

+78
-65
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
1414
with:
1515
runner: linux.12xlarge
16-
docker-image: cimg/python:3.6
16+
docker-image: cimg/python:3.9
1717
repository: pytorch/captum
1818
script: |
1919
sudo chmod -R 777 .

captum/_utils/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from functools import reduce
55
from inspect import signature
6-
from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
77

88
import numpy as np
99
import torch
@@ -683,7 +683,7 @@ def _extract_device(
683683

684684

685685
def _reduce_list(
686-
val_list: List[TupleOrTensorOrBoolGeneric],
686+
val_list: Sequence[TupleOrTensorOrBoolGeneric],
687687
red_func: Callable[[List], Any] = torch.cat,
688688
) -> TupleOrTensorOrBoolGeneric:
689689
"""

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase):
8282
def __init__(
8383
self,
8484
model: Module,
85-
final_fc_layer: Union[Module, str],
85+
final_fc_layer: Module,
8686
train_dataset: Union[Dataset, DataLoader],
8787
checkpoints: Union[str, List[str], Iterator],
8888
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -96,11 +96,9 @@ def __init__(
9696
9797
model (torch.nn.Module): An instance of pytorch model. This model should
9898
define all of its layers as attributes of the model.
99-
final_fc_layer (torch.nn.Module or str): The last fully connected layer in
99+
final_fc_layer (torch.nn.Module): The last fully connected layer in
100100
the network for which gradients will be approximated via fast random
101-
projection method. Can be either the layer module itself, or the
102-
fully qualified name of the layer if it is a defined attribute of
103-
the passed `model`.
101+
projection method.
104102
train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader):
105103
In the `influence` method, we compute the influence score of
106104
training examples on examples in a test batch.
@@ -869,7 +867,7 @@ class TracInCPFastRandProj(TracInCPFast):
869867
def __init__(
870868
self,
871869
model: Module,
872-
final_fc_layer: Union[Module, str],
870+
final_fc_layer: Module,
873871
train_dataset: Union[Dataset, DataLoader],
874872
checkpoints: Union[str, List[str], Iterator],
875873
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -886,11 +884,9 @@ def __init__(
886884
887885
model (torch.nn.Module): An instance of pytorch model. This model should
888886
define all of its layers as attributes of the model.
889-
final_fc_layer (torch.nn.Module or str): The last fully connected layer in
887+
final_fc_layer (torch.nn.Module): The last fully connected layer in
890888
the network for which gradients will be approximated via fast random
891-
projection method. Can be either the layer module itself, or the
892-
fully qualified name of the layer if it is a defined attribute of
893-
the passed `model`.
889+
projection method.
894890
train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader):
895891
In the `influence` method, we compute the influence score of
896892
training examples on examples in a test batch.

captum/insights/attr_vis/attribution_calculation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def calculate_attribution(
131131
)
132132
if "baselines" in inspect.signature(attribution_method.attribute).parameters:
133133
attribution_arguments["baselines"] = baseline
134-
attr = attribution_method.attribute.__wrapped__(
134+
attr = attribution_method.attribute.__wrapped__( # type: ignore
135135
attribution_method, # self
136136
data,
137137
additional_forward_args=additional_forward_args,

captum/insights/attr_vis/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from captum._utils.common import safe_div
99
from captum.attr._utils import visualization as viz
1010
from captum.insights.attr_vis._utils.transforms import format_transforms
11-
from torch._tensor import Tensor
11+
from torch import Tensor
1212

1313
FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution")
1414

captum/insights/attr_vis/server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
import threading
66
from time import sleep
7-
from typing import Optional
7+
from typing import cast, Dict, Optional
88

99
from captum.log import log_usage
1010
from flask import Flask, jsonify, render_template, request
@@ -41,10 +41,10 @@ def namedtuple_to_dict(obj):
4141
def attribute() -> Response:
4242
# force=True needed for Colab notebooks, which doesn't use the correct
4343
# Content-Type header when forwarding requests through the Colab proxy
44-
r = request.get_json(force=True)
44+
r = cast(Dict, request.get_json(force=True))
4545
return jsonify(
4646
namedtuple_to_dict(
47-
visualizer._calculate_attribution_from_cache(
47+
visualizer._calculate_attribution_from_cache( # type: ignore
4848
r["inputIndex"], r["modelIndex"], r["labelIndex"]
4949
)
5050
)
@@ -54,15 +54,15 @@ def attribute() -> Response:
5454
@app.route("/fetch", methods=["POST"])
5555
def fetch() -> Response:
5656
# force=True needed, see comment for "/attribute" route above
57-
visualizer._update_config(request.get_json(force=True))
58-
visualizer_output = visualizer.visualize()
57+
visualizer._update_config(request.get_json(force=True)) # type: ignore
58+
visualizer_output = visualizer.visualize() # type: ignore
5959
clean_output = namedtuple_to_dict(visualizer_output)
6060
return jsonify(clean_output)
6161

6262

6363
@app.route("/init")
6464
def init() -> Response:
65-
return jsonify(visualizer.get_insights_config())
65+
return jsonify(visualizer.get_insights_config()) # type: ignore
6666

6767

6868
@app.route("/")

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
# E203: black and flake8 disagree on whitespace before ':'
33
# W503: black and flake8 disagree on how to place operators
4-
ignore = E203, W503
4+
ignore = E203, W503, E704
55
max-line-length = 88
66
exclude =
77
build, dist, tutorials, website

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def report(*args):
6767
INSIGHTS_REQUIRES
6868
+ TEST_REQUIRES
6969
+ [
70-
"black==22.3.0",
70+
"black",
7171
"flake8",
7272
"sphinx",
7373
"sphinx-autodoc-typehints",

tests/attr/helpers/conductance_reference.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Optional, Tuple
2+
from typing import cast, Tuple, Union
33

44
import numpy as np
55
import torch
@@ -10,7 +10,8 @@
1010
from captum.attr._utils.approximation_methods import approximation_parameters
1111
from captum.attr._utils.attribution import LayerAttribution
1212
from captum.attr._utils.common import _reshape_and_sum
13-
from torch._tensor import Tensor
13+
from torch import Tensor
14+
from torch.utils.hooks import RemovableHandle
1415

1516
"""
1617
Note: This implementation of conductance follows the procedure described in the original
@@ -55,7 +56,7 @@ def forward_hook(module, inp, out):
5556
# The hidden layer tensor is assumed to have dimension (num_hidden, ...)
5657
# where the product of the dimensions >= 1 correspond to the total
5758
# number of hidden neurons in the layer.
58-
layer_size = tuple(saved_tensor.size())[1:]
59+
layer_size = tuple(cast(Tensor, saved_tensor).size())[1:]
5960
layer_units = int(np.prod(layer_size))
6061

6162
# Remove unnecessary forward hook.
@@ -101,12 +102,12 @@ def forward_hook_register_back(module, inp, out):
101102
input_grads = torch.autograd.grad(torch.unbind(output), expanded_input)
102103

103104
# Remove backwards hook
104-
back_hook.remove()
105+
cast(RemovableHandle, back_hook).remove()
105106

106107
# Remove duplicates in gradient with respect to hidden layer,
107108
# choose one for each layer_units indices.
108109
output_mid_grads = torch.index_select(
109-
saved_grads,
110+
cast(Tensor, saved_grads),
110111
0,
111112
torch.tensor(range(0, input_grads[0].shape[0], layer_units)),
112113
)
@@ -115,7 +116,7 @@ def forward_hook_register_back(module, inp, out):
115116
def attribute(
116117
self,
117118
inputs,
118-
baselines: Optional[int] = None,
119+
baselines: Union[None, int, Tensor] = None,
119120
target=None,
120121
n_steps: int = 500,
121122
method: str = "riemann_trapezoid",

tests/attr/layer/test_layer_lrp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
# mypy: ignore-errors
23

34
from typing import Any, Tuple
45

@@ -9,7 +10,7 @@
910

1011
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1112
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
12-
from torch._tensor import Tensor
13+
from torch import Tensor
1314

1415

1516
def _get_basic_config() -> Tuple[BasicModel_ConvNet_One_Conv, Tensor]:

tests/attr/models/test_pytext.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import os
66
import tempfile
77
import unittest
8-
from typing import Dict, List, NoReturn, Optional
8+
from typing import Dict, List
99

1010
import torch
11-
from pytext.data.data_handler import CommonMetadata
1211

1312
HAS_PYTEXT = True
1413
try:
@@ -20,6 +19,7 @@
2019
from pytext.config.component import create_featurizer, create_model
2120
from pytext.config.doc_classification import ModelInputConfig, TargetConfig
2221
from pytext.config.field_config import FeatureConfig, WordFeatConfig
22+
from pytext.data.data_handler import CommonMetadata
2323
from pytext.data.doc_classification_data_handler import ( # @manual=//pytext:main_lib # noqa
2424
DocClassificationDataHandler,
2525
)
@@ -43,7 +43,7 @@ def __init__(self) -> None:
4343

4444

4545
class TestWordEmbeddings(unittest.TestCase):
46-
def setUp(self) -> Optional[NoReturn]:
46+
def setUp(self) -> None:
4747
if not HAS_PYTEXT:
4848
return self.skipTest("Skip the test since PyText is not installed")
4949

@@ -143,7 +143,7 @@ def _create_dummy_model(self):
143143
self._create_dummy_meta_data(),
144144
)
145145

146-
def _create_dummy_meta_data(self) -> CommonMetadata:
146+
def _create_dummy_meta_data(self):
147147
text_field_meta = FieldMeta()
148148
text_field_meta.vocab = VocabStub()
149149
text_field_meta.vocab_size = 4

tests/attr/test_class_summarizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
from typing import List
3+
24
import torch
35
from captum.attr import ClassSummarizer, CommonStats
46
from tests.helpers.basic import BaseTest
@@ -45,7 +47,7 @@ def test_classes(self) -> None:
4547
((3, 2, 10, 3), (1,)),
4648
# ((20,),),
4749
]
48-
list_of_classes = [
50+
list_of_classes: List[List] = [
4951
list(range(100)),
5052
["%d" % i for i in range(100)],
5153
list(range(300, 400)),

tests/attr/test_deeplift_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def softmax_classification(
155155
target: TargetType,
156156
) -> None:
157157
# TODO add test cases for multiple different layers
158+
if isinstance(attr_method, DeepLiftShap):
159+
assert isinstance(
160+
baselines, Tensor
161+
), "Non-tensor baseline not supported for DeepLiftShap"
162+
158163
model.zero_grad()
159164
attributions, delta = attr_method.attribute(
160165
input, baselines=baselines, target=target, return_convergence_delta=True

tests/attr/test_guided_grad_cam.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env python3
22

33
import unittest
4-
from typing import Any
4+
from typing import Any, List, Tuple, Union
55

66
import torch
77
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
88
from captum.attr._core.guided_grad_cam import GuidedGradCam
99
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1010
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv
11-
from torch._tensor import Tensor
11+
from torch import Tensor
1212
from torch.nn import Module
1313

1414

@@ -107,7 +107,7 @@ def _guided_grad_cam_test_assert(
107107
model: Module,
108108
target_layer: Module,
109109
test_input: TensorOrTupleOfTensorsGeneric,
110-
expected: Tensor,
110+
expected: Union[Tensor, List, Tuple],
111111
additional_input: Any = None,
112112
interpolate_mode: str = "nearest",
113113
attribute_to_layer_input: bool = False,

tests/attr/test_input_layer_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
BasicModel_MultiLayer_TrueMultiInput,
2828
MixedKwargsAndArgsModule,
2929
)
30+
from torch.nn import Module
3031

3132
layer_methods_to_test_with_equiv = [
3233
# layer_method, equiv_method, whether or not to use multiple layers
@@ -115,7 +116,7 @@ def layer_method_with_input_layer_patches(
115116
assertTensorTuplesAlmostEqual(self, a1, real_attributions)
116117

117118
def forward_eval_layer_with_inputs_helper(
118-
self, model: ModelInputWrapper, inputs_to_test
119+
self, model: Module, inputs_to_test
119120
) -> None:
120121
# hard coding for simplicity
121122
# 0 if using args, 1 if using kwargs

tests/attr/test_interpretable_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
55
from parameterized import parameterized
66
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
7-
from torch._tensor import Tensor
7+
from torch import Tensor
88

99

1010
class DummyTokenizer:

tests/attr/test_lime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _lime_test_assert(
494494
model: Callable,
495495
test_input: TensorOrTupleOfTensorsGeneric,
496496
expected_attr,
497-
expected_coefs_only: Optional[Tensor] = None,
497+
expected_coefs_only: Union[None, List, Tensor] = None,
498498
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
499499
additional_input: Any = None,
500500
perturbations_per_eval: Tuple[int, ...] = (1,),

tests/attr/test_stat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import random
3+
from typing import Callable, List
34

45
import torch
56
from captum.attr import Max, Mean, Min, MSE, StdDev, Sum, Summarizer, Var
@@ -140,7 +141,7 @@ def test_stats_random_data(self) -> None:
140141
"sum",
141142
"mse",
142143
]
143-
gt_fns = [
144+
gt_fns: List[Callable] = [
144145
torch.mean,
145146
lambda x: torch.var(x, unbiased=False),
146147
lambda x: torch.var(x, unbiased=True),

tests/concept/test_tcav.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Iterator,
1515
List,
1616
Set,
17-
SupportsIndex,
1817
Tuple,
1918
Union,
2019
)
@@ -174,7 +173,7 @@ def __init__(
174173
self,
175174
get_tensor_from_filename_func: Callable,
176175
path: str,
177-
num_samples: SupportsIndex = 100,
176+
num_samples: int = 100,
178177
) -> None:
179178
r"""
180179
Args:

tests/helpers/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99
from captum.log import patch_methods
10-
from torch._tensor import Tensor
1110

1211

1312
def deep_copy_args(func: Callable):
@@ -21,7 +20,7 @@ def copy_args(*args, **kwargs):
2120

2221

2322
def assertTensorAlmostEqual(
24-
test, actual: Tensor, expected: Tensor, delta: float = 0.0001, mode: str = "sum"
23+
test, actual, expected, delta: float = 0.0001, mode: str = "sum"
2524
) -> None:
2625
assert isinstance(actual, torch.Tensor), (
2726
"Actual parameter given for " "comparison must be a tensor."

0 commit comments

Comments
 (0)