@@ -47,8 +47,9 @@ class FeatureAblation(PerturbationAttribution):
4747 first dimension (i.e. a feature mask requires to be applied to all inputs).
4848 """
4949
50- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
51- def __init__ (self , forward_func : Callable ) -> None :
50+ def __init__ (
51+ self , forward_func : Callable [..., Union [int , float , Tensor , Future [Tensor ]]]
52+ ) -> None :
5253 r"""
5354 Args:
5455
@@ -74,8 +75,7 @@ def attribute(
7475 inputs : TensorOrTupleOfTensorsGeneric ,
7576 baselines : BaselineType = None ,
7677 target : TargetType = None ,
77- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
78- additional_forward_args : Any = None ,
78+ additional_forward_args : object = None ,
7979 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
8080 perturbations_per_eval : int = 1 ,
8181 show_progress : bool = False ,
@@ -261,17 +261,13 @@ def attribute(
261261 """
262262 # Keeps track whether original input is a tuple or not before
263263 # converting it into a tuple.
264- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
265- # `TensorOrTupleOfTensorsGeneric`.
266264 is_inputs_tuple = _is_tuple (inputs )
267265
268266 formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
269267 formatted_additional_forward_args = _format_additional_forward_args (
270268 additional_forward_args
271269 )
272270 num_examples = formatted_inputs [0 ].shape [0 ]
273- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
274- # `TensorOrTupleOfTensorsGeneric`.
275271 formatted_feature_mask = _format_feature_mask (feature_mask , formatted_inputs )
276272
277273 assert (
@@ -384,8 +380,6 @@ def attribute(
384380 # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
385381 # [Tensor, typing.Tuple[Tensor, ...]]]`
386382 # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
387- # pyre-fixme[6]: In call `FeatureAblation._generate_result`,
388- # for 3rd positional argument, expected `bool` but got `Literal[]`.
389383 return self ._generate_result (total_attrib , weights , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
390384
391385 def _initial_eval_to_processed_initial_eval_fut (
@@ -414,8 +408,7 @@ def attribute_future(
414408 inputs : TensorOrTupleOfTensorsGeneric ,
415409 baselines : BaselineType = None ,
416410 target : TargetType = None ,
417- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
418- additional_forward_args : Any = None ,
411+ additional_forward_args : object = None ,
419412 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
420413 perturbations_per_eval : int = 1 ,
421414 show_progress : bool = False ,
@@ -428,8 +421,6 @@ def attribute_future(
428421
429422 # Keeps track whether original input is a tuple or not before
430423 # converting it into a tuple.
431- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
432- # `TensorOrTupleOfTensorsGeneric`.
433424 is_inputs_tuple = _is_tuple (inputs )
434425 formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
435426 formatted_additional_forward_args = _format_additional_forward_args (
@@ -660,13 +651,11 @@ def _eval_fut_to_ablated_out_fut(
660651 ) from e
661652 return result
662653
663- # pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664654 def _ith_input_ablation_generator (
665655 self ,
666656 i : int ,
667657 inputs : TensorOrTupleOfTensorsGeneric ,
668- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
669- additional_args : Any ,
658+ additional_args : object ,
670659 target : TargetType ,
671660 baselines : BaselineType ,
672661 input_mask : Union [None , Tensor , Tuple [Tensor , ...]],
@@ -675,7 +664,7 @@ def _ith_input_ablation_generator(
675664 ) -> Generator [
676665 Tuple [
677666 Tuple [Tensor , ...],
678- Any ,
667+ object ,
679668 TargetType ,
680669 Tensor ,
681670 ],
@@ -705,10 +694,9 @@ def _ith_input_ablation_generator(
705694 perturbations_per_eval = min (perturbations_per_eval , num_features )
706695 baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
707696 if isinstance (baseline , torch .Tensor ):
708- # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]`
709- # and `Size`.
710- baseline = baseline .reshape ((1 ,) + baseline .shape )
697+ baseline = baseline .reshape ((1 ,) + tuple (baseline .shape ))
711698
699+ additional_args_repeated : object
712700 if perturbations_per_eval > 1 :
713701 # Repeat features and additional args for batch size.
714702 all_features_repeated = [
@@ -727,6 +715,7 @@ def _ith_input_ablation_generator(
727715 target_repeated = target
728716
729717 num_features_processed = min_feature
718+ current_additional_args : object
730719 while num_features_processed < num_features :
731720 current_num_ablated_features = min (
732721 perturbations_per_eval , num_features - num_features_processed
@@ -762,9 +751,7 @@ def _ith_input_ablation_generator(
762751 # dimension of this tensor.
763752 current_reshaped = current_features [i ].reshape (
764753 (current_num_ablated_features , - 1 )
765- # pyre-fixme[58]: `+` is not supported for operand types
766- # `Tuple[int, int]` and `Size`.
767- + current_features [i ].shape [1 :]
754+ + tuple (current_features [i ].shape [1 :])
768755 )
769756
770757 ablated_features , current_mask = self ._construct_ablated_input (
@@ -780,10 +767,7 @@ def _ith_input_ablation_generator(
780767 # (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
781768 # which can be provided to the model as input.
782769 current_features [i ] = ablated_features .reshape (
783- (- 1 ,)
784- # pyre-fixme[58]: `+` is not supported for operand types
785- # `Tuple[int]` and `Size`.
786- + ablated_features .shape [2 :]
770+ (- 1 ,) + tuple (ablated_features .shape [2 :])
787771 )
788772 yield tuple (
789773 current_features
@@ -818,9 +802,7 @@ def _construct_ablated_input(
818802 thus counted towards ablations for that feature) and 0s otherwise.
819803 """
820804 current_mask = torch .stack (
821- # pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
822- # Tuple[Tensor, ...]]` but got `List[Union[bool, Tensor]]`.
823- [input_mask == j for j in range (start_feature , end_feature )], # type: ignore # noqa: E501 line too long
805+ cast (List [Tensor ], [input_mask == j for j in range (start_feature , end_feature )]), # type: ignore # noqa: E501 line too long
824806 dim = 0 ,
825807 ).long ()
826808 current_mask = current_mask .to (expanded_input .device )
0 commit comments