From f446f931fb0332e5aba4315ba436257bffe43195 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 26 May 2022 13:08:39 +0200 Subject: [PATCH 1/2] prevent feature wrapping if the feature is not the primary operand --- test/test_prototype_features.py | 72 ++++++++++++++++++++++ torchvision/prototype/features/_feature.py | 7 +++ 2 files changed, 79 insertions(+) create mode 100644 test/test_prototype_features.py diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py new file mode 100644 index 00000000000..65f30e9d569 --- /dev/null +++ b/test/test_prototype_features.py @@ -0,0 +1,72 @@ +import torch +from torchvision.prototype import features + + +def test_isinstance(): + assert isinstance( + features.Label([0, 1, 0], categories=["foo", "bar"]), + torch.Tensor, + ) + + +def test_wrapping_no_copy(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + assert label.data_ptr() == tensor.data_ptr() + + +def test_to_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + label_to = label.to(torch.int32) + + assert type(label_to) is features.Label + assert label_to.dtype is torch.int32 + assert label_to.categories is label.categories + + +def test_to_feature_reference(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + + tensor_to = tensor.to(label) + + assert type(tensor_to) is torch.Tensor + assert tensor_to.dtype is torch.int32 + + +def test_clone_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + label_clone = label.clone() + + assert type(label_clone) is features.Label + assert label_clone.data_ptr() != label.data_ptr() + assert label_clone.categories is label.categories + + +def test_other_op_no_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + # any operation besides .to() and .clone() will do here + output = label * 2 + + assert type(output) is torch.Tensor + + +def test_new_like(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + # any operation besides .to() and .clone() will do here + output = label * 2 + + label_new = features.Label.new_like(label, output) + + assert type(label_new) is features.Label + assert label_new.data_ptr() == output.data_ptr() + assert label_new.categories is label.categories diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index f8026b4d34d..e971eb2e862 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -77,6 +77,13 @@ def __torch_function__( with DisableTorchFunction(): output = func(*args, **kwargs) + # The __torch_function__ protocol will invoke this method on all types involved in the computation by walking + # the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke + # `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a + # case. + if not isinstance(args[0], cls): + return output + if func is torch.Tensor.clone: return cls.new_like(args[0], output) elif func is torch.Tensor.to: From 517089c14027358e3a2e1d08af081f4875a27ca7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 23 Sep 2022 15:18:55 +0200 Subject: [PATCH 2/2] explicitly add feature tests to CI --- .github/workflows/prototype-tests.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index e9832860c40..5e9ca360d08 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -43,6 +43,15 @@ jobs: id: setup run: exit 0 + - name: Run prototype features tests + shell: bash + run: | + pytest \ + --durations=20 \ + --cov=torchvision/prototype/features \ + --cov-report=term-missing \ + test/test_prototype_features*.py + - name: Run prototype datasets tests if: success() || ( failure() && steps.setup.conclusion == 'success' ) shell: bash