Skip to content

Commit ca604bf

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
add influence gpu tests not using DataParallel (#1185)
Summary: Currently, when testing implementations of `TracInCPBase`, if the model to be tested is on gpu, we always wrap it in `DataParallel`. However, it is also worth testing when the model is on gpu, but is *not* wrapped in `DataParallel`. Whether the model is on gpu is currently specified by a `use_gpu` flag, which is boolean. In this diff, we change `use_gpu` to have type `Union[bool, str]`, which allowable values of `False` (model on cpu), `'cuda'` (model on gpu, not using `DataParallel`, and `'cuda_data_parallel'` (model on gpu, using `DataParallel`). This has backwards compatibility with classes like `ExplicitDataset`, which moves data to gpu `if use_gpu`, as strings are interpreted as being true. In further detail, the changes are as follows: - for tests (`TestTracInSelfInfluence`, `TestTracInKMostInfluential`) where `use_gpu` was called with `True`, now call them with values of `'cuda'` and `'cuda_parallel'` (in addition to `False`) - in those tests, make the layer names have the 'module' prefix only when `use_gpu='cuda_data_parallel'` - change `get_random_model_and_data`, which is where the `use_gpu` flag is used to create model and data, to reflect the new logic Reviewed By: NarineK Differential Revision: D47190429
1 parent 104ec53 commit ca604bf

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

tests/influence/_core/test_tracin_k_most_influential.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tempfile
2-
from typing import Callable
2+
from typing import Callable, Union
33

44
import torch
55
import torch.nn as nn
@@ -18,7 +18,7 @@
1818
class TestTracInGetKMostInfluential(BaseTest):
1919

2020
use_gpu_list = (
21-
[True, False]
21+
[False, "cuda", "cuda_data_parallel"]
2222
if torch.cuda.is_available() and torch.cuda.device_count() != 0
2323
else [False]
2424
)
@@ -48,7 +48,9 @@ class TestTracInGetKMostInfluential(BaseTest):
4848
DataInfluenceConstructor(
4949
TracInCP,
5050
name="linear2",
51-
layers=["module.linear2"] if use_gpu else ["linear2"],
51+
layers=["module.linear2"]
52+
if use_gpu == "cuda_data_parallel"
53+
else ["linear2"],
5254
),
5355
False,
5456
),
@@ -83,7 +85,7 @@ def test_tracin_k_most_influential(
8385
proponents: bool,
8486
batch_size: int,
8587
k: int,
86-
use_gpu: bool,
88+
use_gpu: Union[bool, str],
8789
aggregate: bool,
8890
) -> None:
8991
"""

tests/influence/_core/test_tracin_self_influence.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tempfile
2-
from typing import Callable
2+
from typing import Callable, Union
33

44
import torch
55
import torch.nn as nn
@@ -19,7 +19,7 @@
1919
class TestTracInSelfInfluence(BaseTest):
2020

2121
use_gpu_list = (
22-
[True, False]
22+
[False, "cuda", "cuda_data_parallel"]
2323
if torch.cuda.is_available() and torch.cuda.device_count() != 0
2424
else [False]
2525
)
@@ -37,7 +37,9 @@ class TestTracInSelfInfluence(BaseTest):
3737
DataInfluenceConstructor(
3838
TracInCP,
3939
name="TracInCP_linear1",
40-
layers=["module.linear1"] if use_gpu else ["linear1"],
40+
layers=["module.linear1"]
41+
if use_gpu == "cuda_data_parallel"
42+
else ["linear1"],
4143
),
4244
),
4345
(
@@ -46,7 +48,7 @@ class TestTracInSelfInfluence(BaseTest):
4648
TracInCP,
4749
name="TracInCP_linear1_linear2",
4850
layers=["module.linear1", "module.linear2"]
49-
if use_gpu
51+
if use_gpu == "cuda_data_parallel"
5052
else ["linear1", "linear2"],
5153
),
5254
),
@@ -87,7 +89,7 @@ def test_tracin_self_influence(
8789
reduction: str,
8890
tracin_constructor: Callable,
8991
unpack_inputs: bool,
90-
use_gpu: bool,
92+
use_gpu: Union[bool, str],
9193
) -> None:
9294
with tempfile.TemporaryDirectory() as tmpdir:
9395
(net, train_dataset,) = get_random_model_and_data(

tests/influence/_utils/common.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,15 @@ def forward(self, *inputs):
183183
def get_random_model_and_data(
184184
tmpdir, unpack_inputs, return_test_data=True, use_gpu=False
185185
):
186+
"""
187+
`use_gpu` can either be
188+
- `False`: returned model is on cpu
189+
- `'cuda'`: returned model is on gpu
190+
- `'cuda_data_parallel``: returned model is a `DataParallel` model, and on cpu
191+
The need to differentiate between `'cuda'` and `'cuda_data_parallel'` is that sometimes
192+
we may want to test a model that is on cpu, but is *not* wrapped in `DataParallel`.
193+
"""
194+
assert use_gpu in [False, "cuda", "cuda_data_parallel"]
186195

187196
in_features, hidden_nodes, out_features = 5, 4, 3
188197
num_inputs = 2
@@ -209,7 +218,11 @@ def get_random_model_and_data(
209218
if hasattr(net, "pre"):
210219
net.pre.weight.data = net.pre.weight.data.double()
211220
checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"])
212-
net_adjusted = _wrap_model_in_dataparallel(net) if use_gpu else net
221+
net_adjusted = (
222+
_wrap_model_in_dataparallel(net)
223+
if use_gpu == "cuda_data_parallel"
224+
else (net.to(device="cuda") if use_gpu == "cuda" else net)
225+
)
213226
torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name))
214227

215228
num_samples = 50
@@ -238,7 +251,9 @@ def get_random_model_and_data(
238251

239252
if return_test_data:
240253
return (
241-
_wrap_model_in_dataparallel(net) if use_gpu else net,
254+
_wrap_model_in_dataparallel(net)
255+
if use_gpu == "cuda_data_parallel"
256+
else (net.to(device="cuda") if use_gpu == "cuda" else net),
242257
dataset,
243258
_move_sample_to_cuda(test_samples)
244259
if isinstance(test_samples, list) and use_gpu
@@ -248,7 +263,12 @@ def get_random_model_and_data(
248263
test_labels.cuda() if use_gpu else test_labels,
249264
)
250265
else:
251-
return _wrap_model_in_dataparallel(net) if use_gpu else net, dataset
266+
return (
267+
_wrap_model_in_dataparallel(net)
268+
if use_gpu == "cuda_data_parallel"
269+
else (net.to(device="cuda") if use_gpu == "cuda" else net),
270+
dataset,
271+
)
252272

253273

254274
class DataInfluenceConstructor:

0 commit comments

Comments
 (0)