diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index c417b33c2a3..7cfb9b6a785 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1322,7 +1322,7 @@ def sample_inputs_gaussian_blur_video(): def sample_inputs_equalize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) ): yield ArgsKwargs(image_loader) @@ -1331,27 +1331,41 @@ def reference_inputs_equalize_image_tensor(): # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, # the information gain is low if we already provide something really close to the expected value. + def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor): + if dtype.is_floating_point: + low = low_factor + high = high_factor + else: + max_value = torch.iinfo(dtype).max + low = int(low_factor * max_value) + high = int(high_factor * max_value) + return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high) + + def make_beta_distributed_image(shape, dtype, device, *, alpha, beta): + image = torch.distributions.Beta(alpha, beta).sample(shape) + if not dtype.is_floating_point: + image.mul_(torch.iinfo(dtype).max).round_() + return image.to(dtype=dtype, device=device) + spatial_size = (256, 256) - for fn, color_space in itertools.product( + for dtype, color_space, fn in itertools.product( + [torch.uint8, torch.float32], + [features.ColorSpace.GRAY, features.ColorSpace.RGB], [ + lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), + lambda shape, dtype, device: torch.full( + shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device + ), *[ - lambda shape, dtype, device, low=low, high=high: torch.randint( - low, high, shape, dtype=dtype, device=device - ) - for low, high in [ - (0, 1), - (255, 256), - (0, 64), - (64, 192), - (192, 256), + functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor) + for low_factor, high_factor in [ + (0.0, 0.25), + (0.25, 0.75), + (0.75, 1.0), ] ], *[ - lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta) - .sample(shape) - .mul_(255) - .round_() - .to(dtype=dtype, device=device) + functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta) for alpha, beta in [ (0.5, 0.5), (2, 2), @@ -1360,10 +1374,9 @@ def reference_inputs_equalize_image_tensor(): ] ], ], - [features.ColorSpace.GRAY, features.ColorSpace.RGB], ): image_loader = ImageLoader( - fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space + fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space ) yield ArgsKwargs(image_loader) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 3ad65493f70..fb238510242 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -371,26 +371,26 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: - if image.dtype != torch.uint8: - raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") - - num_channels, height, width = get_dimensions_image_tensor(image) - if num_channels not in (1, 3): - raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") - if image.numel() == 0: return image + # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that + # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for + # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely + # unfeasible for `torch.int64`. + # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we + # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition + # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower + # and more complicated to implement than a simple conversion and a fast histogram implementation for integers. + # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is + # by far the most common, we choose it as base. + output_dtype = image.dtype + image = convert_dtype_image_tensor(image, torch.uint8) + + # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image + # corresponds to adding 1 to index 127 in the histogram. batch_shape = image.shape[:-2] flat_image = image.flatten(start_dim=-2).to(torch.long) - - # The algorithm for histogram equalization is mirrored from PIL: - # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 - - # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8 - # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using - # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127 - # in the histogram. hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) cum_hist = hist.cumsum(dim=-1) @@ -398,6 +398,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # The simplest form of lookup-table (LUT) that also achieves histogram equalization is # `lut = cum_hist / flat_image.shape[-1] * 255` # However, PIL uses a more elaborate scheme: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum @@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to # pay the runtime cost for checking it every time. - no_equalization = step.eq(0).unsqueeze_(-1) + valid_equalization = step.ne(0).unsqueeze_(-1) # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. @@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) - return torch.where(no_equalization, image, equalized_image) + output = torch.where(valid_equalization, equalized_image, image) + return convert_dtype_image_tensor(output, output_dtype) equalize_image_pil = _FP.equalize