|
1 | 1 | import functools
|
2 | 2 | from typing import Any, Callable, Dict, List, Sequence, Type, Union
|
3 | 3 |
|
| 4 | +import PIL.Image |
| 5 | + |
4 | 6 | import torch
|
| 7 | +from torchvision.prototype import features |
5 | 8 | from torchvision.prototype.transforms import functional as F, Transform
|
6 | 9 | from torchvision.transforms.transforms import _setup_size
|
7 | 10 |
|
@@ -32,6 +35,59 @@ def extra_repr(self) -> str:
|
32 | 35 | return ", ".join(extras)
|
33 | 36 |
|
34 | 37 |
|
| 38 | +class LinearTransformation(Transform): |
| 39 | + def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): |
| 40 | + super().__init__() |
| 41 | + if transformation_matrix.size(0) != transformation_matrix.size(1): |
| 42 | + raise ValueError( |
| 43 | + "transformation_matrix should be square. Got " |
| 44 | + f"{tuple(transformation_matrix.size())} rectangular matrix." |
| 45 | + ) |
| 46 | + |
| 47 | + if mean_vector.size(0) != transformation_matrix.size(0): |
| 48 | + raise ValueError( |
| 49 | + f"mean_vector should have the same length {mean_vector.size(0)}" |
| 50 | + f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" |
| 51 | + ) |
| 52 | + |
| 53 | + if transformation_matrix.device != mean_vector.device: |
| 54 | + raise ValueError( |
| 55 | + f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" |
| 56 | + ) |
| 57 | + |
| 58 | + self.transformation_matrix = transformation_matrix |
| 59 | + self.mean_vector = mean_vector |
| 60 | + |
| 61 | + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: |
| 62 | + |
| 63 | + if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image): |
| 64 | + return inpt |
| 65 | + elif isinstance(inpt, PIL.Image.Image): |
| 66 | + raise TypeError("Unsupported input type") |
| 67 | + |
| 68 | + # Image instance after linear transformation is not Image anymore due to unknown data range |
| 69 | + # Thus we will return Tensor for input Image |
| 70 | + |
| 71 | + shape = inpt.shape |
| 72 | + n = shape[-3] * shape[-2] * shape[-1] |
| 73 | + if n != self.transformation_matrix.shape[0]: |
| 74 | + raise ValueError( |
| 75 | + "Input tensor and transformation matrix have incompatible shape." |
| 76 | + + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " |
| 77 | + + f"{self.transformation_matrix.shape[0]}" |
| 78 | + ) |
| 79 | + |
| 80 | + if inpt.device.type != self.mean_vector.device.type: |
| 81 | + raise ValueError( |
| 82 | + "Input tensor should be on the same device as transformation matrix and mean vector. " |
| 83 | + f"Got {inpt.device} vs {self.mean_vector.device}" |
| 84 | + ) |
| 85 | + |
| 86 | + flat_tensor = inpt.view(-1, n) - self.mean_vector |
| 87 | + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) |
| 88 | + return transformed_tensor.view(shape) |
| 89 | + |
| 90 | + |
35 | 91 | class Normalize(Transform):
|
36 | 92 | def __init__(self, mean: List[float], std: List[float]):
|
37 | 93 | super().__init__()
|
|
0 commit comments