From 53ba4e23ac6390f98b8f5d0eed33dc925fe2e003 Mon Sep 17 00:00:00 2001 From: Samuel Yusuf Date: Tue, 14 Jan 2025 08:55:08 -0800 Subject: [PATCH] Moving the Influence test helper (#1484) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1484 This Diff is to remove base modules for the testing functionality within Captum. The influence test helper is limited in its scope so it's a smaller Diff. Differential Revision: D68157431 --- .../testing}/helpers/influence/__init__.py | 0 .../testing}/helpers/influence/common.py | 0 tests/influence/_core/test_arnoldi_influence.py | 8 ++++---- tests/influence/_core/test_dataloader.py | 16 ++++++++-------- tests/influence/_core/test_naive_influence.py | 8 ++++---- .../_core/test_tracin_aggregate_influence.py | 6 +++--- .../_core/test_tracin_intermediate_quantities.py | 8 ++++---- .../_core/test_tracin_k_most_influential.py | 12 ++++++------ tests/influence/_core/test_tracin_regression.py | 16 ++++++++-------- .../_core/test_tracin_self_influence.py | 8 ++++---- .../influence/_core/test_tracin_show_progress.py | 10 +++++----- tests/influence/_core/test_tracin_validation.py | 8 ++++---- tests/influence/_core/test_tracin_xor.py | 10 +++++----- 13 files changed, 55 insertions(+), 55 deletions(-) rename {tests => captum/testing}/helpers/influence/__init__.py (100%) rename {tests => captum/testing}/helpers/influence/common.py (100%) diff --git a/tests/helpers/influence/__init__.py b/captum/testing/helpers/influence/__init__.py similarity index 100% rename from tests/helpers/influence/__init__.py rename to captum/testing/helpers/influence/__init__.py diff --git a/tests/helpers/influence/common.py b/captum/testing/helpers/influence/common.py similarity index 100% rename from tests/helpers/influence/common.py rename to captum/testing/helpers/influence/common.py diff --git a/tests/influence/_core/test_arnoldi_influence.py b/tests/influence/_core/test_arnoldi_influence.py index 5f4c7adceb..ad61f9c9cf 100644 --- a/tests/influence/_core/test_arnoldi_influence.py +++ b/tests/influence/_core/test_arnoldi_influence.py @@ -17,10 +17,7 @@ _top_eigen, _unflatten_params_factory, ) -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, @@ -31,6 +28,9 @@ is_gpu, UnpackDataset, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_dataloader.py b/tests/influence/_core/test_dataloader.py index 5fe29d41bf..3237270a91 100644 --- a/tests/influence/_core/test_dataloader.py +++ b/tests/influence/_core/test_dataloader.py @@ -9,15 +9,15 @@ TracInCPFast, TracInCPFastRandProj, ) -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader @@ -32,13 +32,13 @@ class TestTracInDataLoader(BaseTest): # `comprehension((reduction, constr, unpack_inputs) for # generators(generator(unpack_inputs in [False, True] if ), # generators(generator((reduction, constr) in - # [("none", tests.helpers.influence.common.DataInfluenceConstructor + # [("none", captum.testing.helpers.influence.common.DataInfluenceConstructor # (captum.influence._core.tracincp.TracInCP)), - # ("sum", tests.helpers.influence.common.DataInfluenceConstructor + # ("sum", captum.testing.helpers.influence.common.DataInfluenceConstructor # (captum.influence._core.tracincp_fast_rand_proj.TracInCPFast)), ("sum", - # tests.helpers.influence.common.DataInfluenceConstructor(captum.influence._core. + # captum.testing.helpers.influence.common.DataInfluenceConstructor(captum.influence._core. # tracincp_fast_rand_proj.TracInCPFastRandProj)), ("sum", - # tests.helpers.influence.common.DataInfluenceConstructor( + # captum.testing.helpers.influence.common.DataInfluenceConstructor( # captum.influence._core.tracincp_fast_rand_proj.TracInCPFastRandProj, # $parameter$name = "TracInCPFastRandProj_1DProj", # $parameter$projection_dim = 1))] if ))))` diff --git a/tests/influence/_core/test_naive_influence.py b/tests/influence/_core/test_naive_influence.py index 1255b26af7..bb2bca2bcc 100644 --- a/tests/influence/_core/test_naive_influence.py +++ b/tests/influence/_core/test_naive_influence.py @@ -12,10 +12,7 @@ _functional_call, _unflatten_params_factory, ) -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, @@ -25,6 +22,9 @@ Linear, UnpackDataset, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual, assertTensorTuplesAlmostEqual from torch.utils.data import DataLoader # TODO: for some unknow reason, this test does not work diff --git a/tests/influence/_core/test_tracin_aggregate_influence.py b/tests/influence/_core/test_tracin_aggregate_influence.py index 99f4098d14..7b293d201b 100644 --- a/tests/influence/_core/test_tracin_aggregate_influence.py +++ b/tests/influence/_core/test_tracin_aggregate_influence.py @@ -9,13 +9,13 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP -from parameterized import parameterized -from tests.helpers.basic import assertTensorAlmostEqual, BaseTest -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) +from parameterized import parameterized +from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index 82d10e5bf6..6298fba3d4 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -12,15 +12,15 @@ TracInCPFast, TracInCPFastRandProj, ) -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_k_most_influential.py b/tests/influence/_core/test_tracin_k_most_influential.py index 08224a60e5..15cf5097d7 100644 --- a/tests/influence/_core/test_tracin_k_most_influential.py +++ b/tests/influence/_core/test_tracin_k_most_influential.py @@ -6,11 +6,7 @@ import torch import torch.nn as nn from captum.influence._core.tracincp import TracInCP - -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, @@ -19,6 +15,10 @@ is_gpu, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual + class TestTracInGetKMostInfluential(BaseTest): param_list: List[ @@ -76,7 +76,7 @@ class TestTracInGetKMostInfluential(BaseTest): ) # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func()` + # `captum.testing.helpers.influence.common.build_test_name_func()` # to decorator factory `parameterized.parameterized.expand`. @parameterized.expand( param_list, diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 6e310a96b8..80863494b2 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -13,10 +13,7 @@ TracInCPFast, TracInCPFastRandProj, ) -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _isSorted, _wrap_model_in_dataparallel, build_test_name_func, @@ -25,6 +22,9 @@ IdentityDataset, RangeDataset, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch import Tensor @@ -142,7 +142,7 @@ def _test_tracin_regression_setup( param_list.append((reduction, constructor, mode, dim, use_gpu)) # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func + # `captum.testing.helpers.influence.common.build_test_name_func # ($parameter$args_to_skip = ["reduction"])` to decorator factory # `parameterized.parameterized.expand`. @parameterized.expand( @@ -258,7 +258,7 @@ def test_tracin_regression( ) # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func()` + # `captum.testing.helpers.influence.common.build_test_name_func()` # to decorator factory `parameterized.parameterized.expand`. @parameterized.expand( [ @@ -350,7 +350,7 @@ def _test_tracin_identity_regression_setup( return dataset, net # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func()` + # `captum.testing.helpers.influence.common.build_test_name_func()` # to decorator factory `parameterized.parameterized.expand` @parameterized.expand( [ @@ -465,7 +465,7 @@ def test_tracin_identity_regression( ) # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func()` + # `captum.testing.helpers.influence.common.build_test_name_func()` # to decorator factory `parameterized.parameterized.expand`. @parameterized.expand( [ diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index f93e6c74f2..fba15734b5 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -8,10 +8,7 @@ from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP, TracInCPBase from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _format_batch_into_tuple, build_test_name_func, DataInfluenceConstructor, @@ -19,6 +16,9 @@ GPU_SETTING_LIST, is_gpu, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual from torch.utils.data import DataLoader diff --git a/tests/influence/_core/test_tracin_show_progress.py b/tests/influence/_core/test_tracin_show_progress.py index ea72dcfabe..e259038acf 100644 --- a/tests/influence/_core/test_tracin_show_progress.py +++ b/tests/influence/_core/test_tracin_show_progress.py @@ -8,13 +8,13 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) +from parameterized import parameterized +from tests.helpers import BaseTest from torch.utils.data import DataLoader @@ -69,9 +69,9 @@ def _check_error_msg_multiplicity( # pyre-fixme[56]: Pyre was not able to infer the type of argument # `comprehension((reduction, constr, mode) for # generators(generator((reduction, constr) in - # [("none", tests.helpers.influence.common.DataInfluenceConstructor + # [("none", captum.testing.helpers.influence.common.DataInfluenceConstructor # (captum.influence._core.tracincp.TracInCP)), - # ("sum", tests.helpers.influence.common.DataInfluenceConstructor + # ("sum", captum.testing.helpers.influence.common.DataInfluenceConstructor # (captum.influence._core.tracincp_fast_rand_proj.TracInCPFast))] if ), # generators(generator(mode in ["self influence by checkpoints", # "self influence by batches", "influence", "k-most"] if ))))` diff --git a/tests/influence/_core/test_tracin_validation.py b/tests/influence/_core/test_tracin_validation.py index 54969ccf11..57ad1ce0b6 100644 --- a/tests/influence/_core/test_tracin_validation.py +++ b/tests/influence/_core/test_tracin_validation.py @@ -5,15 +5,15 @@ import torch.nn as nn from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast - -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( build_test_name_func, DataInfluenceConstructor, get_random_model_and_data, ) +from parameterized import parameterized +from tests.helpers import BaseTest + class TestTracinValidator(BaseTest): diff --git a/tests/influence/_core/test_tracin_xor.py b/tests/influence/_core/test_tracin_xor.py index a9ed3a389d..307f19df93 100644 --- a/tests/influence/_core/test_tracin_xor.py +++ b/tests/influence/_core/test_tracin_xor.py @@ -9,16 +9,16 @@ import torch.nn as nn import torch.nn.functional as F from captum.influence._core.tracincp import TracInCP -from parameterized import parameterized -from tests.helpers import BaseTest -from tests.helpers.basic import assertTensorAlmostEqual -from tests.helpers.influence.common import ( +from captum.testing.helpers.influence.common import ( _wrap_model_in_dataparallel, BasicLinearNet, BinaryDataset, build_test_name_func, DataInfluenceConstructor, ) +from parameterized import parameterized +from tests.helpers import BaseTest +from tests.helpers.basic import assertTensorAlmostEqual class TestTracInXOR(BaseTest): @@ -225,7 +225,7 @@ def _test_tracin_xor_setup( ) # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `tests.helpers.influence.common.build_test_name_func($parameter$args_to_skip + # `captum.testing.helpers.influence.common.build_test_name_func($parameter$args_to_skip # = ["reduction"])` to decorator factory `parameterized.parameterized.expand`. @parameterized.expand( parametrized_list,