@@ -109,6 +109,123 @@ def __init__(
109109 stacklevel = 2 ,
110110 )
111111
112+ def _make_gradient_func (
113+ self ,
114+ num_outputs_cumsum : Tensor ,
115+ attribute_to_layer_input : bool ,
116+ ) -> Callable [..., Tuple [Tensor , ...]]:
117+
118+ def _gradient_func (
119+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
120+ forward_fn : Callable ,
121+ inputs : Union [Tensor , Tuple [Tensor , ...]],
122+ target_ind : TargetType = None ,
123+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
124+ additional_forward_args : Any = None ,
125+ ) -> Tuple [Tensor , ...]:
126+ if self .device_ids is None or len (self .device_ids ) == 0 :
127+ scattered_inputs = (inputs ,)
128+ else :
129+ # scatter method does not have a precise enough return type in its
130+ # stub, so suppress the type warning.
131+ scattered_inputs = scatter ( # type:ignore
132+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
133+ # `Union[Tensor, typing.Tuple[Tensor, ...]]`.
134+ inputs ,
135+ target_gpus = self .device_ids ,
136+ )
137+
138+ scattered_inputs_dict = {
139+ scattered_input [0 ].device : scattered_input
140+ for scattered_input in scattered_inputs
141+ }
142+
143+ with torch .autograd .set_grad_enabled (True ):
144+
145+ # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
146+ # annotated.
147+ # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
148+ # annotated.
149+ # pyre-fixme[3]: Return type must be annotated.
150+ def layer_forward_hook (
151+ # pyre-fixme[2]: Parameter must be annotated.
152+ module ,
153+ # pyre-fixme[2]: Parameter must be annotated.
154+ hook_inputs ,
155+ # pyre-fixme[2]: Parameter must be annotated.
156+ hook_outputs = None ,
157+ # pyre-fixme[2]: Parameter must be annotated.
158+ layer_idx = 0 ,
159+ ):
160+ device = _extract_device (module , hook_inputs , hook_outputs )
161+ is_layer_tuple = (
162+ isinstance (hook_outputs , tuple )
163+ # hook_outputs is None if attribute_to_layer_input == True
164+ if hook_outputs is not None
165+ else isinstance (hook_inputs , tuple )
166+ )
167+
168+ if is_layer_tuple :
169+ return scattered_inputs_dict [device ][
170+ num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
171+ layer_idx + 1
172+ ]
173+ ]
174+
175+ return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
176+
177+ hooks = []
178+ try :
179+
180+ layers = self .layer
181+ if not isinstance (layers , list ):
182+ layers = [self .layer ]
183+
184+ for layer_idx , layer in enumerate (layers ):
185+ hook = None
186+ # TODO:
187+ # Allow multiple attribute_to_layer_input flags for
188+ # each layer, i.e. attribute_to_layer_input[layer_idx]
189+ if attribute_to_layer_input :
190+ hook = layer .register_forward_pre_hook (
191+ functools .partial (
192+ layer_forward_hook , layer_idx = layer_idx
193+ )
194+ )
195+ else :
196+ hook = layer .register_forward_hook (
197+ functools .partial (
198+ layer_forward_hook , layer_idx = layer_idx
199+ )
200+ )
201+
202+ hooks .append (hook )
203+
204+ # the inputs is an empty tuple
205+ # coz it is prepended into additional_forward_args
206+ output = _run_forward (
207+ self .forward_func , (), target_ind , additional_forward_args
208+ )
209+ finally :
210+ for hook in hooks :
211+ if hook is not None :
212+ hook .remove ()
213+
214+ # _run_forward may return future of Tensor,
215+ # but we don't support it here now
216+ # And it will fail before here.
217+ output = cast (Tensor , output )
218+ assert output [0 ].numel () == 1 , (
219+ "Target not provided when necessary, cannot"
220+ " take gradient with respect to multiple outputs."
221+ )
222+ # torch.unbind(forward_out) is a list of scalar tensor tuples and
223+ # contains batch_size * #steps elements
224+ grads = torch .autograd .grad (torch .unbind (output ), inputs )
225+ return grads
226+
227+ return _gradient_func
228+
112229 @overload
113230 # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
114231 # arguments of overload defined on line `112`.
@@ -415,116 +532,10 @@ def flatten_tuple(tup):
415532 baselines_layer = flatten_tuple (baselines_layer )
416533
417534 # inputs -> these inputs are scaled
418- def gradient_func (
419- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
420- forward_fn : Callable ,
421- inputs : Union [Tensor , Tuple [Tensor , ...]],
422- target_ind : TargetType = None ,
423- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
424- additional_forward_args : Any = None ,
425- ) -> Tuple [Tensor , ...]:
426- if self .device_ids is None or len (self .device_ids ) == 0 :
427- scattered_inputs = (inputs ,)
428- else :
429- # scatter method does not have a precise enough return type in its
430- # stub, so suppress the type warning.
431- scattered_inputs = scatter ( # type:ignore
432- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
433- # `Union[Tensor, typing.Tuple[Tensor, ...]]`.
434- inputs ,
435- target_gpus = self .device_ids ,
436- )
437-
438- scattered_inputs_dict = {
439- scattered_input [0 ].device : scattered_input
440- for scattered_input in scattered_inputs
441- }
442-
443- with torch .autograd .set_grad_enabled (True ):
444-
445- # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
446- # annotated.
447- # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
448- # annotated.
449- # pyre-fixme[3]: Return type must be annotated.
450- def layer_forward_hook (
451- # pyre-fixme[2]: Parameter must be annotated.
452- module ,
453- # pyre-fixme[2]: Parameter must be annotated.
454- hook_inputs ,
455- # pyre-fixme[2]: Parameter must be annotated.
456- hook_outputs = None ,
457- # pyre-fixme[2]: Parameter must be annotated.
458- layer_idx = 0 ,
459- ):
460- device = _extract_device (module , hook_inputs , hook_outputs )
461- is_layer_tuple = (
462- isinstance (hook_outputs , tuple )
463- # hook_outputs is None if attribute_to_layer_input == True
464- if hook_outputs is not None
465- else isinstance (hook_inputs , tuple )
466- )
467-
468- if is_layer_tuple :
469- return scattered_inputs_dict [device ][
470- num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
471- layer_idx + 1
472- ]
473- ]
474-
475- return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
476-
477- hooks = []
478- try :
479535
480- layers = self .layer
481- if not isinstance (layers , list ):
482- layers = [self .layer ]
483-
484- for layer_idx , layer in enumerate (layers ):
485- hook = None
486- # TODO:
487- # Allow multiple attribute_to_layer_input flags for
488- # each layer, i.e. attribute_to_layer_input[layer_idx]
489- if attribute_to_layer_input :
490- hook = layer .register_forward_pre_hook (
491- functools .partial (
492- layer_forward_hook , layer_idx = layer_idx
493- )
494- )
495- else :
496- hook = layer .register_forward_hook (
497- functools .partial (
498- layer_forward_hook , layer_idx = layer_idx
499- )
500- )
501-
502- hooks .append (hook )
503-
504- # the inputs is an empty tuple
505- # coz it is prepended into additional_forward_args
506- output = _run_forward (
507- self .forward_func , (), target_ind , additional_forward_args
508- )
509- finally :
510- for hook in hooks :
511- if hook is not None :
512- hook .remove ()
513-
514- # _run_forward may return future of Tensor,
515- # but we don't support it here now
516- # And it will fail before here.
517- output = cast (Tensor , output )
518- assert output [0 ].numel () == 1 , (
519- "Target not provided when necessary, cannot"
520- " take gradient with respect to multiple outputs."
521- )
522- # torch.unbind(forward_out) is a list of scalar tensor tuples and
523- # contains batch_size * #steps elements
524- grads = torch .autograd .grad (torch .unbind (output ), inputs )
525- return grads
526-
527- self .ig .gradient_func = gradient_func
536+ self .ig .gradient_func = self ._make_gradient_func (
537+ num_outputs_cumsum , attribute_to_layer_input
538+ )
528539 all_inputs = (
529540 (inps + additional_forward_args )
530541 if additional_forward_args is not None
0 commit comments