@@ -37,8 +37,7 @@ class Occlusion(FeatureAblation):
3737 /tensorflow/methods.py#L401
3838 """
3939
40- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
41- def __init__ (self , forward_func : Callable ) -> None :
40+ def __init__ (self , forward_func : Callable [..., Tensor ]) -> None :
4241 r"""
4342 Args:
4443
@@ -58,8 +57,7 @@ def attribute( # type: ignore
5857 ] = None ,
5958 baselines : BaselineType = None ,
6059 target : TargetType = None ,
61- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
62- additional_forward_args : Any = None ,
60+ additional_forward_args : object = None ,
6361 perturbations_per_eval : int = 1 ,
6462 show_progress : bool = False ,
6563 ) -> TensorOrTupleOfTensorsGeneric :
@@ -377,9 +375,7 @@ def _occlusion_mask(
377375 padded_tensor = torch .nn .functional .pad (
378376 sliding_window_tsr , tuple (pad_values ) # type: ignore
379377 )
380- # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` and
381- # `Size`.
382- return padded_tensor .reshape ((1 ,) + padded_tensor .shape )
378+ return padded_tensor .reshape ((1 ,) + tuple (padded_tensor .shape ))
383379
384380 def _get_feature_range_and_mask (
385381 self , input : Tensor , input_mask : Optional [Tensor ], ** kwargs : Any
@@ -389,8 +385,7 @@ def _get_feature_range_and_mask(
389385
390386 def _get_feature_counts (
391387 self ,
392- # pyre-fixme[2]: Parameter must be annotated.
393- inputs ,
388+ inputs : TensorOrTupleOfTensorsGeneric ,
394389 feature_mask : Tuple [Tensor , ...],
395390 ** kwargs : Any ,
396391 ) -> Tuple [int , ...]:
0 commit comments