Skip to content

Commit 514c1a5

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Improving typing of additional_forward_args (#1425)
Summary: Pull Request resolved: #1425 Change `additional_forward_args` to resolve pyre errors and address the feedback on D64998803. Reviewed By: cyrjano Differential Revision: D65178564
1 parent 07470af commit 514c1a5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+172
-210
lines changed

captum/_utils/av.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,7 @@ def _compute_and_save_activations(
351351
inputs: Union[Tensor, Tuple[Tensor, ...]],
352352
identifier: str,
353353
num_id: str,
354-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
355-
additional_forward_args: Any = None,
354+
additional_forward_args: Optional[object] = None,
356355
load_from_disk: bool = True,
357356
) -> None:
358357
r"""

captum/_utils/common.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,25 @@ def _format_float_or_tensor_into_tuples(
273273
return inputs
274274

275275

276+
@overload
277+
def _format_additional_forward_args(
278+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
279+
additional_forward_args: Union[Tensor, Tuple]
280+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
281+
) -> Tuple: ...
282+
283+
284+
@overload
285+
def _format_additional_forward_args( # type: ignore
286+
additional_forward_args: Optional[object],
287+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
288+
) -> Union[None, Tuple]: ...
289+
290+
276291
def _format_additional_forward_args(
277292
additional_forward_args: Optional[object],
278-
) -> Union[None, Tuple[object, ...]]:
293+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
294+
) -> Union[None, Tuple]:
279295
if additional_forward_args is not None and not isinstance(
280296
additional_forward_args, tuple
281297
):
@@ -284,8 +300,8 @@ def _format_additional_forward_args(
284300

285301

286302
def _expand_additional_forward_args(
287-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
288-
additional_forward_args: Any,
303+
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
304+
additional_forward_args: Union[None, Tuple],
289305
n_steps: int,
290306
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
291307
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
@@ -557,8 +573,7 @@ def _run_forward(
557573
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
558574
inputs: Any,
559575
target: TargetType = None,
560-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
561-
additional_forward_args: Any = None,
576+
additional_forward_args: Optional[object] = None,
562577
) -> Union[Tensor, Future[Tensor]]:
563578
forward_func_args = signature(forward_func).parameters
564579
if len(forward_func_args) == 0:

captum/_utils/gradient.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def compute_gradients(
104104
forward_fn: Callable,
105105
inputs: Union[Tensor, Tuple[Tensor, ...]],
106106
target_ind: TargetType = None,
107-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
108-
additional_forward_args: Any = None,
107+
additional_forward_args: Optional[object] = None,
109108
) -> Tuple[Tensor, ...]:
110109
r"""
111110
Computes gradients of the output with respect to inputs for an
@@ -175,8 +174,7 @@ def _forward_layer_eval(
175174
forward_fn: Callable,
176175
inputs: Union[Tensor, Tuple[Tensor, ...]],
177176
layer: List[Module],
178-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
179-
additional_forward_args: Any = None,
177+
additional_forward_args: Optional[object] = None,
180178
device_ids: Union[None, List[int]] = None,
181179
attribute_to_layer_input: bool = False,
182180
grad_enabled: bool = False,
@@ -191,8 +189,7 @@ def _forward_layer_eval(
191189
forward_fn: Callable,
192190
inputs: Union[Tensor, Tuple[Tensor, ...]],
193191
layer: Module,
194-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
195-
additional_forward_args: Any = None,
192+
additional_forward_args: Optional[object] = None,
196193
device_ids: Union[None, List[int]] = None,
197194
attribute_to_layer_input: bool = False,
198195
grad_enabled: bool = False,
@@ -204,7 +201,7 @@ def _forward_layer_eval(
204201
forward_fn: Callable,
205202
inputs: Union[Tensor, Tuple[Tensor, ...]],
206203
layer: ModuleOrModuleList,
207-
additional_forward_args: Any = None,
204+
additional_forward_args: Optional[object] = None,
208205
device_ids: Union[None, List[int]] = None,
209206
attribute_to_layer_input: bool = False,
210207
grad_enabled: bool = False,
@@ -233,8 +230,7 @@ def _forward_layer_distributed_eval(
233230
inputs: Any,
234231
layer: ModuleOrModuleList,
235232
target_ind: TargetType = None,
236-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
237-
additional_forward_args: Any = None,
233+
additional_forward_args: Optional[object] = None,
238234
attribute_to_layer_input: bool = False,
239235
forward_hook_with_return: Literal[False] = False,
240236
require_layer_grads: bool = False,
@@ -250,7 +246,7 @@ def _forward_layer_distributed_eval(
250246
inputs: Any,
251247
layer: ModuleOrModuleList,
252248
target_ind: TargetType = None,
253-
additional_forward_args: Any = None,
249+
additional_forward_args: Optional[object] = None,
254250
attribute_to_layer_input: bool = False,
255251
*,
256252
forward_hook_with_return: Literal[True],
@@ -264,7 +260,7 @@ def _forward_layer_distributed_eval(
264260
inputs: Any,
265261
layer: ModuleOrModuleList,
266262
target_ind: TargetType = None,
267-
additional_forward_args: Any = None,
263+
additional_forward_args: Optional[object] = None,
268264
attribute_to_layer_input: bool = False,
269265
forward_hook_with_return: bool = False,
270266
require_layer_grads: bool = False,
@@ -427,8 +423,7 @@ def _forward_layer_eval_with_neuron_grads(
427423
forward_fn: Callable,
428424
inputs: Union[Tensor, Tuple[Tensor, ...]],
429425
layer: Module,
430-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
431-
additional_forward_args: Any = None,
426+
additional_forward_args: Optional[object] = None,
432427
*,
433428
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
434429
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
@@ -446,7 +441,7 @@ def _forward_layer_eval_with_neuron_grads(
446441
forward_fn: Callable,
447442
inputs: Union[Tensor, Tuple[Tensor, ...]],
448443
layer: List[Module],
449-
additional_forward_args: Any = None,
444+
additional_forward_args: Optional[object] = None,
450445
gradient_neuron_selector: None = None,
451446
grad_enabled: bool = False,
452447
device_ids: Union[None, List[int]] = None,
@@ -462,7 +457,7 @@ def _forward_layer_eval_with_neuron_grads(
462457
forward_fn: Callable,
463458
inputs: Union[Tensor, Tuple[Tensor, ...]],
464459
layer: Module,
465-
additional_forward_args: Any = None,
460+
additional_forward_args: Optional[object] = None,
466461
gradient_neuron_selector: None = None,
467462
grad_enabled: bool = False,
468463
device_ids: Union[None, List[int]] = None,
@@ -475,7 +470,7 @@ def _forward_layer_eval_with_neuron_grads(
475470
forward_fn: Callable,
476471
inputs: Union[Tensor, Tuple[Tensor, ...]],
477472
layer: ModuleOrModuleList,
478-
additional_forward_args: Any = None,
473+
additional_forward_args: Optional[object] = None,
479474
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
480475
gradient_neuron_selector: Union[
481476
None, int, Tuple[Union[int, slice], ...], Callable
@@ -549,8 +544,7 @@ def compute_layer_gradients_and_eval(
549544
layer: Module,
550545
inputs: Union[Tensor, Tuple[Tensor, ...]],
551546
target_ind: TargetType = None,
552-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
553-
additional_forward_args: Any = None,
547+
additional_forward_args: Optional[object] = None,
554548
*,
555549
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
556550
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
@@ -571,7 +565,7 @@ def compute_layer_gradients_and_eval(
571565
layer: List[Module],
572566
inputs: Union[Tensor, Tuple[Tensor, ...]],
573567
target_ind: TargetType = None,
574-
additional_forward_args: Any = None,
568+
additional_forward_args: Optional[object] = None,
575569
gradient_neuron_selector: None = None,
576570
device_ids: Union[None, List[int]] = None,
577571
attribute_to_layer_input: bool = False,
@@ -590,7 +584,7 @@ def compute_layer_gradients_and_eval(
590584
layer: Module,
591585
inputs: Union[Tensor, Tuple[Tensor, ...]],
592586
target_ind: TargetType = None,
593-
additional_forward_args: Any = None,
587+
additional_forward_args: Optional[object] = None,
594588
gradient_neuron_selector: None = None,
595589
device_ids: Union[None, List[int]] = None,
596590
attribute_to_layer_input: bool = False,
@@ -606,7 +600,7 @@ def compute_layer_gradients_and_eval(
606600
layer: ModuleOrModuleList,
607601
inputs: Union[Tensor, Tuple[Tensor, ...]],
608602
target_ind: TargetType = None,
609-
additional_forward_args: Any = None,
603+
additional_forward_args: Optional[object] = None,
610604
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
611605
gradient_neuron_selector: Union[
612606
None, int, Tuple[Union[int, slice], ...], Callable
@@ -792,8 +786,7 @@ def grad_fn(
792786
forward_fn: Callable,
793787
inputs: TensorOrTupleOfTensorsGeneric,
794788
target_ind: TargetType = None,
795-
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
796-
additional_forward_args: Any = None,
789+
additional_forward_args: Optional[object] = None,
797790
) -> Tuple[Tensor, ...]:
798791
_, grads = _forward_layer_eval_with_neuron_grads(
799792
forward_fn,

captum/attr/_core/deep_lift.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44
import typing
55
import warnings
6-
from typing import Callable, cast, Dict, List, Literal, Tuple, Type, Union
6+
from typing import Callable, cast, Dict, List, Literal, Optional, Tuple, Type, Union
77

88
import torch
99
import torch.nn as nn
@@ -117,7 +117,7 @@ def attribute(
117117
inputs: TensorOrTupleOfTensorsGeneric,
118118
baselines: BaselineType = None,
119119
target: TargetType = None,
120-
additional_forward_args: object = None,
120+
additional_forward_args: Optional[Tuple[object, ...]] = None,
121121
*,
122122
return_convergence_delta: Literal[True],
123123
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -129,7 +129,7 @@ def attribute(
129129
inputs: TensorOrTupleOfTensorsGeneric,
130130
baselines: BaselineType = None,
131131
target: TargetType = None,
132-
additional_forward_args: object = None,
132+
additional_forward_args: Optional[Tuple[object, ...]] = None,
133133
return_convergence_delta: Literal[False] = False,
134134
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
135135
) -> TensorOrTupleOfTensorsGeneric: ...
@@ -140,7 +140,7 @@ def attribute( # type: ignore
140140
inputs: TensorOrTupleOfTensorsGeneric,
141141
baselines: BaselineType = None,
142142
target: TargetType = None,
143-
additional_forward_args: object = None,
143+
additional_forward_args: Optional[Tuple[object, ...]] = None,
144144
return_convergence_delta: bool = False,
145145
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
146146
) -> Union[
@@ -370,7 +370,7 @@ def _construct_forward_func(
370370
forward_func: Callable[..., Tensor],
371371
inputs: Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
372372
target: TargetType = None,
373-
additional_forward_args: object = None,
373+
additional_forward_args: Optional[Tuple[object, ...]] = None,
374374
) -> Callable[[], Tensor]:
375375
def forward_fn() -> Tensor:
376376
model_out = cast(
@@ -604,7 +604,7 @@ def attribute(
604604
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
605605
],
606606
target: TargetType = None,
607-
additional_forward_args: object = None,
607+
additional_forward_args: Optional[Tuple[object, ...]] = None,
608608
*,
609609
return_convergence_delta: Literal[True],
610610
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
@@ -618,7 +618,7 @@ def attribute(
618618
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
619619
],
620620
target: TargetType = None,
621-
additional_forward_args: object = None,
621+
additional_forward_args: Optional[Tuple[object, ...]] = None,
622622
return_convergence_delta: Literal[False] = False,
623623
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
624624
) -> TensorOrTupleOfTensorsGeneric: ...
@@ -631,7 +631,7 @@ def attribute( # type: ignore
631631
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
632632
],
633633
target: TargetType = None,
634-
additional_forward_args: object = None,
634+
additional_forward_args: Optional[Tuple[object, ...]] = None,
635635
return_convergence_delta: bool = False,
636636
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
637637
) -> Union[
@@ -840,7 +840,7 @@ def _expand_inputs_baselines_targets(
840840
baselines: Tuple[Tensor, ...],
841841
inputs: Tuple[Tensor, ...],
842842
target: TargetType,
843-
additional_forward_args: object,
843+
additional_forward_args: Optional[Tuple[object, ...]],
844844
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, object]:
845845
inp_bsz = inputs[0].shape[0]
846846
base_bsz = baselines[0].shape[0]

captum/attr/_core/feature_ablation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def attribute(
7575
inputs: TensorOrTupleOfTensorsGeneric,
7676
baselines: BaselineType = None,
7777
target: TargetType = None,
78-
additional_forward_args: object = None,
78+
additional_forward_args: Optional[object] = None,
7979
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
8080
perturbations_per_eval: int = 1,
8181
show_progress: bool = False,
@@ -408,7 +408,7 @@ def attribute_future(
408408
inputs: TensorOrTupleOfTensorsGeneric,
409409
baselines: BaselineType = None,
410410
target: TargetType = None,
411-
additional_forward_args: object = None,
411+
additional_forward_args: Optional[object] = None,
412412
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
413413
perturbations_per_eval: int = 1,
414414
show_progress: bool = False,
@@ -655,7 +655,7 @@ def _ith_input_ablation_generator(
655655
self,
656656
i: int,
657657
inputs: TensorOrTupleOfTensorsGeneric,
658-
additional_args: object,
658+
additional_args: Optional[Tuple[object, ...]],
659659
target: TargetType,
660660
baselines: BaselineType,
661661
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],

captum/attr/_core/feature_permutation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Tuple, Union
4+
from typing import Any, Callable, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
@@ -99,7 +99,7 @@ def attribute( # type: ignore
9999
self,
100100
inputs: TensorOrTupleOfTensorsGeneric,
101101
target: TargetType = None,
102-
additional_forward_args: object = None,
102+
additional_forward_args: Optional[object] = None,
103103
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
104104
perturbations_per_eval: int = 1,
105105
show_progress: bool = False,
@@ -280,7 +280,7 @@ def attribute_future(
280280
self,
281281
inputs: TensorOrTupleOfTensorsGeneric,
282282
target: TargetType = None,
283-
additional_forward_args: object = None,
283+
additional_forward_args: Optional[object] = None,
284284
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
285285
perturbations_per_eval: int = 1,
286286
show_progress: bool = False,

0 commit comments

Comments
 (0)