Skip to content

Unwrap features before passing them into a kernel #6807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,23 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
format = BoundingBoxFormat.from_str(format.upper())

return BoundingBox.wrap_like(
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
self,
self._F.convert_format_bounding_box(
self.as_subclass(torch.Tensor), old_format=self.format, new_format=format
),
format=format,
)

def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
output = self._F.horizontal_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)

def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
output = self._F.vertical_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)

def resize( # type: ignore[override]
Expand All @@ -85,19 +93,19 @@ def resize( # type: ignore[override]
antialias: bool = False,
) -> BoundingBox:
output, spatial_size = self._F.resize_bounding_box(
self, spatial_size=self.spatial_size, size=size, max_size=max_size
self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, spatial_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def center_crop(self, output_size: List[int]) -> BoundingBox:
output, spatial_size = self._F.center_crop_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, output_size=output_size
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

Expand All @@ -111,7 +119,9 @@ def resized_crop(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
output, spatial_size = self._F.resized_crop_bounding_box(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def pad(
Expand All @@ -121,7 +131,11 @@ def pad(
padding_mode: str = "constant",
) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, padding=padding, padding_mode=padding_mode
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
padding=padding,
padding_mode=padding_mode,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

Expand All @@ -134,7 +148,12 @@ def rotate(
center: Optional[List[float]] = None,
) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, angle=angle, expand=expand, center=center
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
angle=angle,
expand=expand,
center=center,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

Expand All @@ -149,7 +168,7 @@ def affine(
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
self,
self.as_subclass(torch.Tensor),
self.format,
self.spatial_size,
angle,
Expand All @@ -166,7 +185,7 @@ def perspective(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
output = self._F.perspective_bounding_box(self.as_subclass(torch.Tensor), self.format, perspective_coeffs)
return BoundingBox.wrap_like(self, output)

def elastic(
Expand All @@ -175,5 +194,5 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
output = self._F.elastic_bounding_box(self.as_subclass(torch.Tensor), self.format, displacement)
return BoundingBox.wrap_like(self, output)
71 changes: 44 additions & 27 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True)
return Image.wrap_like(
self,
self._F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)

def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self)
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

def resize( # type: ignore[override]
Expand All @@ -138,16 +138,16 @@ def resize( # type: ignore[override]
antialias: bool = False,
) -> Image:
output = self._F.resize_image_tensor(
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return Image.wrap_like(self, output)

def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self, top, left, height, width)
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
return Image.wrap_like(self, output)

def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(self, output_size=output_size)
output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
return Image.wrap_like(self, output)

def resized_crop(
Expand All @@ -161,7 +161,14 @@ def resized_crop(
antialias: bool = False,
) -> Image:
output = self._F.resized_crop_image_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Image.wrap_like(self, output)

Expand All @@ -171,7 +178,7 @@ def pad(
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
return Image.wrap_like(self, output)

def rotate(
Expand All @@ -182,8 +189,8 @@ def rotate(
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Image.wrap_like(self, output)

Expand All @@ -197,8 +204,8 @@ def affine(
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.affine_image_tensor(
self,
output = self._F.affine_image_tensor(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
Expand All @@ -215,8 +222,8 @@ def perspective(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
output = self._F.perspective_image_tensor(
self.as_subclass(torch.Tensor), perspective_coeffs, interpolation=interpolation, fill=fill
)
return Image.wrap_like(self, output)

Expand All @@ -226,55 +233,65 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
)
return Image.wrap_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
)
return Image.wrap_like(self, output)

def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
output = self._F.adjust_saturation_image_tensor(
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
)
return Image.wrap_like(self, output)

def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
return Image.wrap_like(self, output)

def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
output = self._F.adjust_sharpness_image_tensor(
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
)
return Image.wrap_like(self, output)

def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
return Image.wrap_like(self, output)

def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
return Image.wrap_like(self, output)

def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self, bits=bits)
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
return Image.wrap_like(self, output)

def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self, threshold=threshold)
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
return Image.wrap_like(self, output)

def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self)
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self)
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

def invert(self) -> Image:
output = self._F.invert_image_tensor(self)
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
output = self._F.gaussian_blur_image_tensor(
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
)
return Image.wrap_like(self, output)


Expand Down
22 changes: 11 additions & 11 deletions torchvision/prototype/features/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))

def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)

def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self)
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)

def resize( # type: ignore[override]
Expand All @@ -51,15 +51,15 @@ def resize( # type: ignore[override]
max_size: Optional[int] = None,
antialias: bool = False,
) -> Mask:
output = self._F.resize_mask(self, size, max_size=max_size)
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
return Mask.wrap_like(self, output)

def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self, top, left, height, width)
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
return Mask.wrap_like(self, output)

def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self, output_size=output_size)
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
return Mask.wrap_like(self, output)

def resized_crop(
Expand All @@ -72,7 +72,7 @@ def resized_crop(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False,
) -> Mask:
output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
return Mask.wrap_like(self, output)

def pad(
Expand All @@ -81,7 +81,7 @@ def pad(
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
return Mask.wrap_like(self, output)

def rotate(
Expand All @@ -92,7 +92,7 @@ def rotate(
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output)

def affine(
Expand All @@ -106,7 +106,7 @@ def affine(
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
self,
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
Expand All @@ -122,7 +122,7 @@ def perspective(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
output = self._F.perspective_mask(self.as_subclass(torch.Tensor), perspective_coeffs, fill=fill)
return Mask.wrap_like(self, output)

def elastic(
Expand All @@ -131,5 +131,5 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill)
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output)
Loading