Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ build:

requirements:
host:
- python>=3.6
- python>=3.8
run:
- numpy
- pytorch>=1.6
- numpy<2.0
- pytorch>=1.10
- matplotlib-base

test:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.12xlarge
docker-image: cimg/python:3.9
docker-image: cimg/python:3.11
repository: pytorch/captum
script: |
sudo chmod -R 777 .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-conda-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
tests:
strategy:
matrix:
python_version: ["3.7", "3.8", "3.9", "3.10"]
python_version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-pip-cpu-with-mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.12xlarge
docker-image: cimg/python:3.6
docker-image: cimg/python:3.11
repository: pytorch/captum
script: |
sudo chmod -R 777 .
Expand Down
17 changes: 8 additions & 9 deletions .github/workflows/test-pip-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@ jobs:
tests:
strategy:
matrix:
pytorch_args: ["-v 1.6", "-v 1.7", "-v 1.8", "-v 1.9", "-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13"]
docker_img: ["cimg/python:3.6", "cimg/python:3.7"]
include:
- pytorch_args: "-v 2.0"
docker_img: "cimg/python:3.8"
pytorch_args: ["-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0"]
docker_img: ["cimg/python:3.8", "cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11"]
exclude:
- pytorch_args: "-v 1.10"
docker_img: "cimg/python:3.10"
- pytorch_args: "-v 1.10"
docker_img: "cimg/python:3.11"
- pytorch_args: "-v 1.11"
docker_img: "cimg/python:3.6"
docker_img: "cimg/python:3.11"
- pytorch_args: "-v 1.12"
docker_img: "cimg/python:3.6"
- pytorch_args: "-v 1.13"
docker_img: "cimg/python:3.6"
docker_img: "cimg/python:3.11"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ Captum can also be used by application engineers who are using trained models in
## Installation

**Installation Requirements**
- Python >= 3.6
- PyTorch >= 1.6
- Python >= 3.8
- PyTorch >= 1.10


##### Installing the latest release
Expand Down
4 changes: 2 additions & 2 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def _compute_jacobian_wrt_params(
inputs: Tuple[Any, ...],
labels: Optional[Tensor] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
layer_modules: List[Module] = None,
layer_modules: Optional[List[Module]] = None,
) -> Tuple[Tensor, ...]:
r"""
Computes the Jacobian of a batch of test examples given a model, and optional
Expand Down Expand Up @@ -805,7 +805,7 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
labels: Optional[Tensor] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
reduction_type: Optional[str] = "sum",
layer_modules: List[Module] = None,
layer_modules: Optional[List[Module]] = None,
) -> Tuple[Any, ...]:
r"""
Computes the Jacobian of a batch of test examples given a model, and optional
Expand Down
20 changes: 10 additions & 10 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import warnings
from time import time
from typing import cast, Iterable, Sized, TextIO
from typing import cast, Iterable, Optional, Sized, TextIO

from captum._utils.typing import Literal

Expand Down Expand Up @@ -51,7 +51,7 @@ class NullProgress:
progress bars.
"""

def __init__(self, iterable: Iterable = None, *args, **kwargs):
def __init__(self, iterable: Optional[Iterable] = None, *args, **kwargs):
del args, kwargs
self.iterable = iterable

Expand All @@ -77,10 +77,10 @@ def close(self):
class SimpleProgress:
def __init__(
self,
iterable: Iterable = None,
desc: str = None,
total: int = None,
file: TextIO = None,
iterable: Optional[Iterable] = None,
desc: Optional[str] = None,
total: Optional[int] = None,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
) -> None:
"""
Expand Down Expand Up @@ -155,11 +155,11 @@ def close(self):


def progress(
iterable: Iterable = None,
desc: str = None,
total: int = None,
iterable: Optional[Iterable] = None,
desc: Optional[str] = None,
total: Optional[int] = None,
use_tqdm=True,
file: TextIO = None,
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions captum/attr/_core/noise_tunnel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from enum import Enum
from typing import Any, cast, List, Tuple, Union
from typing import Any, cast, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -80,7 +80,7 @@ def attribute(
inputs: Union[Tensor, Tuple[Tensor, ...]],
nt_type: str = "smoothgrad",
nt_samples: int = 5,
nt_samples_batch_size: int = None,
nt_samples_batch_size: Optional[int] = None,
stdevs: Union[float, Tuple[float, ...]] = 1.0,
draw_baseline_from_distrib: bool = False,
**kwargs: Any,
Expand Down
5 changes: 4 additions & 1 deletion captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def _scatter_itp_attr_by_mask(
return attr


class InterpretableInput:
class InterpretableInput(ABC):
"""
InterpretableInput is an adapter for different kinds of model inputs to
work in Captum's attribution methods. Generally, attribution methods of Captum
Expand Down Expand Up @@ -94,6 +95,7 @@ class to create other types of customized input.
is only allowed in certain attribution classes like LLMAttribution for now.)
"""

@abstractmethod
def to_tensor(self) -> Tensor:
"""
Return the interpretable representation of this input as a tensor
Expand All @@ -104,6 +106,7 @@ def to_tensor(self) -> Tensor:
"""
pass

@abstractmethod
def to_model_input(self, itp_tensor: Optional[Tensor] = None) -> Any:
"""
Get the (perturbed) input in the format required by the model
Expand Down
Loading