Skip to content

Commit e6c5e5c

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [proto] Ported LinearTransformation (#6458)
Summary: * WIP * Fixed dtype correction and tests * Removed PIL Image support and output always Tensor Reviewed By: datumbox Differential Revision: D39013682 fbshipit-source-id: e98d9ba0f2ea703cfb9807dbc9abcd34e5db3331
1 parent b8fc032 commit e6c5e5c

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

test/test_prototype_transforms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,3 +1487,36 @@ def test__transform_bounding_box_clamping(self, mocker):
14871487
transform(bounding_box)
14881488

14891489
mock.assert_called_once()
1490+
1491+
1492+
class TestLinearTransformation:
1493+
def test_assertions(self):
1494+
with pytest.raises(ValueError, match="transformation_matrix should be square"):
1495+
transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
1496+
1497+
with pytest.raises(ValueError, match="mean_vector should have the same length"):
1498+
transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
1499+
1500+
@pytest.mark.parametrize(
1501+
"inpt",
1502+
[
1503+
122 * torch.ones(1, 3, 8, 8),
1504+
122.0 * torch.ones(1, 3, 8, 8),
1505+
features.Image(122 * torch.ones(1, 3, 8, 8)),
1506+
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
1507+
],
1508+
)
1509+
def test__transform(self, inpt):
1510+
1511+
v = 121 * torch.ones(3 * 8 * 8)
1512+
m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
1513+
transform = transforms.LinearTransformation(m, v)
1514+
1515+
if isinstance(inpt, PIL.Image.Image):
1516+
with pytest.raises(TypeError, match="Unsupported input type"):
1517+
transform(inpt)
1518+
else:
1519+
output = transform(inpt)
1520+
assert isinstance(output, torch.Tensor)
1521+
assert output.unique() == 3 * 8 * 8
1522+
assert output.dtype == inpt.dtype

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
TenCrop,
3838
)
3939
from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
40-
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
40+
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, ToDtype
4141
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
4242

4343
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip

torchvision/prototype/transforms/_misc.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import functools
22
from typing import Any, Callable, Dict, List, Sequence, Type, Union
33

4+
import PIL.Image
5+
46
import torch
7+
from torchvision.prototype import features
58
from torchvision.prototype.transforms import functional as F, Transform
69
from torchvision.transforms.transforms import _setup_size
710

@@ -32,6 +35,59 @@ def extra_repr(self) -> str:
3235
return ", ".join(extras)
3336

3437

38+
class LinearTransformation(Transform):
39+
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
40+
super().__init__()
41+
if transformation_matrix.size(0) != transformation_matrix.size(1):
42+
raise ValueError(
43+
"transformation_matrix should be square. Got "
44+
f"{tuple(transformation_matrix.size())} rectangular matrix."
45+
)
46+
47+
if mean_vector.size(0) != transformation_matrix.size(0):
48+
raise ValueError(
49+
f"mean_vector should have the same length {mean_vector.size(0)}"
50+
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
51+
)
52+
53+
if transformation_matrix.device != mean_vector.device:
54+
raise ValueError(
55+
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
56+
)
57+
58+
self.transformation_matrix = transformation_matrix
59+
self.mean_vector = mean_vector
60+
61+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
62+
63+
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
64+
return inpt
65+
elif isinstance(inpt, PIL.Image.Image):
66+
raise TypeError("Unsupported input type")
67+
68+
# Image instance after linear transformation is not Image anymore due to unknown data range
69+
# Thus we will return Tensor for input Image
70+
71+
shape = inpt.shape
72+
n = shape[-3] * shape[-2] * shape[-1]
73+
if n != self.transformation_matrix.shape[0]:
74+
raise ValueError(
75+
"Input tensor and transformation matrix have incompatible shape."
76+
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
77+
+ f"{self.transformation_matrix.shape[0]}"
78+
)
79+
80+
if inpt.device.type != self.mean_vector.device.type:
81+
raise ValueError(
82+
"Input tensor should be on the same device as transformation matrix and mean vector. "
83+
f"Got {inpt.device} vs {self.mean_vector.device}"
84+
)
85+
86+
flat_tensor = inpt.view(-1, n) - self.mean_vector
87+
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
88+
return transformed_tensor.view(shape)
89+
90+
3591
class Normalize(Transform):
3692
def __init__(self, mean: List[float], std: List[float]):
3793
super().__init__()

0 commit comments

Comments
 (0)