Skip to content

Commit 3761855

Browse files
authored
Unwrap features before passing them into a kernel (#6807)
* unwrap features before calling the kernels * revert double unwrapping * cleanup * remove debug raise * more cleanup
1 parent d0de55d commit 3761855

File tree

5 files changed

+142
-79
lines changed

5 files changed

+142
-79
lines changed

torchvision/prototype/features/_bounding_box.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,23 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6666
format = BoundingBoxFormat.from_str(format.upper())
6767

6868
return BoundingBox.wrap_like(
69-
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
69+
self,
70+
self._F.convert_format_bounding_box(
71+
self.as_subclass(torch.Tensor), old_format=self.format, new_format=format
72+
),
73+
format=format,
7074
)
7175

7276
def horizontal_flip(self) -> BoundingBox:
73-
output = self._F.horizontal_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
77+
output = self._F.horizontal_flip_bounding_box(
78+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
79+
)
7480
return BoundingBox.wrap_like(self, output)
7581

7682
def vertical_flip(self) -> BoundingBox:
77-
output = self._F.vertical_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
83+
output = self._F.vertical_flip_bounding_box(
84+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
85+
)
7886
return BoundingBox.wrap_like(self, output)
7987

8088
def resize( # type: ignore[override]
@@ -85,19 +93,19 @@ def resize( # type: ignore[override]
8593
antialias: bool = False,
8694
) -> BoundingBox:
8795
output, spatial_size = self._F.resize_bounding_box(
88-
self, spatial_size=self.spatial_size, size=size, max_size=max_size
96+
self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size
8997
)
9098
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
9199

92100
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
93101
output, spatial_size = self._F.crop_bounding_box(
94-
self, self.format, top=top, left=left, height=height, width=width
102+
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
95103
)
96104
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
97105

98106
def center_crop(self, output_size: List[int]) -> BoundingBox:
99107
output, spatial_size = self._F.center_crop_bounding_box(
100-
self, format=self.format, spatial_size=self.spatial_size, output_size=output_size
108+
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
101109
)
102110
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
103111

@@ -111,7 +119,9 @@ def resized_crop(
111119
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
112120
antialias: bool = False,
113121
) -> BoundingBox:
114-
output, spatial_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
122+
output, spatial_size = self._F.resized_crop_bounding_box(
123+
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
124+
)
115125
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
116126

117127
def pad(
@@ -121,7 +131,11 @@ def pad(
121131
padding_mode: str = "constant",
122132
) -> BoundingBox:
123133
output, spatial_size = self._F.pad_bounding_box(
124-
self, format=self.format, spatial_size=self.spatial_size, padding=padding, padding_mode=padding_mode
134+
self.as_subclass(torch.Tensor),
135+
format=self.format,
136+
spatial_size=self.spatial_size,
137+
padding=padding,
138+
padding_mode=padding_mode,
125139
)
126140
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
127141

@@ -134,7 +148,12 @@ def rotate(
134148
center: Optional[List[float]] = None,
135149
) -> BoundingBox:
136150
output, spatial_size = self._F.rotate_bounding_box(
137-
self, format=self.format, spatial_size=self.spatial_size, angle=angle, expand=expand, center=center
151+
self.as_subclass(torch.Tensor),
152+
format=self.format,
153+
spatial_size=self.spatial_size,
154+
angle=angle,
155+
expand=expand,
156+
center=center,
138157
)
139158
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
140159

@@ -149,7 +168,7 @@ def affine(
149168
center: Optional[List[float]] = None,
150169
) -> BoundingBox:
151170
output = self._F.affine_bounding_box(
152-
self,
171+
self.as_subclass(torch.Tensor),
153172
self.format,
154173
self.spatial_size,
155174
angle,
@@ -166,7 +185,7 @@ def perspective(
166185
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
167186
fill: FillTypeJIT = None,
168187
) -> BoundingBox:
169-
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
188+
output = self._F.perspective_bounding_box(self.as_subclass(torch.Tensor), self.format, perspective_coeffs)
170189
return BoundingBox.wrap_like(self, output)
171190

172191
def elastic(
@@ -175,5 +194,5 @@ def elastic(
175194
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
176195
fill: FillTypeJIT = None,
177196
) -> BoundingBox:
178-
output = self._F.elastic_bounding_box(self, self.format, displacement)
197+
output = self._F.elastic_bounding_box(self.as_subclass(torch.Tensor), self.format, displacement)
179198
return BoundingBox.wrap_like(self, output)

torchvision/prototype/features/_image.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,17 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True)
117117
return Image.wrap_like(
118118
self,
119119
self._F.convert_color_space_image_tensor(
120-
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
120+
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
121121
),
122122
color_space=color_space,
123123
)
124124

125125
def horizontal_flip(self) -> Image:
126-
output = self._F.horizontal_flip_image_tensor(self)
126+
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
127127
return Image.wrap_like(self, output)
128128

129129
def vertical_flip(self) -> Image:
130-
output = self._F.vertical_flip_image_tensor(self)
130+
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
131131
return Image.wrap_like(self, output)
132132

133133
def resize( # type: ignore[override]
@@ -138,16 +138,16 @@ def resize( # type: ignore[override]
138138
antialias: bool = False,
139139
) -> Image:
140140
output = self._F.resize_image_tensor(
141-
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
141+
self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
142142
)
143143
return Image.wrap_like(self, output)
144144

145145
def crop(self, top: int, left: int, height: int, width: int) -> Image:
146-
output = self._F.crop_image_tensor(self, top, left, height, width)
146+
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
147147
return Image.wrap_like(self, output)
148148

149149
def center_crop(self, output_size: List[int]) -> Image:
150-
output = self._F.center_crop_image_tensor(self, output_size=output_size)
150+
output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
151151
return Image.wrap_like(self, output)
152152

153153
def resized_crop(
@@ -161,7 +161,14 @@ def resized_crop(
161161
antialias: bool = False,
162162
) -> Image:
163163
output = self._F.resized_crop_image_tensor(
164-
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
164+
self.as_subclass(torch.Tensor),
165+
top,
166+
left,
167+
height,
168+
width,
169+
size=list(size),
170+
interpolation=interpolation,
171+
antialias=antialias,
165172
)
166173
return Image.wrap_like(self, output)
167174

@@ -171,7 +178,7 @@ def pad(
171178
fill: FillTypeJIT = None,
172179
padding_mode: str = "constant",
173180
) -> Image:
174-
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
181+
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
175182
return Image.wrap_like(self, output)
176183

177184
def rotate(
@@ -182,8 +189,8 @@ def rotate(
182189
fill: FillTypeJIT = None,
183190
center: Optional[List[float]] = None,
184191
) -> Image:
185-
output = self._F._geometry.rotate_image_tensor(
186-
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
192+
output = self._F.rotate_image_tensor(
193+
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
187194
)
188195
return Image.wrap_like(self, output)
189196

@@ -197,8 +204,8 @@ def affine(
197204
fill: FillTypeJIT = None,
198205
center: Optional[List[float]] = None,
199206
) -> Image:
200-
output = self._F._geometry.affine_image_tensor(
201-
self,
207+
output = self._F.affine_image_tensor(
208+
self.as_subclass(torch.Tensor),
202209
angle,
203210
translate=translate,
204211
scale=scale,
@@ -215,8 +222,8 @@ def perspective(
215222
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
216223
fill: FillTypeJIT = None,
217224
) -> Image:
218-
output = self._F._geometry.perspective_image_tensor(
219-
self, perspective_coeffs, interpolation=interpolation, fill=fill
225+
output = self._F.perspective_image_tensor(
226+
self.as_subclass(torch.Tensor), perspective_coeffs, interpolation=interpolation, fill=fill
220227
)
221228
return Image.wrap_like(self, output)
222229

@@ -226,55 +233,65 @@ def elastic(
226233
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
227234
fill: FillTypeJIT = None,
228235
) -> Image:
229-
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
236+
output = self._F.elastic_image_tensor(
237+
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
238+
)
230239
return Image.wrap_like(self, output)
231240

232241
def adjust_brightness(self, brightness_factor: float) -> Image:
233-
output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
242+
output = self._F.adjust_brightness_image_tensor(
243+
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
244+
)
234245
return Image.wrap_like(self, output)
235246

236247
def adjust_saturation(self, saturation_factor: float) -> Image:
237-
output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
248+
output = self._F.adjust_saturation_image_tensor(
249+
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
250+
)
238251
return Image.wrap_like(self, output)
239252

240253
def adjust_contrast(self, contrast_factor: float) -> Image:
241-
output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
254+
output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
242255
return Image.wrap_like(self, output)
243256

244257
def adjust_sharpness(self, sharpness_factor: float) -> Image:
245-
output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
258+
output = self._F.adjust_sharpness_image_tensor(
259+
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
260+
)
246261
return Image.wrap_like(self, output)
247262

248263
def adjust_hue(self, hue_factor: float) -> Image:
249-
output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
264+
output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
250265
return Image.wrap_like(self, output)
251266

252267
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
253-
output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
268+
output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
254269
return Image.wrap_like(self, output)
255270

256271
def posterize(self, bits: int) -> Image:
257-
output = self._F.posterize_image_tensor(self, bits=bits)
272+
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
258273
return Image.wrap_like(self, output)
259274

260275
def solarize(self, threshold: float) -> Image:
261-
output = self._F.solarize_image_tensor(self, threshold=threshold)
276+
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
262277
return Image.wrap_like(self, output)
263278

264279
def autocontrast(self) -> Image:
265-
output = self._F.autocontrast_image_tensor(self)
280+
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
266281
return Image.wrap_like(self, output)
267282

268283
def equalize(self) -> Image:
269-
output = self._F.equalize_image_tensor(self)
284+
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
270285
return Image.wrap_like(self, output)
271286

272287
def invert(self) -> Image:
273-
output = self._F.invert_image_tensor(self)
288+
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
274289
return Image.wrap_like(self, output)
275290

276291
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
277-
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
292+
output = self._F.gaussian_blur_image_tensor(
293+
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
294+
)
278295
return Image.wrap_like(self, output)
279296

280297

torchvision/prototype/features/_mask.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def spatial_size(self) -> Tuple[int, int]:
3737
return cast(Tuple[int, int], tuple(self.shape[-2:]))
3838

3939
def horizontal_flip(self) -> Mask:
40-
output = self._F.horizontal_flip_mask(self)
40+
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
4141
return Mask.wrap_like(self, output)
4242

4343
def vertical_flip(self) -> Mask:
44-
output = self._F.vertical_flip_mask(self)
44+
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
4545
return Mask.wrap_like(self, output)
4646

4747
def resize( # type: ignore[override]
@@ -51,15 +51,15 @@ def resize( # type: ignore[override]
5151
max_size: Optional[int] = None,
5252
antialias: bool = False,
5353
) -> Mask:
54-
output = self._F.resize_mask(self, size, max_size=max_size)
54+
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
5555
return Mask.wrap_like(self, output)
5656

5757
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
58-
output = self._F.crop_mask(self, top, left, height, width)
58+
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
5959
return Mask.wrap_like(self, output)
6060

6161
def center_crop(self, output_size: List[int]) -> Mask:
62-
output = self._F.center_crop_mask(self, output_size=output_size)
62+
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
6363
return Mask.wrap_like(self, output)
6464

6565
def resized_crop(
@@ -72,7 +72,7 @@ def resized_crop(
7272
interpolation: InterpolationMode = InterpolationMode.NEAREST,
7373
antialias: bool = False,
7474
) -> Mask:
75-
output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
75+
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
7676
return Mask.wrap_like(self, output)
7777

7878
def pad(
@@ -81,7 +81,7 @@ def pad(
8181
fill: FillTypeJIT = None,
8282
padding_mode: str = "constant",
8383
) -> Mask:
84-
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
84+
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
8585
return Mask.wrap_like(self, output)
8686

8787
def rotate(
@@ -92,7 +92,7 @@ def rotate(
9292
fill: FillTypeJIT = None,
9393
center: Optional[List[float]] = None,
9494
) -> Mask:
95-
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
95+
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
9696
return Mask.wrap_like(self, output)
9797

9898
def affine(
@@ -106,7 +106,7 @@ def affine(
106106
center: Optional[List[float]] = None,
107107
) -> Mask:
108108
output = self._F.affine_mask(
109-
self,
109+
self.as_subclass(torch.Tensor),
110110
angle,
111111
translate=translate,
112112
scale=scale,
@@ -122,7 +122,7 @@ def perspective(
122122
interpolation: InterpolationMode = InterpolationMode.NEAREST,
123123
fill: FillTypeJIT = None,
124124
) -> Mask:
125-
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
125+
output = self._F.perspective_mask(self.as_subclass(torch.Tensor), perspective_coeffs, fill=fill)
126126
return Mask.wrap_like(self, output)
127127

128128
def elastic(
@@ -131,5 +131,5 @@ def elastic(
131131
interpolation: InterpolationMode = InterpolationMode.NEAREST,
132132
fill: FillTypeJIT = None,
133133
) -> Mask:
134-
output = self._F.elastic_mask(self, displacement, fill=fill)
134+
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
135135
return Mask.wrap_like(self, output)

0 commit comments

Comments
 (0)