|
4 | 4 | import copy |
5 | 5 | import os |
6 | 6 | from enum import Enum |
7 | | -from typing import Any, Callable, cast, Dict, Optional, Tuple, Type |
| 7 | +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import torch.distributed as dist |
@@ -136,91 +136,22 @@ def data_parallel_test_assert(self) -> None: |
136 | 136 | else: |
137 | 137 | cuda_args[key] = args[key] |
138 | 138 |
|
139 | | - alt_device_ids = None |
140 | 139 | cuda_model = copy.deepcopy(model).cuda() |
141 | | - # Initialize models based on DataParallelCompareMode |
142 | | - if mode is DataParallelCompareMode.cpu_cuda: |
143 | | - model_1, model_2 = model, cuda_model |
144 | | - args_1, args_2 = args, cuda_args |
145 | | - elif mode is DataParallelCompareMode.data_parallel_default: |
146 | | - model_1, model_2 = ( |
147 | | - cuda_model, |
148 | | - torch.nn.parallel.DataParallel(cuda_model), |
149 | | - ) |
150 | | - args_1, args_2 = cuda_args, cuda_args |
151 | | - elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids: |
152 | | - alt_device_ids = [0] + [ |
153 | | - x for x in range(torch.cuda.device_count() - 1, 0, -1) |
154 | | - ] |
155 | | - model_1, model_2 = ( |
156 | | - cuda_model, |
157 | | - torch.nn.parallel.DataParallel( |
158 | | - cuda_model, device_ids=alt_device_ids |
159 | | - ), |
160 | | - ) |
161 | | - args_1, args_2 = cuda_args, cuda_args |
162 | | - elif mode is DataParallelCompareMode.dist_data_parallel: |
163 | | - |
164 | | - model_1, model_2 = ( |
165 | | - cuda_model, |
166 | | - torch.nn.parallel.DistributedDataParallel( |
167 | | - cuda_model, device_ids=[0], output_device=0 |
168 | | - ), |
169 | | - ) |
170 | | - args_1, args_2 = cuda_args, cuda_args |
171 | | - else: |
172 | | - raise AssertionError("DataParallel compare mode type is not valid.") |
173 | | - |
174 | | - attr_method_1: Attribution |
175 | | - attr_method_2: Attribution |
176 | | - if target_layer: |
177 | | - internal_algorithm = cast(Type[InternalAttribution], algorithm) |
178 | | - attr_method_1 = internal_algorithm( |
179 | | - model_1, get_target_layer(model_1, target_layer) |
180 | | - ) |
181 | | - # cuda_model is used to obtain target_layer since DataParallel |
182 | | - # adds additional wrapper. |
183 | | - # model_2 is always either the CUDA model itself or DataParallel |
184 | | - if alt_device_ids is None: |
185 | | - attr_method_2 = internal_algorithm( |
186 | | - model_2, get_target_layer(cuda_model, target_layer) |
187 | | - ) |
188 | | - else: |
189 | | - # LayerDeepLift and LayerDeepLiftShap do not take device ids |
190 | | - # as a parameter, since they must always have the DataParallel |
191 | | - # model object directly. |
192 | | - # Some neuron methods and GuidedGradCAM also require the |
193 | | - # model and cannot take a forward function. |
194 | | - if issubclass( |
195 | | - internal_algorithm, |
196 | | - ( |
197 | | - LayerDeepLift, |
198 | | - LayerDeepLiftShap, |
199 | | - LayerLRP, |
200 | | - NeuronDeepLift, |
201 | | - NeuronDeepLiftShap, |
202 | | - NeuronDeconvolution, |
203 | | - NeuronGuidedBackprop, |
204 | | - GuidedGradCam, |
205 | | - ), |
206 | | - ): |
207 | | - attr_method_2 = internal_algorithm( |
208 | | - model_2, |
209 | | - get_target_layer(cuda_model, target_layer), # type: ignore |
210 | | - ) |
211 | | - else: |
212 | | - attr_method_2 = internal_algorithm( |
213 | | - model_2.forward, |
214 | | - get_target_layer(cuda_model, target_layer), |
215 | | - device_ids=alt_device_ids, |
216 | | - ) |
217 | | - else: |
218 | | - attr_method_1 = algorithm(model_1) |
219 | | - attr_method_2 = algorithm(model_2) |
| 140 | + # Set up test arguments based on DataParallelCompareMode |
| 141 | + model_1, model_2, args_1, args_2, alt_device_ids = _get_dp_test_args( |
| 142 | + cuda_model, model, cuda_args, args, mode |
| 143 | + ) |
220 | 144 |
|
221 | | - if noise_tunnel: |
222 | | - attr_method_1 = NoiseTunnel(attr_method_1) |
223 | | - attr_method_2 = NoiseTunnel(attr_method_2) |
| 145 | + # Construct attribution methods |
| 146 | + attr_method_1, attr_method_2 = _get_dp_attr_methods( |
| 147 | + algorithm, |
| 148 | + target_layer, |
| 149 | + model_1, |
| 150 | + model_2, |
| 151 | + cuda_model, |
| 152 | + alt_device_ids, |
| 153 | + noise_tunnel, |
| 154 | + ) |
224 | 155 | if attr_method_1.has_convergence_delta(): |
225 | 156 | attributions_1, delta_1 = attr_method_1.attribute( |
226 | 157 | return_convergence_delta=True, **args_1 |
@@ -266,6 +197,105 @@ def data_parallel_test_assert(self) -> None: |
266 | 197 | return data_parallel_test_assert |
267 | 198 |
|
268 | 199 |
|
| 200 | +def _get_dp_test_args( |
| 201 | + cuda_model: Module, |
| 202 | + model: Module, |
| 203 | + cuda_args: Dict[str, Any], |
| 204 | + args: Dict[str, Any], |
| 205 | + mode: DataParallelCompareMode, |
| 206 | +) -> Tuple[Module, Module, Dict[str, Any], Dict[str, Any], Optional[List[int]]]: |
| 207 | + # Initialize models based on DataParallelCompareMode |
| 208 | + alt_device_ids = None |
| 209 | + if mode is DataParallelCompareMode.cpu_cuda: |
| 210 | + model_1, model_2 = model, cuda_model |
| 211 | + args_1, args_2 = args, cuda_args |
| 212 | + elif mode is DataParallelCompareMode.data_parallel_default: |
| 213 | + model_1, model_2 = ( |
| 214 | + cuda_model, |
| 215 | + torch.nn.parallel.DataParallel(cuda_model), |
| 216 | + ) |
| 217 | + args_1, args_2 = cuda_args, cuda_args |
| 218 | + elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids: |
| 219 | + alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1)) |
| 220 | + model_1, model_2 = ( |
| 221 | + cuda_model, |
| 222 | + torch.nn.parallel.DataParallel(cuda_model, device_ids=alt_device_ids), |
| 223 | + ) |
| 224 | + args_1, args_2 = cuda_args, cuda_args |
| 225 | + elif mode is DataParallelCompareMode.dist_data_parallel: |
| 226 | + |
| 227 | + model_1, model_2 = ( |
| 228 | + cuda_model, |
| 229 | + torch.nn.parallel.DistributedDataParallel( |
| 230 | + cuda_model, device_ids=[0], output_device=0 |
| 231 | + ), |
| 232 | + ) |
| 233 | + args_1, args_2 = cuda_args, cuda_args |
| 234 | + else: |
| 235 | + raise AssertionError("DataParallel compare mode type is not valid.") |
| 236 | + |
| 237 | + return model_1, model_2, args_1, args_2, alt_device_ids |
| 238 | + |
| 239 | + |
| 240 | +def _get_dp_attr_methods( |
| 241 | + algorithm: Type[Attribution], |
| 242 | + target_layer: Optional[str], |
| 243 | + model_1: Module, |
| 244 | + model_2: Module, |
| 245 | + cuda_model: Module, |
| 246 | + alt_device_ids: Optional[List[int]], |
| 247 | + noise_tunnel: bool, |
| 248 | +) -> Tuple[Attribution, Attribution]: |
| 249 | + if target_layer: |
| 250 | + internal_algorithm = cast(Type[InternalAttribution], algorithm) |
| 251 | + attr_method_1 = internal_algorithm( |
| 252 | + model_1, get_target_layer(model_1, target_layer) |
| 253 | + ) |
| 254 | + # cuda_model is used to obtain target_layer since DataParallel |
| 255 | + # adds additional wrapper. |
| 256 | + # model_2 is always either the CUDA model itself or DataParallel |
| 257 | + if alt_device_ids is None: |
| 258 | + attr_method_2 = internal_algorithm( |
| 259 | + model_2, get_target_layer(cuda_model, target_layer) |
| 260 | + ) |
| 261 | + else: |
| 262 | + # LayerDeepLift and LayerDeepLiftShap do not take device ids |
| 263 | + # as a parameter, since they must always have the DataParallel |
| 264 | + # model object directly. |
| 265 | + # Some neuron methods and GuidedGradCAM also require the |
| 266 | + # model and cannot take a forward function. |
| 267 | + if issubclass( |
| 268 | + internal_algorithm, |
| 269 | + ( |
| 270 | + LayerDeepLift, |
| 271 | + LayerDeepLiftShap, |
| 272 | + LayerLRP, |
| 273 | + NeuronDeepLift, |
| 274 | + NeuronDeepLiftShap, |
| 275 | + NeuronDeconvolution, |
| 276 | + NeuronGuidedBackprop, |
| 277 | + GuidedGradCam, |
| 278 | + ), |
| 279 | + ): |
| 280 | + attr_method_2 = internal_algorithm( |
| 281 | + model_2, |
| 282 | + get_target_layer(cuda_model, target_layer), # type: ignore |
| 283 | + ) |
| 284 | + else: |
| 285 | + attr_method_2 = internal_algorithm( |
| 286 | + model_2.forward, |
| 287 | + get_target_layer(cuda_model, target_layer), |
| 288 | + device_ids=alt_device_ids, |
| 289 | + ) |
| 290 | + else: |
| 291 | + attr_method_1 = algorithm(model_1) |
| 292 | + attr_method_2 = algorithm(model_2) |
| 293 | + if noise_tunnel: |
| 294 | + attr_method_1 = NoiseTunnel(attr_method_1) |
| 295 | + attr_method_2 = NoiseTunnel(attr_method_2) |
| 296 | + return attr_method_1, attr_method_2 |
| 297 | + |
| 298 | + |
269 | 299 | if torch.cuda.is_available() and torch.cuda.device_count() != 0: |
270 | 300 |
|
271 | 301 | class DataParallelTest(BaseTest, metaclass=DataParallelMeta): |
|
0 commit comments