Skip to content

Commit db00ed7

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix neuron conductance pyre fixme issues (#1458)
Summary: Pull Request resolved: #1458 Differential Revision: D67523217
1 parent 526ed18 commit db00ed7

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

captum/attr/_core/neuron/neuron_conductance.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
_verify_select_neuron,
1515
)
1616
from captum._utils.gradient import compute_layer_gradients_and_eval
17-
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
17+
from captum._utils.typing import (
18+
BaselineType,
19+
SliceIntType,
20+
TargetType,
21+
TensorOrTupleOfTensorsGeneric,
22+
)
1823
from captum.attr._utils.approximation_methods import approximation_parameters
1924
from captum.attr._utils.attribution import GradientAttribution, NeuronAttribution
2025
from captum.attr._utils.batching import _batch_attribution
@@ -39,8 +44,7 @@ class NeuronConductance(NeuronAttribution, GradientAttribution):
3944

4045
def __init__(
4146
self,
42-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
43-
forward_func: Callable,
47+
forward_func: Callable[..., Tensor],
4448
layer: Module,
4549
device_ids: Union[None, List[int]] = None,
4650
multiply_by_inputs: bool = True,
@@ -94,8 +98,11 @@ def __init__(
9498
def attribute(
9599
self,
96100
inputs: TensorOrTupleOfTensorsGeneric,
97-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
98-
neuron_selector: Union[int, Tuple[int, ...], Callable],
101+
neuron_selector: Union[
102+
int,
103+
Tuple[Union[int, SliceIntType], ...],
104+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
105+
],
99106
baselines: BaselineType = None,
100107
target: TargetType = None,
101108
additional_forward_args: Optional[object] = None,
@@ -285,28 +292,24 @@ def attribute(
285292
" results.",
286293
stacklevel=1,
287294
)
288-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
289-
# `TensorOrTupleOfTensorsGeneric`.
290295
is_inputs_tuple = _is_tuple(inputs)
291296

292-
# pyre-fixme[9]: inputs has type `TensorOrTupleOfTensorsGeneric`; used as
293-
# `Tuple[Tensor, ...]`.
294-
inputs, baselines = _format_input_baseline(inputs, baselines)
295-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
296-
# `TensorOrTupleOfTensorsGeneric`.
297-
_validate_input(inputs, baselines, n_steps, method)
297+
formatted_inputs, formatted_baselines = _format_input_baseline(
298+
inputs, baselines
299+
)
300+
_validate_input(formatted_inputs, formatted_baselines, n_steps, method)
298301

299-
num_examples = inputs[0].shape[0]
302+
num_examples = formatted_inputs[0].shape[0]
300303

301304
if internal_batch_size is not None:
302-
num_examples = inputs[0].shape[0]
305+
num_examples = formatted_inputs[0].shape[0]
303306
attrs = _batch_attribution(
304307
self,
305308
num_examples,
306309
internal_batch_size,
307310
n_steps,
308-
inputs=inputs,
309-
baselines=baselines,
311+
inputs=formatted_inputs,
312+
baselines=formatted_baselines,
310313
neuron_selector=neuron_selector,
311314
target=target,
312315
additional_forward_args=additional_forward_args,
@@ -315,11 +318,9 @@ def attribute(
315318
)
316319
else:
317320
attrs = self._attribute(
318-
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but
319-
# got `TensorOrTupleOfTensorsGeneric`.
320-
inputs=inputs,
321+
inputs=formatted_inputs,
321322
neuron_selector=neuron_selector,
322-
baselines=baselines,
323+
baselines=formatted_baselines,
323324
target=target,
324325
additional_forward_args=additional_forward_args,
325326
n_steps=n_steps,
@@ -334,8 +335,11 @@ def attribute(
334335
def _attribute(
335336
self,
336337
inputs: Tuple[Tensor, ...],
337-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
338-
neuron_selector: Union[int, Tuple[int, ...], Callable],
338+
neuron_selector: Union[
339+
int,
340+
Tuple[Union[int, SliceIntType], ...],
341+
Callable[[Union[Tensor, Tuple[Tensor, ...]]], Tensor],
342+
],
339343
baselines: Tuple[Union[Tensor, int, float], ...],
340344
target: TargetType = None,
341345
additional_forward_args: Optional[object] = None,
@@ -409,8 +413,9 @@ def _attribute(
409413

410414
# Aggregates across all steps for each tensor in the input tuple
411415
total_grads = tuple(
412-
# pyre-fixme[6]: For 4th argument expected `Tuple[int, ...]` but got `Size`.
413-
_reshape_and_sum(scaled_grad, n_steps, num_examples, input_grad.shape[1:])
416+
_reshape_and_sum(
417+
scaled_grad, n_steps, num_examples, tuple(input_grad.shape[1:])
418+
)
414419
for (scaled_grad, input_grad) in zip(scaled_grads, input_grads)
415420
)
416421

0 commit comments

Comments
 (0)