Skip to content

Commit 787f17f

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Testing linux job fix (#1586)
Summary: Pull Request resolved: #1586 Rollback Plan: Differential Revision: D76657713
1 parent cdebed0 commit 787f17f

30 files changed

+149
-164
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
tests:
13-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
13+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1414
with:
1515
runner: linux.12xlarge
1616
docker-image: cimg/python:3.11

.github/workflows/test-pip-cpu-with-type-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
matrix:
1515
pytorch_args: ["", "-n"]
1616
fail-fast: false
17-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
17+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1818
with:
1919
runner: linux.12xlarge
2020
docker-image: cimg/python:3.11

.github/workflows/test-pip-cpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- pytorch_args: "-v 2.1.0"
3838
docker_img: "cimg/python:3.12"
3939
fail-fast: false
40-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
40+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
4141
with:
4242
runner: linux.12xlarge
4343
docker-image: ${{ matrix.docker_img }}

.github/workflows/test-pip-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
matrix:
1515
cuda_arch_version: ["12.1"]
1616
fail-fast: false
17-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
17+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1818
with:
1919
runner: linux.4xlarge.nvidia.gpu
2020
repository: pytorch/captum

captum/_utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _is_tuple(inputs: Tensor) -> Literal[False]: ...
9191

9292
@typing.overload
9393
def _is_tuple(
94-
inputs: TensorOrTupleOfTensorsGeneric, # type: ignore
94+
inputs: Union[Tensor, Tuple[Tensor, ...]],
9595
) -> bool: ...
9696

9797

captum/_utils/models/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ class Model(ABC):
2222
def fit(
2323
self,
2424
train_data: DataLoader,
25-
# pyre-fixme[2]: Parameter must be annotated.
26-
**kwargs,
25+
**kwargs: object,
2726
) -> Optional[Dict[str, Union[int, float, Tensor]]]:
2827
r"""
2928
Override this method to actually train your model.

captum/attr/_core/deep_lift.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,21 @@ def attribute( # type: ignore
348348
self._remove_hooks(main_model_hooks)
349349

350350
undo_gradient_requirements(inputs_tuple, gradient_mask)
351-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric...
352-
return _compute_conv_delta_and_format_attrs(
353-
self,
354-
return_convergence_delta,
355-
attributions,
356-
baselines,
357-
inputs_tuple,
358-
additional_forward_args,
359-
target,
360-
is_inputs_tuple,
351+
return cast(
352+
TensorOrTupleOfTensorsGeneric,
353+
_compute_conv_delta_and_format_attrs(
354+
self,
355+
return_convergence_delta,
356+
attributions,
357+
baselines,
358+
inputs_tuple,
359+
additional_forward_args,
360+
target,
361+
is_inputs_tuple,
362+
),
361363
)
362364

363-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
364-
def attribute_future(self) -> Callable:
365+
def attribute_future(self) -> None:
365366
r"""
366367
This method is not implemented for DeepLift.
367368
"""
@@ -831,11 +832,18 @@ def attribute( # type: ignore
831832
)
832833

833834
if return_convergence_delta:
834-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
835-
return _format_output(is_inputs_tuple, attributions), delta
835+
return (
836+
cast(
837+
TensorOrTupleOfTensorsGeneric,
838+
_format_output(is_inputs_tuple, attributions),
839+
),
840+
delta,
841+
)
836842
else:
837-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
838-
return _format_output(is_inputs_tuple, attributions)
843+
return cast(
844+
TensorOrTupleOfTensorsGeneric,
845+
_format_output(is_inputs_tuple, attributions),
846+
)
839847

840848
def _expand_inputs_baselines_targets(
841849
self,
@@ -995,10 +1003,8 @@ def maxpool3d(
9951003

9961004
def maxpool(
9971005
module: Module,
998-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
999-
pool_func: Callable,
1000-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
1001-
unpool_func: Callable,
1006+
pool_func: Callable[..., Tuple[Tensor, Tensor]],
1007+
unpool_func: Callable[..., Tensor],
10021008
inputs: Tensor,
10031009
outputs: Tensor,
10041010
grad_input: Tensor,

captum/attr/_core/feature_ablation.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -401,10 +401,10 @@ def attribute(
401401
if attr_progress is not None:
402402
attr_progress.close()
403403

404-
# pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
405-
# [Tensor, typing.Tuple[Tensor, ...]]]`
406-
# but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
407-
return self._generate_result(total_attrib, weights, is_inputs_tuple) # type: ignore # noqa: E501 line too long
404+
return cast(
405+
TensorOrTupleOfTensorsGeneric,
406+
self._generate_result(total_attrib, weights, is_inputs_tuple),
407+
)
408408

409409
def _attribute_with_independent_feature_masks(
410410
self,
@@ -629,8 +629,7 @@ def _should_skip_inputs_and_warn(
629629
all_empty = False
630630
if self._min_examples_per_batch_grouped is not None and (
631631
formatted_inputs[tensor_idx].shape[0]
632-
# pyre-ignore[58]: Type has been narrowed to int
633-
< self._min_examples_per_batch_grouped
632+
< cast(int, self._min_examples_per_batch_grouped)
634633
):
635634
should_skip = True
636635
break
@@ -789,35 +788,35 @@ def attribute_future(
789788
)
790789

791790
if enable_cross_tensor_attribution:
792-
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
793-
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
794-
# `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
795-
return self._attribute_with_cross_tensor_feature_masks_future( # type: ignore # noqa: E501 line too long
796-
formatted_inputs=formatted_inputs,
797-
formatted_additional_forward_args=formatted_additional_forward_args,
798-
target=target,
799-
baselines=baselines,
800-
formatted_feature_mask=formatted_feature_mask,
801-
attr_progress=attr_progress,
802-
processed_initial_eval_fut=processed_initial_eval_fut,
803-
is_inputs_tuple=is_inputs_tuple,
804-
perturbations_per_eval=perturbations_per_eval,
791+
return cast(
792+
Future[TensorOrTupleOfTensorsGeneric],
793+
self._attribute_with_cross_tensor_feature_masks_future( # type: ignore
794+
formatted_inputs=formatted_inputs,
795+
formatted_additional_forward_args=formatted_additional_forward_args, # noqa: E501 line too long
796+
target=target,
797+
baselines=baselines,
798+
formatted_feature_mask=formatted_feature_mask,
799+
attr_progress=attr_progress,
800+
processed_initial_eval_fut=processed_initial_eval_fut,
801+
is_inputs_tuple=is_inputs_tuple,
802+
perturbations_per_eval=perturbations_per_eval,
803+
),
805804
)
806805
else:
807-
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
808-
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
809-
# `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
810-
return self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long
811-
formatted_inputs,
812-
formatted_additional_forward_args,
813-
target,
814-
baselines,
815-
formatted_feature_mask,
816-
perturbations_per_eval,
817-
attr_progress,
818-
processed_initial_eval_fut,
819-
is_inputs_tuple,
820-
**kwargs,
806+
return cast(
807+
Future[TensorOrTupleOfTensorsGeneric],
808+
self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long
809+
formatted_inputs,
810+
formatted_additional_forward_args,
811+
target,
812+
baselines,
813+
formatted_feature_mask,
814+
perturbations_per_eval,
815+
attr_progress,
816+
processed_initial_eval_fut,
817+
is_inputs_tuple,
818+
**kwargs,
819+
),
821820
)
822821

823822
def _attribute_with_independent_feature_masks_future(

captum/attr/_core/guided_backprop_deconvnet.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import warnings
5-
from typing import Callable, List, Optional, Tuple, Union
5+
from typing import cast, List, Optional, Tuple, Union
66

77
import torch
88
import torch.nn.functional as F
@@ -78,12 +78,11 @@ def attribute(
7878
self._remove_hooks()
7979

8080
undo_gradient_requirements(inputs_tuple, gradient_mask)
81-
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
82-
# `Tuple[Tensor, ...]`.
83-
return _format_output(is_inputs_tuple, gradients)
81+
return cast(
82+
TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, gradients)
83+
)
8484

85-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
86-
def attribute_future(self) -> Callable:
85+
def attribute_future(self) -> None:
8786
r"""
8887
This method is not implemented for ModifiedReluGradientAttribution.
8988
"""

captum/attr/_core/guided_grad_cam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import warnings
5-
from typing import List, Optional, Union
5+
from typing import cast, List, Optional, Union
66

77
import torch
88
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
@@ -223,6 +223,7 @@ def attribute(
223223
)
224224
output_attr.append(torch.empty(0))
225225

226-
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
227-
# `Tuple[Tensor, ...]`.
228-
return _format_output(is_inputs_tuple, tuple(output_attr))
226+
return cast(
227+
TensorOrTupleOfTensorsGeneric,
228+
_format_output(is_inputs_tuple, tuple(output_attr)),
229+
)

captum/attr/_core/input_x_gradient.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Callable, Optional
4+
from typing import Callable, cast, Optional
55

66
from captum._utils.common import _format_output, _format_tensor_into_tuples, _is_tuple
77
from captum._utils.gradient import (
@@ -126,12 +126,11 @@ def attribute(
126126
)
127127

128128
undo_gradient_requirements(inputs_tuple, gradient_mask)
129-
# pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
130-
# `Tuple[Tensor, ...]`.
131-
return _format_output(is_inputs_tuple, attributions)
129+
return cast(
130+
TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions)
131+
)
132132

133-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
134-
def attribute_future(self) -> Callable:
133+
def attribute_future(self) -> None:
135134
r"""
136135
This method is not implemented for InputXGradient.
137136
"""

captum/attr/_core/integrated_gradients.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
import typing
5-
from typing import Callable, List, Literal, Optional, Tuple, Union
5+
from typing import Callable, cast, List, Literal, Optional, Tuple, Union
66

77
import torch
88
from captum._utils.common import (
@@ -301,16 +301,18 @@ def attribute( # type: ignore
301301
additional_forward_args=additional_forward_args,
302302
target=target,
303303
)
304-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
305-
return _format_output(is_inputs_tuple, attributions), delta
306-
# pyre-fixme[7]: Expected
307-
# `Union[Tuple[Variable[TensorOrTupleOfTensorsGeneric <: [Tensor,
308-
# typing.Tuple[Tensor, ...]]], Tensor], Variable[TensorOrTupleOfTensorsGeneric
309-
# <: [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Tuple[Tensor, ...]`.
310-
return _format_output(is_inputs_tuple, attributions)
311-
312-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
313-
def attribute_future(self) -> Callable:
304+
return (
305+
cast(
306+
TensorOrTupleOfTensorsGeneric,
307+
_format_output(is_inputs_tuple, attributions),
308+
),
309+
delta,
310+
)
311+
return cast(
312+
TensorOrTupleOfTensorsGeneric, _format_output(is_inputs_tuple, attributions)
313+
)
314+
315+
def attribute_future(self) -> None:
314316
r"""
315317
This method is not implemented for IntegratedGradients.
316318
"""

captum/attr/_core/kernel_shap.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,7 @@ def attribute( # type: ignore
292292
show_progress=show_progress,
293293
)
294294

295-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
296-
def attribute_future(self) -> Callable:
295+
def attribute_future(self) -> None:
297296
r"""
298297
This method is not implemented for KernelShap.
299298
"""

captum/attr/_core/lime.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,7 @@ def generate_perturbation() -> (
548548

549549
return generate_perturbation
550550

551-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
552-
def attribute_future(self) -> Callable:
551+
def attribute_future(self) -> None:
553552
r"""
554553
This method is not implemented for LimeBase.
555554
"""
@@ -1116,8 +1115,7 @@ def attribute( # type: ignore
11161115
show_progress=show_progress,
11171116
)
11181117

1119-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
1120-
def attribute_future(self) -> Callable:
1118+
def attribute_future(self) -> None:
11211119
return super().attribute_future()
11221120

11231121
def _attribute_kwargs( # type: ignore

captum/attr/_core/lrp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import typing
66
from collections import defaultdict
7-
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Tuple, Union
7+
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union
88

99
import torch.nn as nn
1010
from captum._utils.common import (
@@ -230,16 +230,17 @@ def attribute(
230230
undo_gradient_requirements(input_tuple, gradient_mask)
231231

232232
if return_convergence_delta:
233-
# pyre-fixme[7]: Expected `Union[Tuple[Variable[TensorOrTupleOfTensorsGen...
234233
return (
235-
_format_output(is_inputs_tuple, relevances),
234+
cast(
235+
TensorOrTupleOfTensorsGeneric,
236+
_format_output(is_inputs_tuple, relevances),
237+
),
236238
self.compute_convergence_delta(relevances, output),
237239
)
238240
else:
239241
return _format_output(is_inputs_tuple, relevances) # type: ignore
240242

241-
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
242-
def attribute_future(self) -> Callable:
243+
def attribute_future(self) -> None:
243244
r"""
244245
This method is not implemented for LRP.
245246
"""

0 commit comments

Comments
 (0)