@@ -104,8 +104,7 @@ def compute_gradients(
104104 forward_fn : Callable ,
105105 inputs : Union [Tensor , Tuple [Tensor , ...]],
106106 target_ind : TargetType = None ,
107- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
108- additional_forward_args : Any = None ,
107+ additional_forward_args : Optional [object ] = None ,
109108) -> Tuple [Tensor , ...]:
110109 r"""
111110 Computes gradients of the output with respect to inputs for an
@@ -175,8 +174,7 @@ def _forward_layer_eval(
175174 forward_fn : Callable ,
176175 inputs : Union [Tensor , Tuple [Tensor , ...]],
177176 layer : List [Module ],
178- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
179- additional_forward_args : Any = None ,
177+ additional_forward_args : Optional [object ] = None ,
180178 device_ids : Union [None , List [int ]] = None ,
181179 attribute_to_layer_input : bool = False ,
182180 grad_enabled : bool = False ,
@@ -191,8 +189,7 @@ def _forward_layer_eval(
191189 forward_fn : Callable ,
192190 inputs : Union [Tensor , Tuple [Tensor , ...]],
193191 layer : Module ,
194- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
195- additional_forward_args : Any = None ,
192+ additional_forward_args : Optional [object ] = None ,
196193 device_ids : Union [None , List [int ]] = None ,
197194 attribute_to_layer_input : bool = False ,
198195 grad_enabled : bool = False ,
@@ -204,7 +201,7 @@ def _forward_layer_eval(
204201 forward_fn : Callable ,
205202 inputs : Union [Tensor , Tuple [Tensor , ...]],
206203 layer : ModuleOrModuleList ,
207- additional_forward_args : Any = None ,
204+ additional_forward_args : Optional [ object ] = None ,
208205 device_ids : Union [None , List [int ]] = None ,
209206 attribute_to_layer_input : bool = False ,
210207 grad_enabled : bool = False ,
@@ -233,8 +230,7 @@ def _forward_layer_distributed_eval(
233230 inputs : Any ,
234231 layer : ModuleOrModuleList ,
235232 target_ind : TargetType = None ,
236- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
237- additional_forward_args : Any = None ,
233+ additional_forward_args : Optional [object ] = None ,
238234 attribute_to_layer_input : bool = False ,
239235 forward_hook_with_return : Literal [False ] = False ,
240236 require_layer_grads : bool = False ,
@@ -250,7 +246,7 @@ def _forward_layer_distributed_eval(
250246 inputs : Any ,
251247 layer : ModuleOrModuleList ,
252248 target_ind : TargetType = None ,
253- additional_forward_args : Any = None ,
249+ additional_forward_args : Optional [ object ] = None ,
254250 attribute_to_layer_input : bool = False ,
255251 * ,
256252 forward_hook_with_return : Literal [True ],
@@ -264,7 +260,7 @@ def _forward_layer_distributed_eval(
264260 inputs : Any ,
265261 layer : ModuleOrModuleList ,
266262 target_ind : TargetType = None ,
267- additional_forward_args : Any = None ,
263+ additional_forward_args : Optional [ object ] = None ,
268264 attribute_to_layer_input : bool = False ,
269265 forward_hook_with_return : bool = False ,
270266 require_layer_grads : bool = False ,
@@ -427,8 +423,7 @@ def _forward_layer_eval_with_neuron_grads(
427423 forward_fn : Callable ,
428424 inputs : Union [Tensor , Tuple [Tensor , ...]],
429425 layer : Module ,
430- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
431- additional_forward_args : Any = None ,
426+ additional_forward_args : Optional [object ] = None ,
432427 * ,
433428 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
434429 gradient_neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
@@ -446,7 +441,7 @@ def _forward_layer_eval_with_neuron_grads(
446441 forward_fn : Callable ,
447442 inputs : Union [Tensor , Tuple [Tensor , ...]],
448443 layer : List [Module ],
449- additional_forward_args : Any = None ,
444+ additional_forward_args : Optional [ object ] = None ,
450445 gradient_neuron_selector : None = None ,
451446 grad_enabled : bool = False ,
452447 device_ids : Union [None , List [int ]] = None ,
@@ -462,7 +457,7 @@ def _forward_layer_eval_with_neuron_grads(
462457 forward_fn : Callable ,
463458 inputs : Union [Tensor , Tuple [Tensor , ...]],
464459 layer : Module ,
465- additional_forward_args : Any = None ,
460+ additional_forward_args : Optional [ object ] = None ,
466461 gradient_neuron_selector : None = None ,
467462 grad_enabled : bool = False ,
468463 device_ids : Union [None , List [int ]] = None ,
@@ -475,7 +470,7 @@ def _forward_layer_eval_with_neuron_grads(
475470 forward_fn : Callable ,
476471 inputs : Union [Tensor , Tuple [Tensor , ...]],
477472 layer : ModuleOrModuleList ,
478- additional_forward_args : Any = None ,
473+ additional_forward_args : Optional [ object ] = None ,
479474 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
480475 gradient_neuron_selector : Union [
481476 None , int , Tuple [Union [int , slice ], ...], Callable
@@ -549,8 +544,7 @@ def compute_layer_gradients_and_eval(
549544 layer : Module ,
550545 inputs : Union [Tensor , Tuple [Tensor , ...]],
551546 target_ind : TargetType = None ,
552- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
553- additional_forward_args : Any = None ,
547+ additional_forward_args : Optional [object ] = None ,
554548 * ,
555549 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
556550 gradient_neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
@@ -571,7 +565,7 @@ def compute_layer_gradients_and_eval(
571565 layer : List [Module ],
572566 inputs : Union [Tensor , Tuple [Tensor , ...]],
573567 target_ind : TargetType = None ,
574- additional_forward_args : Any = None ,
568+ additional_forward_args : Optional [ object ] = None ,
575569 gradient_neuron_selector : None = None ,
576570 device_ids : Union [None , List [int ]] = None ,
577571 attribute_to_layer_input : bool = False ,
@@ -590,7 +584,7 @@ def compute_layer_gradients_and_eval(
590584 layer : Module ,
591585 inputs : Union [Tensor , Tuple [Tensor , ...]],
592586 target_ind : TargetType = None ,
593- additional_forward_args : Any = None ,
587+ additional_forward_args : Optional [ object ] = None ,
594588 gradient_neuron_selector : None = None ,
595589 device_ids : Union [None , List [int ]] = None ,
596590 attribute_to_layer_input : bool = False ,
@@ -606,7 +600,7 @@ def compute_layer_gradients_and_eval(
606600 layer : ModuleOrModuleList ,
607601 inputs : Union [Tensor , Tuple [Tensor , ...]],
608602 target_ind : TargetType = None ,
609- additional_forward_args : Any = None ,
603+ additional_forward_args : Optional [ object ] = None ,
610604 # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
611605 gradient_neuron_selector : Union [
612606 None , int , Tuple [Union [int , slice ], ...], Callable
@@ -792,8 +786,7 @@ def grad_fn(
792786 forward_fn : Callable ,
793787 inputs : TensorOrTupleOfTensorsGeneric ,
794788 target_ind : TargetType = None ,
795- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
796- additional_forward_args : Any = None ,
789+ additional_forward_args : Optional [object ] = None ,
797790 ) -> Tuple [Tensor , ...]:
798791 _ , grads = _forward_layer_eval_with_neuron_grads (
799792 forward_fn ,
0 commit comments