@@ -109,6 +109,124 @@ def __init__(
109109 stacklevel = 2 ,
110110 )
111111
112+ def _make_gradient_func (
113+ self ,
114+ # pyre-fixme[2]: Parameter needs type annotation.
115+ num_outputs_cumsum ,
116+ attribute_to_layer_input : bool ,
117+ ) -> Callable [..., Tuple [Tensor , ...]]:
118+
119+ def _gradient_func (
120+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
121+ forward_fn : Callable ,
122+ inputs : Union [Tensor , Tuple [Tensor , ...]],
123+ target_ind : TargetType = None ,
124+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
125+ additional_forward_args : Any = None ,
126+ ) -> Tuple [Tensor , ...]:
127+ if self .device_ids is None or len (self .device_ids ) == 0 :
128+ scattered_inputs = (inputs ,)
129+ else :
130+ # scatter method does not have a precise enough return type in its
131+ # stub, so suppress the type warning.
132+ scattered_inputs = scatter ( # type:ignore
133+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got
134+ # `Union[Tensor, typing.Tuple[Tensor, ...]]`.
135+ inputs ,
136+ target_gpus = self .device_ids ,
137+ )
138+
139+ scattered_inputs_dict = {
140+ scattered_input [0 ].device : scattered_input
141+ for scattered_input in scattered_inputs
142+ }
143+
144+ with torch .autograd .set_grad_enabled (True ):
145+
146+ # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
147+ # annotated.
148+ # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
149+ # annotated.
150+ # pyre-fixme[3]: Return type must be annotated.
151+ def layer_forward_hook (
152+ # pyre-fixme[2]: Parameter must be annotated.
153+ module ,
154+ # pyre-fixme[2]: Parameter must be annotated.
155+ hook_inputs ,
156+ # pyre-fixme[2]: Parameter must be annotated.
157+ hook_outputs = None ,
158+ # pyre-fixme[2]: Parameter must be annotated.
159+ layer_idx = 0 ,
160+ ):
161+ device = _extract_device (module , hook_inputs , hook_outputs )
162+ is_layer_tuple = (
163+ isinstance (hook_outputs , tuple )
164+ # hook_outputs is None if attribute_to_layer_input == True
165+ if hook_outputs is not None
166+ else isinstance (hook_inputs , tuple )
167+ )
168+
169+ if is_layer_tuple :
170+ return scattered_inputs_dict [device ][
171+ num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
172+ layer_idx + 1
173+ ]
174+ ]
175+
176+ return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
177+
178+ hooks = []
179+ try :
180+
181+ layers = self .layer
182+ if not isinstance (layers , list ):
183+ layers = [self .layer ]
184+
185+ for layer_idx , layer in enumerate (layers ):
186+ hook = None
187+ # TODO:
188+ # Allow multiple attribute_to_layer_input flags for
189+ # each layer, i.e. attribute_to_layer_input[layer_idx]
190+ if attribute_to_layer_input :
191+ hook = layer .register_forward_pre_hook (
192+ functools .partial (
193+ layer_forward_hook , layer_idx = layer_idx
194+ )
195+ )
196+ else :
197+ hook = layer .register_forward_hook (
198+ functools .partial (
199+ layer_forward_hook , layer_idx = layer_idx
200+ )
201+ )
202+
203+ hooks .append (hook )
204+
205+ # the inputs is an empty tuple
206+ # coz it is prepended into additional_forward_args
207+ output = _run_forward (
208+ self .forward_func , (), target_ind , additional_forward_args
209+ )
210+ finally :
211+ for hook in hooks :
212+ if hook is not None :
213+ hook .remove ()
214+
215+ # _run_forward may return future of Tensor,
216+ # but we don't support it here now
217+ # And it will fail before here.
218+ output = cast (Tensor , output )
219+ assert output [0 ].numel () == 1 , (
220+ "Target not provided when necessary, cannot"
221+ " take gradient with respect to multiple outputs."
222+ )
223+ # torch.unbind(forward_out) is a list of scalar tensor tuples and
224+ # contains batch_size * #steps elements
225+ grads = torch .autograd .grad (torch .unbind (output ), inputs )
226+ return grads
227+
228+ return _gradient_func
229+
112230 @overload
113231 # pyre-fixme[43]: The implementation of `attribute` does not accept all possible
114232 # arguments of overload defined on line `112`.
@@ -415,116 +533,10 @@ def flatten_tuple(tup):
415533 baselines_layer = flatten_tuple (baselines_layer )
416534
417535 # inputs -> these inputs are scaled
418- def gradient_func (
419- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
420- forward_fn : Callable ,
421- inputs : Union [Tensor , Tuple [Tensor , ...]],
422- target_ind : TargetType = None ,
423- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
424- additional_forward_args : Any = None ,
425- ) -> Tuple [Tensor , ...]:
426- if self .device_ids is None or len (self .device_ids ) == 0 :
427- scattered_inputs = (inputs ,)
428- else :
429- # scatter method does not have a precise enough return type in its
430- # stub, so suppress the type warning.
431- scattered_inputs = scatter ( # type:ignore
432- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
433- # `Union[Tensor, typing.Tuple[Tensor, ...]]`.
434- inputs ,
435- target_gpus = self .device_ids ,
436- )
437-
438- scattered_inputs_dict = {
439- scattered_input [0 ].device : scattered_input
440- for scattered_input in scattered_inputs
441- }
442-
443- with torch .autograd .set_grad_enabled (True ):
444-
445- # pyre-fixme[53]: Captured variable `num_outputs_cumsum` is not
446- # annotated.
447- # pyre-fixme[53]: Captured variable `scattered_inputs_dict` is not
448- # annotated.
449- # pyre-fixme[3]: Return type must be annotated.
450- def layer_forward_hook (
451- # pyre-fixme[2]: Parameter must be annotated.
452- module ,
453- # pyre-fixme[2]: Parameter must be annotated.
454- hook_inputs ,
455- # pyre-fixme[2]: Parameter must be annotated.
456- hook_outputs = None ,
457- # pyre-fixme[2]: Parameter must be annotated.
458- layer_idx = 0 ,
459- ):
460- device = _extract_device (module , hook_inputs , hook_outputs )
461- is_layer_tuple = (
462- isinstance (hook_outputs , tuple )
463- # hook_outputs is None if attribute_to_layer_input == True
464- if hook_outputs is not None
465- else isinstance (hook_inputs , tuple )
466- )
467-
468- if is_layer_tuple :
469- return scattered_inputs_dict [device ][
470- num_outputs_cumsum [layer_idx ] : num_outputs_cumsum [
471- layer_idx + 1
472- ]
473- ]
474-
475- return scattered_inputs_dict [device ][num_outputs_cumsum [layer_idx ]]
476-
477- hooks = []
478- try :
479536
480- layers = self .layer
481- if not isinstance (layers , list ):
482- layers = [self .layer ]
483-
484- for layer_idx , layer in enumerate (layers ):
485- hook = None
486- # TODO:
487- # Allow multiple attribute_to_layer_input flags for
488- # each layer, i.e. attribute_to_layer_input[layer_idx]
489- if attribute_to_layer_input :
490- hook = layer .register_forward_pre_hook (
491- functools .partial (
492- layer_forward_hook , layer_idx = layer_idx
493- )
494- )
495- else :
496- hook = layer .register_forward_hook (
497- functools .partial (
498- layer_forward_hook , layer_idx = layer_idx
499- )
500- )
501-
502- hooks .append (hook )
503-
504- # the inputs is an empty tuple
505- # coz it is prepended into additional_forward_args
506- output = _run_forward (
507- self .forward_func , (), target_ind , additional_forward_args
508- )
509- finally :
510- for hook in hooks :
511- if hook is not None :
512- hook .remove ()
513-
514- # _run_forward may return future of Tensor,
515- # but we don't support it here now
516- # And it will fail before here.
517- output = cast (Tensor , output )
518- assert output [0 ].numel () == 1 , (
519- "Target not provided when necessary, cannot"
520- " take gradient with respect to multiple outputs."
521- )
522- # torch.unbind(forward_out) is a list of scalar tensor tuples and
523- # contains batch_size * #steps elements
524- grads = torch .autograd .grad (torch .unbind (output ), inputs )
525- return grads
526-
527- self .ig .gradient_func = gradient_func
537+ self .ig .gradient_func = self ._make_gradient_func (
538+ num_outputs_cumsum , attribute_to_layer_input
539+ )
528540 all_inputs = (
529541 (inps + additional_forward_args )
530542 if additional_forward_args is not None
0 commit comments