Skip to content

Commit 77e2bc4

Browse files
jjunchofacebook-github-bot
authored andcommitted
layer_integrated_gradients is too complex (#1407)
Summary: This diff addresses the C901 in visualization.py by breaking down the method Reviewed By: vivekmig Differential Revision: D64565179
1 parent b80e488 commit 77e2bc4

File tree

1 file changed

+121
-109
lines changed

1 file changed

+121
-109
lines changed

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 121 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,124 @@ def __init__(
109109
stacklevel=2,
110110
)
111111

112+
def _make_gradient_func(
113+
self,
114+
# pyre-fixme[2]: Parameter needs type annotation.
115+
num_outputs_cumsum,
116+
attribute_to_layer_input: bool,
117+
) -> Callable[..., Tuple[Tensor, ...]]:
118+
119+
def _gradient_func(
120+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
121+
forward_fn: Callable,
122+
inputs: Union[Tensor, Tuple[Tensor, ...]],
123+
target_ind: TargetType = None,
124+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
125+
additional_forward_args: Any = None,
126+
) -> Tuple[Tensor, ...]:
127+
if self.device_ids is None or len(self.device_ids) == 0:
128+
scattered_inputs = (inputs,)
129+
else:
130+
# scatter method does not have a precise enough return type in its
131+
# stub, so suppress the type warning.
132+
scattered_inputs = scatter( # type:ignore
133+
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
134+
# `Union[Tensor, typing.Tuple[Tensor, ...]]`.
135+
inputs,
136+
target_gpus=self.device_ids,
137+
)
138+
139+
scattered_inputs_dict = {
140+
scattered_input[0].device: scattered_input
141+
for scattered_input in scattered_inputs
142+
}
143+
144+
with torch.autograd.set_grad_enabled(True):
145+
146+
# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
147+
# annotated.
148+
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
149+
# annotated.
150+
# pyre-fixme[3]: Return type must be annotated.
151+
def layer_forward_hook(
152+
# pyre-fixme[2]: Parameter must be annotated.
153+
module,
154+
# pyre-fixme[2]: Parameter must be annotated.
155+
hook_inputs,
156+
# pyre-fixme[2]: Parameter must be annotated.
157+
hook_outputs=None,
158+
# pyre-fixme[2]: Parameter must be annotated.
159+
layer_idx=0,
160+
):
161+
device = _extract_device(module, hook_inputs, hook_outputs)
162+
is_layer_tuple = (
163+
isinstance(hook_outputs, tuple)
164+
# hook_outputs is None if attribute_to_layer_input == True
165+
if hook_outputs is not None
166+
else isinstance(hook_inputs, tuple)
167+
)
168+
169+
if is_layer_tuple:
170+
return scattered_inputs_dict[device][
171+
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
172+
layer_idx + 1
173+
]
174+
]
175+
176+
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
177+
178+
hooks = []
179+
try:
180+
181+
layers = self.layer
182+
if not isinstance(layers, list):
183+
layers = [self.layer]
184+
185+
for layer_idx, layer in enumerate(layers):
186+
hook = None
187+
# TODO:
188+
# Allow multiple attribute_to_layer_input flags for
189+
# each layer, i.e. attribute_to_layer_input[layer_idx]
190+
if attribute_to_layer_input:
191+
hook = layer.register_forward_pre_hook(
192+
functools.partial(
193+
layer_forward_hook, layer_idx=layer_idx
194+
)
195+
)
196+
else:
197+
hook = layer.register_forward_hook(
198+
functools.partial(
199+
layer_forward_hook, layer_idx=layer_idx
200+
)
201+
)
202+
203+
hooks.append(hook)
204+
205+
# the inputs is an empty tuple
206+
# coz it is prepended into additional_forward_args
207+
output = _run_forward(
208+
self.forward_func, (), target_ind, additional_forward_args
209+
)
210+
finally:
211+
for hook in hooks:
212+
if hook is not None:
213+
hook.remove()
214+
215+
# _run_forward may return future of Tensor,
216+
# but we don't support it here now
217+
# And it will fail before here.
218+
output = cast(Tensor, output)
219+
assert output[0].numel() == 1, (
220+
"Target not provided when necessary, cannot"
221+
" take gradient with respect to multiple outputs."
222+
)
223+
# torch.unbind(forward_out) is a list of scalar tensor tuples and
224+
# contains batch_size * #steps elements
225+
grads = torch.autograd.grad(torch.unbind(output), inputs)
226+
return grads
227+
228+
return _gradient_func
229+
112230
@overload
113231
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
114232
# arguments of overload defined on line `112`.
@@ -415,116 +533,10 @@ def flatten_tuple(tup):
415533
baselines_layer = flatten_tuple(baselines_layer)
416534

417535
# inputs -> these inputs are scaled
418-
def gradient_func(
419-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
420-
forward_fn: Callable,
421-
inputs: Union[Tensor, Tuple[Tensor, ...]],
422-
target_ind: TargetType = None,
423-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
424-
additional_forward_args: Any = None,
425-
) -> Tuple[Tensor, ...]:
426-
if self.device_ids is None or len(self.device_ids) == 0:
427-
scattered_inputs = (inputs,)
428-
else:
429-
# scatter method does not have a precise enough return type in its
430-
# stub, so suppress the type warning.
431-
scattered_inputs = scatter( # type:ignore
432-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
433-
# `Union[Tensor, typing.Tuple[Tensor, ...]]`.
434-
inputs,
435-
target_gpus=self.device_ids,
436-
)
437-
438-
scattered_inputs_dict = {
439-
scattered_input[0].device: scattered_input
440-
for scattered_input in scattered_inputs
441-
}
442-
443-
with torch.autograd.set_grad_enabled(True):
444-
445-
# pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
446-
# annotated.
447-
# pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
448-
# annotated.
449-
# pyre-fixme[3]: Return type must be annotated.
450-
def layer_forward_hook(
451-
# pyre-fixme[2]: Parameter must be annotated.
452-
module,
453-
# pyre-fixme[2]: Parameter must be annotated.
454-
hook_inputs,
455-
# pyre-fixme[2]: Parameter must be annotated.
456-
hook_outputs=None,
457-
# pyre-fixme[2]: Parameter must be annotated.
458-
layer_idx=0,
459-
):
460-
device = _extract_device(module, hook_inputs, hook_outputs)
461-
is_layer_tuple = (
462-
isinstance(hook_outputs, tuple)
463-
# hook_outputs is None if attribute_to_layer_input == True
464-
if hook_outputs is not None
465-
else isinstance(hook_inputs, tuple)
466-
)
467-
468-
if is_layer_tuple:
469-
return scattered_inputs_dict[device][
470-
num_outputs_cumsum[layer_idx] : num_outputs_cumsum[
471-
layer_idx + 1
472-
]
473-
]
474-
475-
return scattered_inputs_dict[device][num_outputs_cumsum[layer_idx]]
476-
477-
hooks = []
478-
try:
479536

480-
layers = self.layer
481-
if not isinstance(layers, list):
482-
layers = [self.layer]
483-
484-
for layer_idx, layer in enumerate(layers):
485-
hook = None
486-
# TODO:
487-
# Allow multiple attribute_to_layer_input flags for
488-
# each layer, i.e. attribute_to_layer_input[layer_idx]
489-
if attribute_to_layer_input:
490-
hook = layer.register_forward_pre_hook(
491-
functools.partial(
492-
layer_forward_hook, layer_idx=layer_idx
493-
)
494-
)
495-
else:
496-
hook = layer.register_forward_hook(
497-
functools.partial(
498-
layer_forward_hook, layer_idx=layer_idx
499-
)
500-
)
501-
502-
hooks.append(hook)
503-
504-
# the inputs is an empty tuple
505-
# coz it is prepended into additional_forward_args
506-
output = _run_forward(
507-
self.forward_func, (), target_ind, additional_forward_args
508-
)
509-
finally:
510-
for hook in hooks:
511-
if hook is not None:
512-
hook.remove()
513-
514-
# _run_forward may return future of Tensor,
515-
# but we don't support it here now
516-
# And it will fail before here.
517-
output = cast(Tensor, output)
518-
assert output[0].numel() == 1, (
519-
"Target not provided when necessary, cannot"
520-
" take gradient with respect to multiple outputs."
521-
)
522-
# torch.unbind(forward_out) is a list of scalar tensor tuples and
523-
# contains batch_size * #steps elements
524-
grads = torch.autograd.grad(torch.unbind(output), inputs)
525-
return grads
526-
527-
self.ig.gradient_func = gradient_func
537+
self.ig.gradient_func = self._make_gradient_func(
538+
num_outputs_cumsum, attribute_to_layer_input
539+
)
528540
all_inputs = (
529541
(inps + additional_forward_args)
530542
if additional_forward_args is not None

0 commit comments

Comments
 (0)