Skip to content

Commit 7043b6e

Browse files
committed
move simple_tensor to features module
1 parent afa772c commit 7043b6e

File tree

11 files changed

+25
-34
lines changed

11 files changed

+25
-34
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._bounding_box import BoundingBox, BoundingBoxFormat
22
from ._encoded import EncodedData, EncodedImage, EncodedVideo
3-
from ._feature import _Feature
3+
from ._feature import _Feature, is_simple_tensor
44
from ._image import ColorSpace, Image
55
from ._label import Label, OneHotLabel
66
from ._segmentation_mask import SegmentationMask

torchvision/prototype/features/_feature.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
F = TypeVar("F", bound="_Feature")
1111

1212

13+
def is_simple_tensor(inpt: Any) -> bool:
14+
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
15+
16+
1317
class _Feature(torch.Tensor):
1418
__F: Optional[ModuleType] = None
1519

torchvision/prototype/transforms/_augment.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
1414

1515
from ._transform import _RandomApplyTransform
16-
from ._utils import has_any, is_simple_tensor, query_chw
16+
from ._utils import has_any, query_chw
1717

1818

1919
class RandomErasing(_RandomApplyTransform):
@@ -103,7 +103,7 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None:
103103

104104
def forward(self, *inpts: Any) -> Any:
105105
sample = inpts if len(inpts) > 1 else inpts[0]
106-
if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)):
106+
if not (has_any(sample, features.Image, features.is_simple_tensor) and has_any(sample, features.OneHotLabel)):
107107
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
108108
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
109109
raise TypeError(
@@ -125,7 +125,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
125125

126126
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
127127
lam = params["lam"]
128-
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
128+
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
129129
if inpt.ndim < 4:
130130
raise ValueError("Need a batch of images")
131131
output = inpt.clone()
@@ -165,7 +165,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
165165
return dict(box=box, lam_adjusted=lam_adjusted)
166166

167167
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
168-
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
168+
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
169169
box = params["box"]
170170
if inpt.ndim < 4:
171171
raise ValueError("Need a batch of images")
@@ -275,7 +275,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis
275275
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
276276
images, bboxes, masks, labels = [], [], [], []
277277
for obj in flat_sample:
278-
if isinstance(obj, features.Image) or is_simple_tensor(obj):
278+
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
279279
images.append(obj)
280280
elif isinstance(obj, PIL.Image.Image):
281281
images.append(pil_to_tensor(obj))
@@ -309,7 +309,7 @@ def _insert_outputs(
309309
elif isinstance(obj, PIL.Image.Image):
310310
flat_sample[i] = F.to_image_pil(output_images[c0])
311311
c0 += 1
312-
elif is_simple_tensor(obj):
312+
elif features.is_simple_tensor(obj):
313313
flat_sample[i] = output_images[c0]
314314
c0 += 1
315315
elif isinstance(obj, features.BoundingBox):

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchvision.transforms.autoaugment import AutoAugmentPolicy
1111
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
1212

13-
from ._utils import get_chw, is_simple_tensor
13+
from ._utils import get_chw
1414

1515
K = TypeVar("K")
1616
V = TypeVar("V")
@@ -46,7 +46,7 @@ def _extract_image(
4646
sample_flat, _ = tree_flatten(sample)
4747
images = []
4848
for id, inpt in enumerate(sample_flat):
49-
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
49+
if isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt):
5050
images.append((id, inpt))
5151
elif isinstance(inpt, unsupported_types):
5252
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

torchvision/prototype/transforms/_color.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.transforms import functional as _F
99

1010
from ._transform import _RandomApplyTransform
11-
from ._utils import is_simple_tensor, query_chw
11+
from ._utils import query_chw
1212

1313
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
1414

@@ -112,7 +112,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
112112
)
113113

114114
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
115-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
115+
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
116116
return inpt
117117

118118
image = inpt

torchvision/prototype/transforms/_deprecated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing_extensions import Literal
1111

1212
from ._transform import _RandomApplyTransform
13-
from ._utils import is_simple_tensor, query_chw
13+
from ._utils import query_chw
1414

1515

1616
class ToTensor(Transform):
@@ -61,7 +61,7 @@ def __init__(self, mode: Optional[str] = None) -> None:
6161
self.mode = mode
6262

6363
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
64-
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
64+
if features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
6565
return _F.to_pil_image(inpt, mode=self.mode)
6666
else:
6767
return inpt

torchvision/prototype/transforms/_geometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing_extensions import Literal
1616

1717
from ._transform import _RandomApplyTransform
18-
from ._utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_chw
18+
from ._utils import has_all, has_any, query_bounding_box, query_chw
1919

2020

2121
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -700,7 +700,7 @@ def forward(self, *inputs: Any) -> Any:
700700
sample = inputs if len(inputs) > 1 else inputs[0]
701701
if not (
702702
has_all(sample, features.BoundingBox)
703-
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
703+
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor)
704704
and has_any(sample, features.Label, features.OneHotLabel)
705705
):
706706
raise TypeError(
@@ -848,7 +848,7 @@ def forward(self, *inputs: Any) -> Any:
848848
sample = inputs if len(inputs) > 1 else inputs[0]
849849
if not (
850850
has_all(sample, features.BoundingBox)
851-
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
851+
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor)
852852
and has_any(sample, features.Label, features.OneHotLabel)
853853
):
854854
raise TypeError(

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from torchvision.prototype.transforms import functional as F, Transform
88
from torchvision.transforms.functional import convert_image_dtype
99

10-
from ._utils import is_simple_tensor
11-
1210

1311
class ConvertBoundingBoxFormat(Transform):
1412
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
@@ -34,7 +32,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
3432
if isinstance(inpt, features.Image):
3533
output = convert_image_dtype(inpt, dtype=self.dtype)
3634
return features.Image.new_like(inpt, output, dtype=self.dtype)
37-
elif is_simple_tensor(inpt):
35+
elif features.is_simple_tensor(inpt):
3836
return convert_image_dtype(inpt, dtype=self.dtype)
3937
else:
4038
return inpt

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
1010

11-
from ._utils import is_simple_tensor
12-
1311

1412
class DecodeImage(Transform):
1513
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -51,7 +49,7 @@ def __init__(self, *, copy: bool = False) -> None:
5149
self.copy = copy
5250

5351
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
54-
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
52+
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or features.is_simple_tensor(inpt):
5553
output = F.to_image_tensor(inpt, copy=self.copy)
5654
return features.Image(output)
5755
else:
@@ -68,7 +66,7 @@ def __init__(self, *, mode: Optional[str] = None) -> None:
6866
self.mode = mode
6967

7068
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
71-
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
69+
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or features.is_simple_tensor(inpt):
7270
return F.to_image_pil(inpt, mode=self.mode)
7371
else:
7472
return inpt

torchvision/prototype/transforms/_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
3636
chws = {
3737
get_chw(item)
3838
for item in flat_sample
39-
if isinstance(item, (features.Image, PIL.Image.Image)) or is_simple_tensor(item)
39+
if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item)
4040
}
4141
if not chws:
4242
raise TypeError("No image was found in the sample")
@@ -63,10 +63,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
6363
else:
6464
return False
6565
return True
66-
67-
68-
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
69-
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
70-
# also move it `features` since it literally checks for the _Feature type.
71-
def is_simple_tensor(inpt: Any) -> bool:
72-
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature)

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torchvision.prototype import features
77
from torchvision.transforms import functional as _F
88

9-
from .._utils import is_simple_tensor
10-
119

1210
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
1311
call = ", num_output_channels=3" if num_output_channels == 3 else ""
@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
2321

2422

2523
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
26-
old_color_space = features.Image.guess_color_space(inpt) if is_simple_tensor(inpt) else None
24+
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
2725

2826
call = ", num_output_channels=3" if num_output_channels == 3 else ""
2927
replacement = (

0 commit comments

Comments
 (0)