Skip to content

Commit 9bac757

Browse files
Narine Kokhlikyanfacebook-github-bot
authored andcommitted
add influence gpu tests not using DataParallel (#1185)
Summary: Currently, when testing implementations of `TracInCPBase`, if the model to be tested is on gpu, we always wrap it in `DataParallel`. However, it is also worth testing when the model is on gpu, but is *not* wrapped in `DataParallel`. Whether the model is on gpu is currently specified by a `use_gpu` flag, which is boolean. In this diff, we change `use_gpu` to have type `Union[bool, str]`, which allowable values of `False` (model on cpu), `'cuda'` (model on gpu, not using `DataParallel`, and `'cuda_data_parallel'` (model on gpu, using `DataParallel`). This has backwards compatibility with classes like `ExplicitDataset`, which moves data to gpu `if use_gpu`, as strings are interpreted as being true. In further detail, the changes are as follows: - for tests (`TestTracInSelfInfluence`, `TestTracInKMostInfluential`) where `use_gpu` was called with `True`, now call them with values of `'cuda'` and `'cuda_parallel'` (in addition to `False`) - in those tests, make the layer names have the 'module' prefix only when `use_gpu='cuda_data_parallel'` - change `get_random_model_and_data`, which is where the `use_gpu` flag is used to create model and data, to reflect the new logic Differential Revision: D47190429
1 parent b321570 commit 9bac757

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tempfile
2-
from typing import Callable
2+
from typing import Callable, Union
33

44
import torch
55
import torch.nn as nn
@@ -18,7 +18,7 @@
1818
class TestTracInGetKMostInfluential(BaseTest):
1919

2020
use_gpu_list = (
21-
[True, False]
21+
[False, "cuda", "cuda_data_parallel"]
2222
if torch.cuda.is_available() and torch.cuda.device_count() != 0
2323
else [False]
2424
)
@@ -48,7 +48,9 @@ class TestTracInGetKMostInfluential(BaseTest):
4848
DataInfluenceConstructor(
4949
TracInCP,
5050
name="linear2",
51-
layers=["module.linear2"] if use_gpu else ["linear2"],
51+
layers=["module.linear2"]
52+
if use_gpu == "cuda_data_parallel"
53+
else ["linear2"],
5254
),
5355
False,
5456
),
@@ -83,7 +85,7 @@ def test_tracin_k_most_influential(
8385
proponents: bool,
8486
batch_size: int,
8587
k: int,
86-
use_gpu: bool,
88+
use_gpu: Union[bool, str],
8789
aggregate: bool,
8890
) -> None:
8991
"""

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tempfile
2-
from typing import Callable
2+
from typing import Callable, Union
33

44
import torch
55
import torch.nn as nn
@@ -19,7 +19,7 @@
1919
class TestTracInSelfInfluence(BaseTest):
2020

2121
use_gpu_list = (
22-
[True, False]
22+
[False, "cuda", "cuda_data_parallel"]
2323
if torch.cuda.is_available() and torch.cuda.device_count() != 0
2424
else [False]
2525
)
@@ -37,7 +37,9 @@ class TestTracInSelfInfluence(BaseTest):
3737
DataInfluenceConstructor(
3838
TracInCP,
3939
name="TracInCP_linear1",
40-
layers=["module.linear1"] if use_gpu else ["linear1"],
40+
layers=["module.linear1"]
41+
if use_gpu == "cuda_data_parallel"
42+
else ["linear1"],
4143
),
4244
),
4345
(
@@ -46,7 +48,7 @@ class TestTracInSelfInfluence(BaseTest):
4648
TracInCP,
4749
name="TracInCP_linear1_linear2",
4850
layers=["module.linear1", "module.linear2"]
49-
if use_gpu
51+
if use_gpu == "cuda_data_parallel"
5052
else ["linear1", "linear2"],
5153
),
5254
),
@@ -87,7 +89,7 @@ def test_tracin_self_influence(
8789
reduction: str,
8890
tracin_constructor: Callable,
8991
unpack_inputs: bool,
90-
use_gpu: bool,
92+
use_gpu: Union[bool, str],
9193
) -> None:
9294
with tempfile.TemporaryDirectory() as tmpdir:
9395
(net, train_dataset,) = get_random_model_and_data(

tests/influence/_utils/common.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ def forward(self, *inputs):
183183
def get_random_model_and_data(
184184
tmpdir, unpack_inputs, return_test_data=True, use_gpu=False
185185
):
186+
"""
187+
`use_gpu` can either be
188+
- `False`: returned model is on cpu
189+
- `'cuda'`: returned model is on gpu
190+
- `'cuda_data_parallel``: returned model is a `DataParallel` model, and on cpu
191+
The need to differentiate between `'cuda'` and `'cuda_data_parallel'`
192+
is that sometimes we may want to test a model that is on cpu, but is *not*
193+
wrapped in `DataParallel`.
194+
"""
195+
assert use_gpu in [False, "cuda", "cuda_data_parallel"]
186196

187197
in_features, hidden_nodes, out_features = 5, 4, 3
188198
num_inputs = 2
@@ -209,7 +219,11 @@ def get_random_model_and_data(
209219
if hasattr(net, "pre"):
210220
net.pre.weight.data = net.pre.weight.data.double()
211221
checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"])
212-
net_adjusted = _wrap_model_in_dataparallel(net) if use_gpu else net
222+
net_adjusted = (
223+
_wrap_model_in_dataparallel(net)
224+
if use_gpu == "cuda_data_parallel"
225+
else (net.to(device="cuda") if use_gpu == "cuda" else net)
226+
)
213227
torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name))
214228

215229
num_samples = 50
@@ -238,7 +252,9 @@ def get_random_model_and_data(
238252

239253
if return_test_data:
240254
return (
241-
_wrap_model_in_dataparallel(net) if use_gpu else net,
255+
_wrap_model_in_dataparallel(net)
256+
if use_gpu == "cuda_data_parallel"
257+
else (net.to(device="cuda") if use_gpu == "cuda" else net),
242258
dataset,
243259
_move_sample_to_cuda(test_samples)
244260
if isinstance(test_samples, list) and use_gpu
@@ -248,7 +264,12 @@ def get_random_model_and_data(
248264
test_labels.cuda() if use_gpu else test_labels,
249265
)
250266
else:
251-
return _wrap_model_in_dataparallel(net) if use_gpu else net, dataset
267+
return (
268+
_wrap_model_in_dataparallel(net)
269+
if use_gpu == "cuda_data_parallel"
270+
else (net.to(device="cuda") if use_gpu == "cuda" else net),
271+
dataset,
272+
)
252273

253274

254275
class DataInfluenceConstructor:

0 commit comments

Comments
 (0)