diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 6db0727024..1bad602896 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -18,6 +18,27 @@ from torch.nn import Module +def _parse_version(v: str) -> Tuple[int, ...]: + """ + Parse version strings into tuples for comparison. + + Versions should be in the form of "..", ".", + or "". The "dev", "post" and other letter portions of the given version will + be ignored. + + Args: + + v (str): A version string. + + Returns: + version_tuple (tuple of int): A tuple of integer values to use for version + comparison. + """ + v = [n for n in v.split(".") if n.isdigit()] + assert v != [] + return tuple(map(int, v)) + + class ExpansionTypes(Enum): repeat = 1 repeat_interleave = 2 @@ -671,7 +692,7 @@ def _register_backward_hook( ): return module.register_backward_hook(hook) - if torch.__version__ >= "1.9": + if _parse_version(torch.__version__) >= (1, 9, 0): # Only supported for torch >= 1.9 return module.register_full_backward_hook(hook) else: diff --git a/captum/influence/_core/similarity_influence.py b/captum/influence/_core/similarity_influence.py index 83cb2966fa..0fd21eedb7 100644 --- a/captum/influence/_core/similarity_influence.py +++ b/captum/influence/_core/similarity_influence.py @@ -40,7 +40,7 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor: test = test.view(test.shape[0], -1) train = train.view(train.shape[0], -1) - if torch.__version__ <= "1.6.0": + if common._parse_version(torch.__version__) <= (1, 6, 0): test_norm = torch.norm(test, p=None, dim=1, keepdim=True) train_norm = torch.norm(train, p=None, dim=1, keepdim=True) else: diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 131f8964b8..c7b60529c7 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -5,7 +5,9 @@ import torch import torch.nn as nn +from captum._utils.common import _parse_version from captum._utils.progress import progress + from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset @@ -126,7 +128,7 @@ def _jacobian_loss_wrt_inputs( "Must be either 'sum' or 'mean'." ) - if torch.__version__ >= "1.8": + if _parse_version(torch.__version__) >= (1, 8, 0): input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out, vectorize=vectorize ) diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index 5bea797e97..e19c3c26b9 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -3,7 +3,13 @@ from typing import cast, List, Tuple import torch -from captum._utils.common import _reduce_list, _select_targets, _sort_key_list, safe_div +from captum._utils.common import ( + _parse_version, + _reduce_list, + _select_targets, + _sort_key_list, + safe_div, +) from tests.helpers.basic import assertTensorAlmostEqual, BaseTest @@ -109,3 +115,35 @@ def test_select_target_3d(self) -> None: # Verify error is raised if too many dimensions are provided. with self.assertRaises(AssertionError): _select_targets(output_tensor, (1, 2, 3)) + + +class TestParseVersion(BaseTest): + def test_parse_version_dev(self) -> None: + version_str = "1.12.0.dev20201109" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 0)) + + def test_parse_version_post(self) -> None: + version_str = "1.3.0.post2" + output = _parse_version(version_str) + self.assertEqual(output, (1, 3, 0)) + + def test_parse_version_1_12_0(self) -> None: + version_str = "1.12.0" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 0)) + + def test_parse_version_1_12_2(self) -> None: + version_str = "1.12.2" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12, 2)) + + def test_parse_version_1_6_0(self) -> None: + version_str = "1.6.0" + output = _parse_version(version_str) + self.assertEqual(output, (1, 6, 0)) + + def test_parse_version_1_12(self) -> None: + version_str = "1.12" + output = _parse_version(version_str) + self.assertEqual(output, (1, 12))