Skip to content

extend equalize to all integer and floating dtypes #6851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def sample_inputs_gaussian_blur_video():

def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader)

Expand All @@ -1331,27 +1331,41 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
if dtype.is_floating_point:
low = low_factor
high = high_factor
else:
max_value = torch.iinfo(dtype).max
low = int(low_factor * max_value)
high = int(high_factor * max_value)
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)

def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
image = torch.distributions.Beta(alpha, beta).sample(shape)
if not dtype.is_floating_point:
image.mul_(torch.iinfo(dtype).max).round_()
return image.to(dtype=dtype, device=device)

spatial_size = (256, 256)
for fn, color_space in itertools.product(
for dtype, color_space, fn in itertools.product(
[torch.uint8, torch.float32],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
),
*[
lambda shape, dtype, device, low=low, high=high: torch.randint(
low, high, shape, dtype=dtype, device=device
)
for low, high in [
(0, 1),
(255, 256),
(0, 64),
(64, 192),
(192, 256),
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
for low_factor, high_factor in [
(0.0, 0.25),
(0.25, 0.75),
(0.75, 1.0),
]
],
*[
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta)
.sample(shape)
.mul_(255)
.round_()
.to(dtype=dtype, device=device)
functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
for alpha, beta in [
(0.5, 0.5),
(2, 2),
Expand All @@ -1360,10 +1374,9 @@ def reference_inputs_equalize_image_tensor():
]
],
],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
):
image_loader = ImageLoader(
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space
Copy link
Collaborator Author

@pmeier pmeier Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although there is a quite a diff in this file, the only functional change is to also test torch.float32 next to torch.uint8. The rest is just refactoring to account for that.

)
yield ArgsKwargs(image_loader)

Expand Down
36 changes: 19 additions & 17 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,33 +371,34 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")

num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")

if image.numel() == 0:
return image

# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8)

# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
batch_shape = image.shape[:-2]
flat_image = image.flatten(start_dim=-2).to(torch.long)

# The algorithm for histogram equalization is mirrored from PIL:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385

# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
# in the histogram.
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
cum_hist = hist.cumsum(dim=-1)

# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
# `lut = cum_hist / flat_image.shape[-1] * 255`
# However, PIL uses a more elaborate scheme:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`

# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
Expand All @@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
# pay the runtime cost for checking it every time.
no_equalization = step.eq(0).unsqueeze_(-1)
valid_equalization = step.ne(0).unsqueeze_(-1)

# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
Expand All @@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)

return torch.where(no_equalization, image, equalized_image)
output = torch.where(valid_equalization, equalized_image, image)
return convert_dtype_image_tensor(output, output_dtype)


equalize_image_pil = _FP.equalize
Expand Down