Skip to content

Commit 4b7879c

Browse files
styusuffacebook-github-bot
authored andcommitted
Moving the Influence test helper (#1484)
Summary: Pull Request resolved: #1484 This Diff is to remove base modules for the testing functionality within Captum. The influence test helper is limited in its scope so it's a smaller Diff. Reviewed By: cyrjano Differential Revision: D68157431 fbshipit-source-id: 90ae5c79521a5b6f4609a33f7f90146ec3c44b58
1 parent d847549 commit 4b7879c

13 files changed

+55
-55
lines changed
File renamed without changes.
File renamed without changes.

tests/influence/_core/test_arnoldi_influence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
_top_eigen,
1818
_unflatten_params_factory,
1919
)
20-
from parameterized import parameterized
21-
from tests.helpers import BaseTest
22-
from tests.helpers.basic import assertTensorAlmostEqual
23-
from tests.helpers.influence.common import (
20+
from captum.testing.helpers.influence.common import (
2421
_format_batch_into_tuple,
2522
build_test_name_func,
2623
DataInfluenceConstructor,
@@ -31,6 +28,9 @@
3128
is_gpu,
3229
UnpackDataset,
3330
)
31+
from parameterized import parameterized
32+
from tests.helpers import BaseTest
33+
from tests.helpers.basic import assertTensorAlmostEqual
3434
from torch import Tensor
3535
from torch.utils.data import DataLoader
3636

tests/influence/_core/test_dataloader.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
TracInCPFast,
1010
TracInCPFastRandProj,
1111
)
12-
from parameterized import parameterized
13-
from tests.helpers import BaseTest
14-
from tests.helpers.basic import assertTensorAlmostEqual
15-
from tests.helpers.influence.common import (
12+
from captum.testing.helpers.influence.common import (
1613
_format_batch_into_tuple,
1714
build_test_name_func,
1815
DataInfluenceConstructor,
1916
get_random_model_and_data,
2017
)
18+
from parameterized import parameterized
19+
from tests.helpers import BaseTest
20+
from tests.helpers.basic import assertTensorAlmostEqual
2121
from torch.utils.data import DataLoader
2222

2323

@@ -32,13 +32,13 @@ class TestTracInDataLoader(BaseTest):
3232
# `comprehension((reduction, constr, unpack_inputs) for
3333
# generators(generator(unpack_inputs in [False, True] if ),
3434
# generators(generator((reduction, constr) in
35-
# [("none", tests.helpers.influence.common.DataInfluenceConstructor
35+
# [("none", captum.testing.helpers.influence.common.DataInfluenceConstructor
3636
# (captum.influence._core.tracincp.TracInCP)),
37-
# ("sum", tests.helpers.influence.common.DataInfluenceConstructor
37+
# ("sum", captum.testing.helpers.influence.common.DataInfluenceConstructor
3838
# (captum.influence._core.tracincp_fast_rand_proj.TracInCPFast)), ("sum",
39-
# tests.helpers.influence.common.DataInfluenceConstructor(captum.influence._core.
39+
# captum.testing.helpers.influence.common.DataInfluenceConstructor(captum.influence._core.
4040
# tracincp_fast_rand_proj.TracInCPFastRandProj)), ("sum",
41-
# tests.helpers.influence.common.DataInfluenceConstructor(
41+
# captum.testing.helpers.influence.common.DataInfluenceConstructor(
4242
# captum.influence._core.tracincp_fast_rand_proj.TracInCPFastRandProj,
4343
# $parameter$name = "TracInCPFastRandProj_1DProj",
4444
# $parameter$projection_dim = 1))] if ))))`

tests/influence/_core/test_naive_influence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
_functional_call,
1313
_unflatten_params_factory,
1414
)
15-
from parameterized import parameterized
16-
from tests.helpers import BaseTest
17-
from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual
18-
from tests.helpers.influence.common import (
15+
from captum.testing.helpers.influence.common import (
1916
_format_batch_into_tuple,
2017
build_test_name_func,
2118
DataInfluenceConstructor,
@@ -25,6 +22,9 @@
2522
Linear,
2623
UnpackDataset,
2724
)
25+
from parameterized import parameterized
26+
from tests.helpers import BaseTest
27+
from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual
2828
from torch.utils.data import DataLoader
2929

3030
# TODO: for some unknow reason, this test does not work

tests/influence/_core/test_tracin_aggregate_influence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
import torch.nn as nn
1111
from captum.influence._core.tracincp import TracInCP
12-
from parameterized import parameterized
13-
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
14-
from tests.helpers.influence.common import (
12+
from captum.testing.helpers.influence.common import (
1513
build_test_name_func,
1614
DataInfluenceConstructor,
1715
get_random_model_and_data,
1816
)
17+
from parameterized import parameterized
18+
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1919
from torch.utils.data import DataLoader
2020

2121

tests/influence/_core/test_tracin_intermediate_quantities.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
TracInCPFast,
1313
TracInCPFastRandProj,
1414
)
15-
from parameterized import parameterized
16-
from tests.helpers import BaseTest
17-
from tests.helpers.basic import assertTensorAlmostEqual
18-
from tests.helpers.influence.common import (
15+
from captum.testing.helpers.influence.common import (
1916
_format_batch_into_tuple,
2017
build_test_name_func,
2118
DataInfluenceConstructor,
2219
get_random_model_and_data,
2320
)
21+
from parameterized import parameterized
22+
from tests.helpers import BaseTest
23+
from tests.helpers.basic import assertTensorAlmostEqual
2424
from torch.utils.data import DataLoader
2525

2626

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@
66
import torch
77
import torch.nn as nn
88
from captum.influence._core.tracincp import TracInCP
9-
10-
from parameterized import parameterized
11-
from tests.helpers import BaseTest
12-
from tests.helpers.basic import assertTensorAlmostEqual
13-
from tests.helpers.influence.common import (
9+
from captum.testing.helpers.influence.common import (
1410
_format_batch_into_tuple,
1511
build_test_name_func,
1612
DataInfluenceConstructor,
@@ -19,6 +15,10 @@
1915
is_gpu,
2016
)
2117

18+
from parameterized import parameterized
19+
from tests.helpers import BaseTest
20+
from tests.helpers.basic import assertTensorAlmostEqual
21+
2222

2323
class TestTracInGetKMostInfluential(BaseTest):
2424
param_list: List[
@@ -76,7 +76,7 @@ class TestTracInGetKMostInfluential(BaseTest):
7676
)
7777

7878
# pyre-fixme[56]: Pyre was not able to infer the type of argument
79-
# `tests.helpers.influence.common.build_test_name_func()`
79+
# `captum.testing.helpers.influence.common.build_test_name_func()`
8080
# to decorator factory `parameterized.parameterized.expand`.
8181
@parameterized.expand(
8282
param_list,

tests/influence/_core/test_tracin_regression.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
TracInCPFast,
1414
TracInCPFastRandProj,
1515
)
16-
from parameterized import parameterized
17-
from tests.helpers import BaseTest
18-
from tests.helpers.basic import assertTensorAlmostEqual
19-
from tests.helpers.influence.common import (
16+
from captum.testing.helpers.influence.common import (
2017
_isSorted,
2118
_wrap_model_in_dataparallel,
2219
build_test_name_func,
@@ -25,6 +22,9 @@
2522
IdentityDataset,
2623
RangeDataset,
2724
)
25+
from parameterized import parameterized
26+
from tests.helpers import BaseTest
27+
from tests.helpers.basic import assertTensorAlmostEqual
2828
from torch import Tensor
2929

3030

@@ -142,7 +142,7 @@ def _test_tracin_regression_setup(
142142
param_list.append((reduction, constructor, mode, dim, use_gpu))
143143

144144
# pyre-fixme[56]: Pyre was not able to infer the type of argument
145-
# `tests.helpers.influence.common.build_test_name_func
145+
# `captum.testing.helpers.influence.common.build_test_name_func
146146
# ($parameter$args_to_skip = ["reduction"])` to decorator factory
147147
# `parameterized.parameterized.expand`.
148148
@parameterized.expand(
@@ -258,7 +258,7 @@ def test_tracin_regression(
258258
)
259259

260260
# pyre-fixme[56]: Pyre was not able to infer the type of argument
261-
# `tests.helpers.influence.common.build_test_name_func()`
261+
# `captum.testing.helpers.influence.common.build_test_name_func()`
262262
# to decorator factory `parameterized.parameterized.expand`.
263263
@parameterized.expand(
264264
[
@@ -350,7 +350,7 @@ def _test_tracin_identity_regression_setup(
350350
return dataset, net
351351

352352
# pyre-fixme[56]: Pyre was not able to infer the type of argument
353-
# `tests.helpers.influence.common.build_test_name_func()`
353+
# `captum.testing.helpers.influence.common.build_test_name_func()`
354354
# to decorator factory `parameterized.parameterized.expand`
355355
@parameterized.expand(
356356
[
@@ -465,7 +465,7 @@ def test_tracin_identity_regression(
465465
)
466466

467467
# pyre-fixme[56]: Pyre was not able to infer the type of argument
468-
# `tests.helpers.influence.common.build_test_name_func()`
468+
# `captum.testing.helpers.influence.common.build_test_name_func()`
469469
# to decorator factory `parameterized.parameterized.expand`.
470470
@parameterized.expand(
471471
[

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from captum.influence._core.influence_function import NaiveInfluenceFunction
99
from captum.influence._core.tracincp import TracInCP, TracInCPBase
1010
from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast
11-
from parameterized import parameterized
12-
from tests.helpers import BaseTest
13-
from tests.helpers.basic import assertTensorAlmostEqual
14-
from tests.helpers.influence.common import (
11+
from captum.testing.helpers.influence.common import (
1512
_format_batch_into_tuple,
1613
build_test_name_func,
1714
DataInfluenceConstructor,
1815
get_random_model_and_data,
1916
GPU_SETTING_LIST,
2017
is_gpu,
2118
)
19+
from parameterized import parameterized
20+
from tests.helpers import BaseTest
21+
from tests.helpers.basic import assertTensorAlmostEqual
2222
from torch.utils.data import DataLoader
2323

2424

0 commit comments

Comments
 (0)