11#!/usr/bin/env python3
22
33# pyre-strict
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
55
66import torch
77import torch .nn .functional as F
@@ -54,8 +54,7 @@ class LayerGradCam(LayerAttribution, GradientAttribution):
5454
5555 def __init__ (
5656 self ,
57- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
58- forward_func : Callable ,
57+ forward_func : Callable [..., Tensor ],
5958 layer : Module ,
6059 device_ids : Union [None , List [int ]] = None ,
6160 ) -> None :
@@ -201,7 +200,7 @@ def attribute(
201200 # hidden layer and hidden layer evaluated at each input.
202201 layer_gradients , layer_evals = compute_layer_gradients_and_eval (
203202 self .forward_func ,
204- self .layer ,
203+ cast ( Module , self .layer ) ,
205204 inputs ,
206205 target ,
207206 additional_forward_args ,
@@ -213,10 +212,7 @@ def attribute(
213212 summed_grads = tuple (
214213 (
215214 torch .mean (
216- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
217- # `Tuple[Tensor, ...]`.
218215 layer_grad ,
219- # pyre-fixme[16]: `tuple` has no attribute `shape`.
220216 dim = tuple (x for x in range (2 , len (layer_grad .shape ))),
221217 keepdim = True ,
222218 )
@@ -228,27 +224,15 @@ def attribute(
228224
229225 if attr_dim_summation :
230226 scaled_acts = tuple (
231- # pyre-fixme[58]: `*` is not supported for operand types
232- # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
233- # `Tuple[Tensor, ...]`.
234- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
235- # `Tuple[Tensor, ...]`.
236227 torch .sum (summed_grad * layer_eval , dim = 1 , keepdim = True )
237228 for summed_grad , layer_eval in zip (summed_grads , layer_evals )
238229 )
239230 else :
240231 scaled_acts = tuple (
241- # pyre-fixme[58]: `*` is not supported for operand types
242- # `Union[tuple[torch._tensor.Tensor], torch._tensor.Tensor]` and
243- # `Tuple[Tensor, ...]`.
244232 summed_grad * layer_eval
245233 for summed_grad , layer_eval in zip (summed_grads , layer_evals )
246234 )
247235
248236 if relu_attributions :
249- # pyre-fixme[6]: For 1st argument expected `Tensor` but got
250- # `Union[tuple[Tensor], Tensor]`.
251237 scaled_acts = tuple (F .relu (scaled_act ) for scaled_act in scaled_acts )
252- # pyre-fixme[6]: For 2nd argument expected `Tuple[Tensor, ...]` but got
253- # `Tuple[Union[tuple[Tensor], Tensor], ...]`.
254238 return _format_output (len (scaled_acts ) > 1 , scaled_acts )
0 commit comments