11#!/usr/bin/env python3
2+
3+ # pyre-strict
24import typing
35from enum import Enum
46from functools import reduce
@@ -68,10 +70,18 @@ def safe_div(
6870
6971
7072@typing .overload
73+ # pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
74+ # is incompatible with the return type of the implementation (`bool`).
75+ # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
76+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7177def _is_tuple (inputs : Tensor ) -> Literal [False ]: ...
7278
7379
7480@typing .overload
81+ # pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
82+ # is incompatible with the return type of the implementation (`bool`).
83+ # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
84+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
7585def _is_tuple (inputs : Tuple [Tensor , ...]) -> Literal [True ]: ...
7686
7787
@@ -230,6 +240,8 @@ def _format_tensor_into_tuples(
230240 return inputs
231241
232242
243+ # pyre-fixme[3]: Return annotation cannot be `Any`.
244+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
233245def _format_inputs (inputs : Any , unpack_inputs : bool = True ) -> Any :
234246 return (
235247 inputs
@@ -257,16 +269,21 @@ def _format_additional_forward_args(additional_forward_args: None) -> None: ...
257269
258270@overload
259271def _format_additional_forward_args (
272+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
260273 additional_forward_args : Union [Tensor , Tuple ]
274+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
261275) -> Tuple : ...
262276
263277
264278@overload
265279def _format_additional_forward_args (
280+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
266281 additional_forward_args : Any ,
282+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
267283) -> Union [None , Tuple ]: ...
268284
269285
286+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
270287def _format_additional_forward_args (additional_forward_args : Any ) -> Union [None , Tuple ]:
271288 if additional_forward_args is not None and not isinstance (
272289 additional_forward_args , tuple
@@ -276,9 +293,11 @@ def _format_additional_forward_args(additional_forward_args: Any) -> Union[None,
276293
277294
278295def _expand_additional_forward_args (
296+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
279297 additional_forward_args : Any ,
280298 n_steps : int ,
281299 expansion_type : ExpansionTypes = ExpansionTypes .repeat ,
300+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
282301) -> Union [None , Tuple ]:
283302 def _expand_tensor_forward_arg (
284303 additional_forward_arg : Tensor ,
@@ -343,9 +362,12 @@ def _expand_target(
343362 return target
344363
345364
365+ # pyre-fixme[3]: Return type must be annotated.
346366def _expand_feature_mask (
347367 feature_mask : Union [Tensor , Tuple [Tensor , ...]], n_samples : int
348368):
369+ # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
370+ # typing.Tuple[Tensor, ...]]`.
349371 is_feature_mask_tuple = _is_tuple (feature_mask )
350372 feature_mask = _format_tensor_into_tuples (feature_mask )
351373 feature_mask_new = tuple (
@@ -359,12 +381,17 @@ def _expand_feature_mask(
359381 return _format_output (is_feature_mask_tuple , feature_mask_new )
360382
361383
384+ # pyre-fixme[3]: Return type must be annotated.
362385def _expand_and_update_baselines (
363386 inputs : Tuple [Tensor , ...],
364387 n_samples : int ,
388+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
389+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
365390 kwargs : dict ,
366391 draw_baseline_from_distrib : bool = False ,
367392):
393+ # pyre-fixme[3]: Return type must be annotated.
394+ # pyre-fixme[2]: Parameter must be annotated.
368395 def get_random_baseline_indices (bsz , baseline ):
369396 num_ref_samples = baseline .shape [0 ]
370397 return np .random .choice (num_ref_samples , n_samples * bsz ).tolist ()
@@ -404,6 +431,9 @@ def get_random_baseline_indices(bsz, baseline):
404431 kwargs ["baselines" ] = baselines
405432
406433
434+ # pyre-fixme[3]: Return type must be annotated.
435+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
436+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
407437def _expand_and_update_additional_forward_args (n_samples : int , kwargs : dict ):
408438 if "additional_forward_args" not in kwargs :
409439 return
@@ -420,6 +450,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
420450 kwargs ["additional_forward_args" ] = additional_forward_args
421451
422452
453+ # pyre-fixme[3]: Return type must be annotated.
454+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
455+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
423456def _expand_and_update_target (n_samples : int , kwargs : dict ):
424457 if "target" not in kwargs :
425458 return
@@ -431,6 +464,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
431464 kwargs ["target" ] = target
432465
433466
467+ # pyre-fixme[3]: Return type must be annotated.
468+ # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
469+ # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
434470def _expand_and_update_feature_mask (n_samples : int , kwargs : dict ):
435471 if "feature_mask" not in kwargs :
436472 return
@@ -444,14 +480,24 @@ def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
444480
445481
446482@typing .overload
483+ # pyre-fixme[43]: The implementation of `_format_output` does not accept all
484+ # possible arguments of overload defined on line `449`.
447485def _format_output (
448- is_inputs_tuple : Literal [True ], output : Tuple [Tensor , ...]
486+ # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
487+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
488+ is_inputs_tuple : Literal [True ],
489+ output : Tuple [Tensor , ...],
449490) -> Tuple [Tensor , ...]: ...
450491
451492
452493@typing .overload
494+ # pyre-fixme[43]: The implementation of `_format_output` does not accept all
495+ # possible arguments of overload defined on line `455`.
453496def _format_output (
454- is_inputs_tuple : Literal [False ], output : Tuple [Tensor , ...]
497+ # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
498+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
499+ is_inputs_tuple : Literal [False ],
500+ output : Tuple [Tensor , ...],
455501) -> Tensor : ...
456502
457503
@@ -474,18 +520,30 @@ def _format_output(
474520 "The input is a single tensor however the output isn't."
475521 "The number of output tensors is: {}" .format (len (output ))
476522 )
523+ # pyre-fixme[7]: Expected `Union[Tensor, typing.Tuple[Tensor, ...]]` but got
524+ # `Union[tuple[Tensor], Tensor]`.
477525 return output if is_inputs_tuple else output [0 ]
478526
479527
480528@typing .overload
529+ # pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
530+ # possible arguments of overload defined on line `483`.
481531def _format_outputs (
482- is_multiple_inputs : Literal [False ], outputs : List [Tuple [Tensor , ...]]
532+ # pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
533+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
534+ is_multiple_inputs : Literal [False ],
535+ outputs : List [Tuple [Tensor , ...]],
483536) -> Union [Tensor , Tuple [Tensor , ...]]: ...
484537
485538
486539@typing .overload
540+ # pyre-fixme[43]: The implementation of `_format_outputs` does not accept all
541+ # possible arguments of overload defined on line `489`.
487542def _format_outputs (
488- is_multiple_inputs : Literal [True ], outputs : List [Tuple [Tensor , ...]]
543+ # pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
544+ # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
545+ is_multiple_inputs : Literal [True ],
546+ outputs : List [Tuple [Tensor , ...]],
489547) -> List [Union [Tensor , Tuple [Tensor , ...]]]: ...
490548
491549
@@ -512,9 +570,12 @@ def _format_outputs(
512570
513571
514572def _run_forward (
573+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
515574 forward_func : Callable ,
575+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
516576 inputs : Any ,
517577 target : TargetType = None ,
578+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
518579 additional_forward_args : Any = None ,
519580) -> Union [Tensor , Future [Tensor ]]:
520581 forward_func_args = signature (forward_func ).parameters
@@ -529,6 +590,8 @@ def _run_forward(
529590
530591 output = forward_func (
531592 * (
593+ # pyre-fixme[60]: Concatenation not yet support for multiple variadic
594+ # tuples: `*inputs, *additional_forward_args`.
532595 (* inputs , * additional_forward_args )
533596 if additional_forward_args is not None
534597 else inputs
@@ -606,6 +669,8 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
606669 elif isinstance (target [0 ], tuple ):
607670 return torch .stack (
608671 [
672+ # pyre-fixme[24]: Generic type `tuple` expects at least 1 type
673+ # parameter.
609674 output [(i ,) + cast (Tuple , targ_elem )]
610675 for i , targ_elem in enumerate (target )
611676 ]
@@ -639,9 +704,11 @@ def _verify_select_column(
639704
640705def _verify_select_neuron (
641706 layer_output : Tuple [Tensor , ...],
707+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
642708 selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
643709) -> Tensor :
644710 if callable (selector ):
711+ # pyre-fixme[7]: Expected `Tensor` but got `object`.
645712 return selector (layer_output if len (layer_output ) > 1 else layer_output [0 ])
646713
647714 assert len (layer_output ) == 1 , (
@@ -688,6 +755,9 @@ def _extract_device(
688755
689756def _reduce_list (
690757 val_list : Sequence [TupleOrTensorOrBoolGeneric ],
758+ # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
759+ # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
760+ # `typing.List[<element type>]` to avoid runtime subscripting errors.
691761 red_func : Callable [[List ], Any ] = torch .cat ,
692762) -> TupleOrTensorOrBoolGeneric :
693763 """
@@ -702,21 +772,28 @@ def _reduce_list(
702772 """
703773 assert len (val_list ) > 0 , "Cannot reduce empty list!"
704774 if isinstance (val_list [0 ], torch .Tensor ):
775+ # pyre-fixme[16]: `bool` has no attribute `device`.
705776 first_device = val_list [0 ].device
777+ # pyre-fixme[16]: `bool` has no attribute `to`.
706778 return red_func ([elem .to (first_device ) for elem in val_list ])
707779 elif isinstance (val_list [0 ], bool ):
780+ # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
708781 return any (val_list )
709782 elif isinstance (val_list [0 ], tuple ):
710783 final_out = []
784+ # pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
785+ # but got `TupleOrTensorOrBoolGeneric`.
711786 for i in range (len (val_list [0 ])):
712787 final_out .append (
788+ # pyre-fixme[16]: `bool` has no attribute `__getitem__`.
713789 _reduce_list ([val_elem [i ] for val_elem in val_list ], red_func )
714790 )
715791 else :
716792 raise AssertionError (
717793 "Elements to be reduced can only be"
718794 "either Tensors or tuples containing Tensors."
719795 )
796+ # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`.
720797 return tuple (final_out )
721798
722799
@@ -756,6 +833,7 @@ def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor:
756833 return torch .cat ([single_inp .flatten () for single_inp in inp ])
757834
758835
836+ # pyre-fixme[3]: Return annotation cannot be `Any`.
759837def _get_module_from_name (model : Module , layer_name : str ) -> Any :
760838 r"""
761839 Returns the module (layer) object, given its (string) name
@@ -772,7 +850,11 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any:
772850
773851
774852def _register_backward_hook (
775- module : Module , hook : Callable , attr_obj : Any
853+ module : Module ,
854+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
855+ hook : Callable ,
856+ # pyre-fixme[2]: Parameter annotation cannot be `Any`.
857+ attr_obj : Any ,
776858) -> List [torch .utils .hooks .RemovableHandle ]:
777859 grad_out : Dict [device , Tensor ] = {}
778860
@@ -784,6 +866,7 @@ def forward_hook(
784866 nonlocal grad_out
785867 grad_out = {}
786868
869+ # pyre-fixme[53]: Captured variable `grad_out` is not annotated.
787870 def output_tensor_hook (output_grad : Tensor ) -> None :
788871 grad_out [output_grad .device ] = output_grad
789872
@@ -795,7 +878,11 @@ def output_tensor_hook(output_grad: Tensor) -> None:
795878 else :
796879 out .register_hook (output_tensor_hook )
797880
881+ # pyre-fixme[3]: Return type must be annotated.
882+ # pyre-fixme[2]: Parameter must be annotated.
798883 def pre_hook (module , inp ):
884+ # pyre-fixme[53]: Captured variable `module` is not annotated.
885+ # pyre-fixme[3]: Return type must be annotated.
799886 def input_tensor_hook (input_grad : Tensor ):
800887 if len (grad_out ) == 0 :
801888 return
@@ -820,6 +907,7 @@ def input_tensor_hook(input_grad: Tensor):
820907 ]
821908
822909
910+ # pyre-fixme[3]: Return type must be annotated.
823911def _get_max_feature_index (feature_mask : Tuple [Tensor , ...]):
824912 """
825913 Returns the max feature mask index
0 commit comments