@@ -124,7 +124,24 @@ def pil_to_tensor(pic):
124
124
return img
125
125
126
126
127
- def convert_image_dtype (image : torch .Tensor , dtype : torch .dtype = torch .float ) -> torch .Tensor :
127
+ # torch.iinfo isn't scriptable so using this helper function
128
+ # https://github.com/pytorch/pytorch/issues/41492
129
+ def _max_value (dtype : int ) -> int :
130
+ a = torch .tensor (2 , dtype = dtype )
131
+ signed = 1 if torch .tensor (0 , dtype = dtype ).is_signed () else 0
132
+ bits = 1
133
+ max_value = torch .tensor (- signed , dtype = torch .long )
134
+ while (True ):
135
+ next_value = a .pow (bits - signed ).sub (1 )
136
+ if next_value > max_value :
137
+ max_value = next_value
138
+ bits *= 2
139
+ else :
140
+ return max_value .item ()
141
+ return max_value .item ()
142
+
143
+
144
+ def convert_image_dtype (image : torch .Tensor , dtype : int = torch .float ) -> torch .Tensor :
128
145
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
129
146
130
147
Args:
@@ -148,9 +165,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
148
165
if image .dtype == dtype :
149
166
return image
150
167
151
- if image .dtype .is_floating_point :
168
+ if torch . empty ( 0 , dtype = image .dtype ) .is_floating_point () :
152
169
# float to float
153
- if dtype .is_floating_point :
170
+ if torch . tensor ( 0 , dtype = dtype ) .is_floating_point () :
154
171
return image .to (dtype )
155
172
156
173
# float to int
@@ -166,19 +183,19 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
166
183
# `max + 1 - epsilon` provides more evenly distributed mapping of
167
184
# ranges of floats to ints.
168
185
eps = 1e-3
169
- result = image .mul (torch .iinfo (dtype ).max + 1 - eps )
186
+ max_val = _max_value (dtype )
187
+ result = image .mul (max_val + 1.0 - eps )
170
188
return result .to (dtype )
171
189
else :
190
+ input_max = _max_value (image .dtype )
191
+ output_max = _max_value (dtype )
192
+
172
193
# int to float
173
- if dtype .is_floating_point :
174
- max = torch .iinfo (image .dtype ).max
194
+ if torch .tensor (0 , dtype = dtype ).is_floating_point ():
175
195
image = image .to (dtype )
176
- return image / max
196
+ return image / input_max
177
197
178
198
# int to int
179
- input_max = torch .iinfo (image .dtype ).max
180
- output_max = torch .iinfo (dtype ).max
181
-
182
199
if input_max > output_max :
183
200
factor = (input_max + 1 ) // (output_max + 1 )
184
201
image = image // factor
0 commit comments