@@ -363,10 +363,9 @@ def _expand_target(
363363 return target
364364
365365
366- # pyre-fixme[3]: Return type must be annotated.
367366def _expand_feature_mask (
368367 feature_mask : Union [Tensor , Tuple [Tensor , ...]], n_samples : int
369- ):
368+ ) -> Tuple [ Tensor , ...] :
370369 # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor,
371370 # typing.Tuple[Tensor, ...]]`.
372371 is_feature_mask_tuple = _is_tuple (feature_mask )
@@ -379,18 +378,17 @@ def _expand_feature_mask(
379378 )
380379 for feature_mask_elem in feature_mask
381380 )
382- return _format_output (is_feature_mask_tuple , feature_mask_new )
381+ return _format_output (is_feature_mask_tuple , feature_mask_new ) # type: ignore
383382
384383
385- # pyre-fixme[3]: Return type must be annotated.
386384def _expand_and_update_baselines (
387385 inputs : Tuple [Tensor , ...],
388386 n_samples : int ,
389387 # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
390388 # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
391389 kwargs : dict ,
392390 draw_baseline_from_distrib : bool = False ,
393- ):
391+ ) -> None :
394392 # pyre-fixme[3]: Return type must be annotated.
395393 # pyre-fixme[2]: Parameter must be annotated.
396394 def get_random_baseline_indices (bsz , baseline ):
@@ -432,10 +430,9 @@ def get_random_baseline_indices(bsz, baseline):
432430 kwargs ["baselines" ] = baselines
433431
434432
435- # pyre-fixme[3]: Return type must be annotated.
436433# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
437434# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
438- def _expand_and_update_additional_forward_args (n_samples : int , kwargs : dict ):
435+ def _expand_and_update_additional_forward_args (n_samples : int , kwargs : dict ) -> None :
439436 if "additional_forward_args" not in kwargs :
440437 return
441438 additional_forward_args = kwargs ["additional_forward_args" ]
@@ -451,10 +448,9 @@ def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict):
451448 kwargs ["additional_forward_args" ] = additional_forward_args
452449
453450
454- # pyre-fixme[3]: Return type must be annotated.
455451# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
456452# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
457- def _expand_and_update_target (n_samples : int , kwargs : dict ):
453+ def _expand_and_update_target (n_samples : int , kwargs : dict ) -> None :
458454 if "target" not in kwargs :
459455 return
460456 target = kwargs ["target" ]
@@ -465,10 +461,9 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
465461 kwargs ["target" ] = target
466462
467463
468- # pyre-fixme[3]: Return type must be annotated.
469464# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
470465# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
471- def _expand_and_update_feature_mask (n_samples : int , kwargs : dict ):
466+ def _expand_and_update_feature_mask (n_samples : int , kwargs : dict ) -> None :
472467 if "feature_mask" not in kwargs :
473468 return
474469
@@ -573,10 +568,9 @@ def _format_outputs(
573568# pyre-fixme[24] Callable requires 2 arguments
574569def _construct_future_forward (original_forward : Callable ) -> Callable :
575570 # pyre-fixme[3] return type not specified
576- # pyre-ignore
577- def future_forward (* args , ** kwargs ):
578- # pyre-ignore
579- fut = torch .futures .Future ()
571+ def future_forward (* args : Any , ** kwargs : Any ):
572+ # pyre-fixme[29]: `typing.Type[torch.futures.Future]` is not a function.
573+ fut : torch .futures .Future [Tensor ] = torch .futures .Future ()
580574 fut .set_result (original_forward (* args , ** kwargs ))
581575 return fut
582576
@@ -921,8 +915,7 @@ def input_tensor_hook(input_grad: Tensor):
921915 ]
922916
923917
924- # pyre-fixme[3]: Return type must be annotated.
925- def _get_max_feature_index (feature_mask : Tuple [Tensor , ...]):
918+ def _get_max_feature_index (feature_mask : Tuple [Tensor , ...]) -> int :
926919 """
927920 Returns the max feature mask index
928921 The feature mask should be formatted to tuple of tensors at first.
0 commit comments