Closed
Description
torch.Tensor.to
accepts a tensor reference object as input in which case it converts to the reference dtype and device:
>>> a = torch.zeros((), dtype=torch.float)
>>> b = torch.ones((), dtype=torch.int)
>>> b
tensor(1, dtype=torch.int32)
>>> b.to(a)
tensor(1.)
This also works for our custom features:
>>> data = torch.rand(3, 2, 2)
>>> image = features.Image(data.to(torch.float64))
>>> image
Image([[[0.9595, 0.0947],
[0.9553, 0.5563]],
[[0.5435, 0.2975],
[0.3037, 0.0863]],
[[0.6253, 0.3481],
[0.3518, 0.4499]]], dtype=torch.float64)
>>> image.to(torch.float32)
Image([[[0.9595, 0.0947],
[0.9553, 0.5563]],
[[0.5435, 0.2975],
[0.3037, 0.0863]],
[[0.6253, 0.3481],
[0.3518, 0.4499]]])
>>> image.to(data)
Image([[[0.9595, 0.0947],
[0.9553, 0.5563]],
[[0.5435, 0.2975],
[0.3037, 0.0863]],
[[0.6253, 0.3481],
[0.3518, 0.4499]]])
However, it doesn't work if we want to convert a plain tensor and use a custom feature as reference:
>>> data.to(image)
AttributeError: 'Tensor' object has no attribute 'color_space'