Skip to content

Commit 44e26fa

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of DataParallelMeta.make_single_dp_test (#1369)
Summary: Reduce complexity of DataParallelMeta.make_single_dp_test Differential Revision: D64370178
1 parent 29eb6c7 commit 44e26fa

File tree

1 file changed

+114
-84
lines changed

1 file changed

+114
-84
lines changed

tests/attr/test_data_parallel.py

Lines changed: 114 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import os
66
from enum import Enum
7-
from typing import Any, Callable, cast, Dict, Optional, Tuple, Type
7+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
88

99
import torch
1010
import torch.distributed as dist
@@ -136,91 +136,22 @@ def data_parallel_test_assert(self) -> None:
136136
else:
137137
cuda_args[key] = args[key]
138138

139-
alt_device_ids = None
140139
cuda_model = copy.deepcopy(model).cuda()
141-
# Initialize models based on DataParallelCompareMode
142-
if mode is DataParallelCompareMode.cpu_cuda:
143-
model_1, model_2 = model, cuda_model
144-
args_1, args_2 = args, cuda_args
145-
elif mode is DataParallelCompareMode.data_parallel_default:
146-
model_1, model_2 = (
147-
cuda_model,
148-
torch.nn.parallel.DataParallel(cuda_model),
149-
)
150-
args_1, args_2 = cuda_args, cuda_args
151-
elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids:
152-
alt_device_ids = [0] + [
153-
x for x in range(torch.cuda.device_count() - 1, 0, -1)
154-
]
155-
model_1, model_2 = (
156-
cuda_model,
157-
torch.nn.parallel.DataParallel(
158-
cuda_model, device_ids=alt_device_ids
159-
),
160-
)
161-
args_1, args_2 = cuda_args, cuda_args
162-
elif mode is DataParallelCompareMode.dist_data_parallel:
163-
164-
model_1, model_2 = (
165-
cuda_model,
166-
torch.nn.parallel.DistributedDataParallel(
167-
cuda_model, device_ids=[0], output_device=0
168-
),
169-
)
170-
args_1, args_2 = cuda_args, cuda_args
171-
else:
172-
raise AssertionError("DataParallel compare mode type is not valid.")
173-
174-
attr_method_1: Attribution
175-
attr_method_2: Attribution
176-
if target_layer:
177-
internal_algorithm = cast(Type[InternalAttribution], algorithm)
178-
attr_method_1 = internal_algorithm(
179-
model_1, get_target_layer(model_1, target_layer)
180-
)
181-
# cuda_model is used to obtain target_layer since DataParallel
182-
# adds additional wrapper.
183-
# model_2 is always either the CUDA model itself or DataParallel
184-
if alt_device_ids is None:
185-
attr_method_2 = internal_algorithm(
186-
model_2, get_target_layer(cuda_model, target_layer)
187-
)
188-
else:
189-
# LayerDeepLift and LayerDeepLiftShap do not take device ids
190-
# as a parameter, since they must always have the DataParallel
191-
# model object directly.
192-
# Some neuron methods and GuidedGradCAM also require the
193-
# model and cannot take a forward function.
194-
if issubclass(
195-
internal_algorithm,
196-
(
197-
LayerDeepLift,
198-
LayerDeepLiftShap,
199-
LayerLRP,
200-
NeuronDeepLift,
201-
NeuronDeepLiftShap,
202-
NeuronDeconvolution,
203-
NeuronGuidedBackprop,
204-
GuidedGradCam,
205-
),
206-
):
207-
attr_method_2 = internal_algorithm(
208-
model_2,
209-
get_target_layer(cuda_model, target_layer), # type: ignore
210-
)
211-
else:
212-
attr_method_2 = internal_algorithm(
213-
model_2.forward,
214-
get_target_layer(cuda_model, target_layer),
215-
device_ids=alt_device_ids,
216-
)
217-
else:
218-
attr_method_1 = algorithm(model_1)
219-
attr_method_2 = algorithm(model_2)
140+
# Set up test arguments based on DataParallelCompareMode
141+
model_1, model_2, args_1, args_2, alt_device_ids = _get_dp_test_args(
142+
cuda_model, model, cuda_args, args, mode
143+
)
220144

221-
if noise_tunnel:
222-
attr_method_1 = NoiseTunnel(attr_method_1)
223-
attr_method_2 = NoiseTunnel(attr_method_2)
145+
# Construct attribution methods
146+
attr_method_1, attr_method_2 = _get_dp_attr_methods(
147+
algorithm,
148+
target_layer,
149+
model_1,
150+
model_2,
151+
cuda_model,
152+
alt_device_ids,
153+
noise_tunnel,
154+
)
224155
if attr_method_1.has_convergence_delta():
225156
attributions_1, delta_1 = attr_method_1.attribute(
226157
return_convergence_delta=True, **args_1
@@ -266,6 +197,105 @@ def data_parallel_test_assert(self) -> None:
266197
return data_parallel_test_assert
267198

268199

200+
def _get_dp_test_args(
201+
cuda_model: Module,
202+
model: Module,
203+
cuda_args: Dict[str, Any],
204+
args: Dict[str, Any],
205+
mode: DataParallelCompareMode,
206+
) -> Tuple[Module, Module, Dict[str, Any], Dict[str, Any], Optional[List[int]]]:
207+
# Initialize models based on DataParallelCompareMode
208+
alt_device_ids = None
209+
if mode is DataParallelCompareMode.cpu_cuda:
210+
model_1, model_2 = model, cuda_model
211+
args_1, args_2 = args, cuda_args
212+
elif mode is DataParallelCompareMode.data_parallel_default:
213+
model_1, model_2 = (
214+
cuda_model,
215+
torch.nn.parallel.DataParallel(cuda_model),
216+
)
217+
args_1, args_2 = cuda_args, cuda_args
218+
elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids:
219+
alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1))
220+
model_1, model_2 = (
221+
cuda_model,
222+
torch.nn.parallel.DataParallel(cuda_model, device_ids=alt_device_ids),
223+
)
224+
args_1, args_2 = cuda_args, cuda_args
225+
elif mode is DataParallelCompareMode.dist_data_parallel:
226+
227+
model_1, model_2 = (
228+
cuda_model,
229+
torch.nn.parallel.DistributedDataParallel(
230+
cuda_model, device_ids=[0], output_device=0
231+
),
232+
)
233+
args_1, args_2 = cuda_args, cuda_args
234+
else:
235+
raise AssertionError("DataParallel compare mode type is not valid.")
236+
237+
return model_1, model_2, args_1, args_2, alt_device_ids
238+
239+
240+
def _get_dp_attr_methods(
241+
algorithm: Type[Attribution],
242+
target_layer: Optional[str],
243+
model_1: Module,
244+
model_2: Module,
245+
cuda_model: Module,
246+
alt_device_ids: Optional[List[int]],
247+
noise_tunnel: bool,
248+
) -> Tuple[Attribution, Attribution]:
249+
if target_layer:
250+
internal_algorithm = cast(Type[InternalAttribution], algorithm)
251+
attr_method_1 = internal_algorithm(
252+
model_1, get_target_layer(model_1, target_layer)
253+
)
254+
# cuda_model is used to obtain target_layer since DataParallel
255+
# adds additional wrapper.
256+
# model_2 is always either the CUDA model itself or DataParallel
257+
if alt_device_ids is None:
258+
attr_method_2 = internal_algorithm(
259+
model_2, get_target_layer(cuda_model, target_layer)
260+
)
261+
else:
262+
# LayerDeepLift and LayerDeepLiftShap do not take device ids
263+
# as a parameter, since they must always have the DataParallel
264+
# model object directly.
265+
# Some neuron methods and GuidedGradCAM also require the
266+
# model and cannot take a forward function.
267+
if issubclass(
268+
internal_algorithm,
269+
(
270+
LayerDeepLift,
271+
LayerDeepLiftShap,
272+
LayerLRP,
273+
NeuronDeepLift,
274+
NeuronDeepLiftShap,
275+
NeuronDeconvolution,
276+
NeuronGuidedBackprop,
277+
GuidedGradCam,
278+
),
279+
):
280+
attr_method_2 = internal_algorithm(
281+
model_2,
282+
get_target_layer(cuda_model, target_layer), # type: ignore
283+
)
284+
else:
285+
attr_method_2 = internal_algorithm(
286+
model_2.forward,
287+
get_target_layer(cuda_model, target_layer),
288+
device_ids=alt_device_ids,
289+
)
290+
else:
291+
attr_method_1 = algorithm(model_1)
292+
attr_method_2 = algorithm(model_2)
293+
if noise_tunnel:
294+
attr_method_1 = NoiseTunnel(attr_method_1)
295+
attr_method_2 = NoiseTunnel(attr_method_2)
296+
return attr_method_1, attr_method_2
297+
298+
269299
if torch.cuda.is_available() and torch.cuda.device_count() != 0:
270300

271301
class DataParallelTest(BaseTest, metaclass=DataParallelMeta):

0 commit comments

Comments
 (0)