Skip to content

Commit 38f8e21

Browse files
committed
Use _tensor ref
1 parent 3799ce7 commit 38f8e21

File tree

7 files changed

+18
-11
lines changed

7 files changed

+18
-11
lines changed

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size:
2424
bounding_box = tensor.as_subclass(cls)
2525
bounding_box.format = format
2626
bounding_box.spatial_size = spatial_size
27+
bounding_box._tensor = tensor
2728
return bounding_box
2829

2930
def __new__(

torchvision/prototype/features/_encoded.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
class EncodedData(_Feature):
1717
@classmethod
1818
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
19-
return tensor.as_subclass(cls)
19+
output = tensor.as_subclass(cls)
20+
output._tensor = tensor
21+
return output
2022

2123
def __new__(
2224
cls,

torchvision/prototype/features/_feature.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __new__(
4444
) -> _Feature:
4545
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
4646
output = tensor.as_subclass(_Feature)
47+
output._tensor = tensor
4748
return output
4849

4950
@classmethod
@@ -108,7 +109,9 @@ def __torch_function__(
108109
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
109110
# will retain the input type. Thus, we need to unwrap here.
110111
if isinstance(output, cls):
111-
return output.as_subclass(torch.Tensor)
112+
tensor = output.as_subclass(torch.Tensor)
113+
output._tensor = tensor
114+
return tensor
112115

113116
return output
114117

@@ -134,23 +137,19 @@ def _F(self) -> ModuleType:
134137
# this way we return the result without passing into __torch_function__
135138
@property
136139
def shape(self) -> _size: # type: ignore[override]
137-
with DisableTorchFunction():
138-
return super().shape
140+
return self._tensor.shape
139141

140142
@property
141143
def ndim(self) -> int: # type: ignore[override]
142-
with DisableTorchFunction():
143-
return super().ndim
144+
return self._tensor.ndim
144145

145146
@property
146147
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
147-
with DisableTorchFunction():
148-
return super().device
148+
return self._tensor.device
149149

150150
@property
151151
def dtype(self) -> _dtype: # type: ignore[override]
152-
with DisableTorchFunction():
153-
return super().dtype
152+
return self._tensor.dtype
154153

155154
def horizontal_flip(self) -> _Feature:
156155
return self

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Image(_Feature):
6464
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image:
6565
image = tensor.as_subclass(cls)
6666
image.color_space = color_space
67+
image._tensor = tensor
6768
return image
6869

6970
def __new__(

torchvision/prototype/features/_label.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class _LabelBase(_Feature):
1818
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
1919
label_base = tensor.as_subclass(cls)
2020
label_base.categories = categories
21+
label_base._tensor = tensor
2122
return label_base
2223

2324
def __new__(

torchvision/prototype/features/_mask.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
class Mask(_Feature):
1212
@classmethod
1313
def _wrap(cls, tensor: torch.Tensor) -> Mask:
14-
return tensor.as_subclass(cls)
14+
output = tensor.as_subclass(cls)
15+
output._tensor = tensor
16+
return output
1517

1618
def __new__(
1719
cls,

torchvision/prototype/features/_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Video(_Feature):
1717
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video:
1818
video = tensor.as_subclass(cls)
1919
video.color_space = color_space
20+
video._tensor = tensor
2021
return video
2122

2223
def __new__(

0 commit comments

Comments
 (0)