@@ -76,12 +76,7 @@ def _check_inputs(self, sample: Any) -> Any:
76
76
if has_any (sample , PIL .Image .Image ):
77
77
raise TypeError ("LinearTransformation does not work on PIL Images" )
78
78
79
- def _transform (
80
- self , inpt : Union [datapoints .TensorImageType , datapoints .TensorVideoType ], params : Dict [str , Any ]
81
- ) -> torch .Tensor :
82
- # Image instance after linear transformation is not Image anymore due to unknown data range
83
- # Thus we will return Tensor for input Image
84
-
79
+ def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
85
80
shape = inpt .shape
86
81
n = shape [- 3 ] * shape [- 2 ] * shape [- 1 ]
87
82
if n != self .transformation_matrix .shape [0 ]:
@@ -97,11 +92,15 @@ def _transform(
97
92
f"Got { inpt .device } vs { self .mean_vector .device } "
98
93
)
99
94
100
- flat_tensor = inpt .reshape (- 1 , n ) - self .mean_vector
95
+ flat_inpt = inpt .reshape (- 1 , n ) - self .mean_vector
96
+
97
+ transformation_matrix = self .transformation_matrix .to (flat_inpt .dtype )
98
+ output = torch .mm (flat_inpt , transformation_matrix )
99
+ output = output .reshape (shape )
101
100
102
- transformation_matrix = self . transformation_matrix . to ( flat_tensor . dtype )
103
- transformed_tensor = torch . mm ( flat_tensor , transformation_matrix )
104
- return transformed_tensor . reshape ( shape )
101
+ if isinstance ( inpt , ( datapoints . Image , datapoints . Video )):
102
+ output = type ( inpt ). wrap_like ( inpt , output ) # type: ignore[arg-type]
103
+ return output
105
104
106
105
107
106
class Normalize (Transform ):
@@ -120,7 +119,7 @@ def _check_inputs(self, sample: Any) -> Any:
120
119
121
120
def _transform (
122
121
self , inpt : Union [datapoints .TensorImageType , datapoints .TensorVideoType ], params : Dict [str , Any ]
123
- ) -> torch . Tensor :
122
+ ) -> Any :
124
123
return F .normalize (inpt , mean = self .mean , std = self .std , inplace = self .inplace )
125
124
126
125
0 commit comments