Skip to content

Commit 1de9f17

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Remove to_tensor() and ToTensor() usages (#5553)
Summary: * Remove from models and references. * Adding most tests and docs. * Adding transforms tests. * Remove unnecesary ipython notebook. * Simplify tests. * Addressing comments. Reviewed By: vmoens Differential Revision: D34878980 fbshipit-source-id: 870b09c50697cbbdf8956b446712dcc4bbd5ed96
1 parent 0957981 commit 1de9f17

File tree

16 files changed

+68
-590
lines changed

16 files changed

+68
-590
lines changed

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ to::
179179
import torch
180180
from torchvision import datasets, transforms as T
181181

182-
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
182+
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.PILToTensor(), T.ConvertImageDtype(torch.float)])
183183
dataset = datasets.ImageNet(".", split="train", transform=transform)
184184

185185
means = []

references/detection/presets.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def __call__(self, img, target):
4141

4242
class DetectionPresetEval:
4343
def __init__(self):
44-
self.transforms = T.ToTensor()
44+
self.transforms = T.Compose(
45+
[
46+
T.PILToTensor(),
47+
T.ConvertImageDtype(torch.float),
48+
]
49+
)
4550

4651
def __call__(self, img, target):
4752
return self.transforms(img, target)

references/detection/transforms.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,6 @@ def forward(
4545
return image, target
4646

4747

48-
class ToTensor(nn.Module):
49-
def forward(
50-
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
51-
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
52-
image = F.pil_to_tensor(image)
53-
image = F.convert_image_dtype(image)
54-
return image, target
55-
56-
5748
class PILToTensor(nn.Module):
5849
def forward(
5950
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None

references/similarity/test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from collections import defaultdict
33

4+
import torch
45
import torchvision.transforms as transforms
56
from sampler import PKSampler
67
from torch.utils.data import DataLoader
@@ -17,7 +18,13 @@ def test_pksampler(self):
1718
self.assertRaises(AssertionError, PKSampler, targets, p, k)
1819

1920
# Ensure p, k constraints on batch
20-
dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=transforms.ToTensor())
21+
trans = transforms.Compose(
22+
[
23+
transforms.PILToTensor(),
24+
transforms.ConvertImageDtype(torch.float),
25+
]
26+
)
27+
dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=trans)
2128
targets = [target.item() for _, target in dataset]
2229
sampler = PKSampler(targets, p, k)
2330
loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)

references/similarity/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def main(args):
102102
optimizer = Adam(model.parameters(), lr=args.lr)
103103

104104
transform = transforms.Compose(
105-
[transforms.Lambda(lambda image: image.convert("RGB")), transforms.Resize((224, 224)), transforms.ToTensor()]
105+
[
106+
transforms.Lambda(lambda image: image.convert("RGB")),
107+
transforms.Resize((224, 224)),
108+
transforms.PILToTensor(),
109+
transforms.ConvertImageDtype(torch.float),
110+
]
106111
)
107112

108113
# Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can

test/preprocess-bench.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
[
3434
transforms.RandomSizedCrop(224),
3535
transforms.RandomHorizontalFlip(),
36-
transforms.ToTensor(),
36+
transforms.PILToTensor(),
37+
transforms.ConvertImageDtype(torch.float),
3738
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
3839
]
3940
)

test/sanity_checks.ipynb

Lines changed: 0 additions & 529 deletions
This file was deleted.

test/test_cpp_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def read_image1():
3030
)
3131
image = Image.open(image_path)
3232
image = image.resize((224, 224))
33-
x = F.to_tensor(image)
33+
x = F.pil_to_tensor(image)
34+
x = F.convert_image_dtype(x)
3435
return x.view(1, 3, 224, 224)
3536

3637

@@ -40,7 +41,8 @@ def read_image2():
4041
)
4142
image = Image.open(image_path)
4243
image = image.resize((299, 299))
43-
x = F.to_tensor(image)
44+
x = F.pil_to_tensor(image)
45+
x = F.convert_image_dtype(x)
4446
x = x.view(1, 3, 299, 299)
4547
return torch.cat([x, x], 0)
4648

test/test_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,13 @@ def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
413413
import os
414414

415415
from PIL import Image
416-
from torchvision import transforms
416+
from torchvision.transforms import functional as F
417417

418418
data_dir = os.path.join(os.path.dirname(__file__), "assets")
419419
path = os.path.join(data_dir, *rel_path.split("/"))
420420
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
421421

422-
return transforms.ToTensor()(image)
422+
return F.convert_image_dtype(F.pil_to_tensor(image))
423423

424424
def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
425425
return (

test/test_transforms.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_int_to_int_consistency(self, input_dtype, output_dtype):
154154
@pytest.mark.skipif(accimage is None, reason="accimage not available")
155155
class TestAccImage:
156156
def test_accimage_to_tensor(self):
157-
trans = transforms.ToTensor()
157+
trans = transforms.PILToTensor()
158158

159159
expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
160160
output = trans(accimage.Image(GRACE_HOPPER))
@@ -174,7 +174,8 @@ def test_accimage_resize(self):
174174
trans = transforms.Compose(
175175
[
176176
transforms.Resize(256, interpolation=Image.LINEAR),
177-
transforms.ToTensor(),
177+
transforms.PILToTensor(),
178+
transforms.ConvertImageDtype(dtype=torch.float),
178179
]
179180
)
180181

@@ -192,10 +193,7 @@ def test_accimage_resize(self):
192193

193194
def test_accimage_crop(self):
194195
trans = transforms.Compose(
195-
[
196-
transforms.CenterCrop(256),
197-
transforms.ToTensor(),
198-
]
196+
[transforms.CenterCrop(256), transforms.PILToTensor(), transforms.ConvertImageDtype(dtype=torch.float)]
199197
)
200198

201199
# Checking if Compose, CenterCrop and ToTensor can be printed as string
@@ -457,26 +455,24 @@ class TestPad:
457455
def test_pad(self):
458456
height = random.randint(10, 32) * 2
459457
width = random.randint(10, 32) * 2
460-
img = torch.ones(3, height, width)
458+
img = torch.ones(3, height, width, dtype=torch.uint8)
461459
padding = random.randint(1, 20)
462460
fill = random.randint(1, 50)
463461
result = transforms.Compose(
464462
[
465463
transforms.ToPILImage(),
466464
transforms.Pad(padding, fill=fill),
467-
transforms.ToTensor(),
465+
transforms.PILToTensor(),
468466
]
469467
)(img)
470468
assert result.size(1) == height + 2 * padding
471469
assert result.size(2) == width + 2 * padding
472470
# check that all elements in the padded region correspond
473471
# to the pad value
474-
fill_v = fill / 255
475-
eps = 1e-5
476472
h_padded = result[:, :padding, :]
477473
w_padded = result[:, :, :padding]
478-
torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps)
479-
torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps)
474+
torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill), rtol=0.0, atol=0.0)
475+
torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill), rtol=0.0, atol=0.0)
480476
pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img))
481477

482478
def test_pad_with_tuple_of_pad_values(self):
@@ -509,23 +505,23 @@ def test_pad_with_non_constant_padding_modes(self):
509505
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
510506
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
511507
assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8))
512-
assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35)
508+
assert transforms.PILToTensor()(edge_padded_img).size() == (3, 35, 35)
513509

514510
# Pad 3 to left/right, 2 to top/bottom
515511
reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect")
516512
# First 6 elements of leftmost edge in the middle of the image, values are in order:
517513
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
518514
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
519515
assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8))
520-
assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35)
516+
assert transforms.PILToTensor()(reflect_padded_img).size() == (3, 33, 35)
521517

522518
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom
523519
symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric")
524520
# First 6 elements of leftmost edge in the middle of the image, values are in order:
525521
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
526522
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
527523
assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8))
528-
assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34)
524+
assert transforms.PILToTensor()(symmetric_padded_img).size() == (3, 32, 34)
529525

530526
# Check negative padding explicitly for symmetric case, since it is not
531527
# implemented for tensor case to compare to
@@ -535,7 +531,7 @@ def test_pad_with_non_constant_padding_modes(self):
535531
symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
536532
assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
537533
assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8))
538-
assert transforms.ToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
534+
assert transforms.PILToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
539535

540536
def test_pad_raises_with_invalid_pad_sequence_len(self):
541537
with pytest.raises(ValueError):
@@ -1625,12 +1621,12 @@ def test_random_crop():
16251621
width = random.randint(10, 32) * 2
16261622
oheight = random.randint(5, (height - 2) / 2) * 2
16271623
owidth = random.randint(5, (width - 2) / 2) * 2
1628-
img = torch.ones(3, height, width)
1624+
img = torch.ones(3, height, width, dtype=torch.uint8)
16291625
result = transforms.Compose(
16301626
[
16311627
transforms.ToPILImage(),
16321628
transforms.RandomCrop((oheight, owidth)),
1633-
transforms.ToTensor(),
1629+
transforms.PILToTensor(),
16341630
]
16351631
)(img)
16361632
assert result.size(1) == oheight
@@ -1641,14 +1637,14 @@ def test_random_crop():
16411637
[
16421638
transforms.ToPILImage(),
16431639
transforms.RandomCrop((oheight, owidth), padding=padding),
1644-
transforms.ToTensor(),
1640+
transforms.PILToTensor(),
16451641
]
16461642
)(img)
16471643
assert result.size(1) == oheight
16481644
assert result.size(2) == owidth
16491645

16501646
result = transforms.Compose(
1651-
[transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.ToTensor()]
1647+
[transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.PILToTensor()]
16521648
)(img)
16531649
assert result.size(1) == height
16541650
assert result.size(2) == width
@@ -1658,7 +1654,7 @@ def test_random_crop():
16581654
[
16591655
transforms.ToPILImage(),
16601656
transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
1661-
transforms.ToTensor(),
1657+
transforms.PILToTensor(),
16621658
]
16631659
)(img)
16641660
assert result.size(1) == height + 1
@@ -1676,7 +1672,7 @@ def test_center_crop():
16761672
oheight = random.randint(5, (height - 2) / 2) * 2
16771673
owidth = random.randint(5, (width - 2) / 2) * 2
16781674

1679-
img = torch.ones(3, height, width)
1675+
img = torch.ones(3, height, width, dtype=torch.uint8)
16801676
oh1 = (height - oheight) // 2
16811677
ow1 = (width - owidth) // 2
16821678
imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth]
@@ -1685,7 +1681,7 @@ def test_center_crop():
16851681
[
16861682
transforms.ToPILImage(),
16871683
transforms.CenterCrop((oheight, owidth)),
1688-
transforms.ToTensor(),
1684+
transforms.PILToTensor(),
16891685
]
16901686
)(img)
16911687
assert result.sum() == 0
@@ -1695,7 +1691,7 @@ def test_center_crop():
16951691
[
16961692
transforms.ToPILImage(),
16971693
transforms.CenterCrop((oheight, owidth)),
1698-
transforms.ToTensor(),
1694+
transforms.PILToTensor(),
16991695
]
17001696
)(img)
17011697
sum1 = result.sum()
@@ -1706,7 +1702,7 @@ def test_center_crop():
17061702
[
17071703
transforms.ToPILImage(),
17081704
transforms.CenterCrop((oheight, owidth)),
1709-
transforms.ToTensor(),
1705+
transforms.PILToTensor(),
17101706
]
17111707
)(img)
17121708
sum2 = result.sum()
@@ -1729,12 +1725,12 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
17291725
delta_height *= delta
17301726
delta_width *= delta
17311727

1732-
img = torch.ones(3, *input_image_size)
1728+
img = torch.ones(3, *input_image_size, dtype=torch.uint8)
17331729
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
17341730

17351731
# Test both transforms, one with PIL input and one with tensor
17361732
output_pil = transforms.Compose(
1737-
[transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.ToTensor()],
1733+
[transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.PILToTensor()],
17381734
)(img)
17391735
assert output_pil.size()[1:3] == crop_size
17401736

@@ -1893,13 +1889,13 @@ def test_randomperspective():
18931889
perp = transforms.RandomPerspective()
18941890
startpoints, endpoints = perp.get_params(width, height, 0.5)
18951891
tr_img = F.perspective(img, startpoints, endpoints)
1896-
tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
1897-
tr_img = F.to_tensor(tr_img)
1892+
tr_img2 = F.convert_image_dtype(F.pil_to_tensor(F.perspective(tr_img, endpoints, startpoints)))
1893+
tr_img = F.convert_image_dtype(F.pil_to_tensor(tr_img))
18981894
assert img.size[0] == width
18991895
assert img.size[1] == height
1900-
assert torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > torch.nn.functional.mse_loss(
1901-
tr_img2, F.to_tensor(img)
1902-
)
1896+
assert torch.nn.functional.mse_loss(
1897+
tr_img, F.convert_image_dtype(F.pil_to_tensor(img))
1898+
) + 0.3 > torch.nn.functional.mse_loss(tr_img2, F.convert_image_dtype(F.pil_to_tensor(img)))
19031899

19041900

19051901
@pytest.mark.parametrize("seed", range(10))

test/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_save_image_file_object():
7676
fp = BytesIO()
7777
utils.save_image(t, fp, format="png")
7878
img_bytes = Image.open(fp)
79-
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object")
79+
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
8080

8181

8282
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
@@ -88,7 +88,7 @@ def test_save_image_single_pixel_file_object():
8888
fp = BytesIO()
8989
utils.save_image(t, fp, format="png")
9090
img_bytes = Image.open(fp)
91-
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object")
91+
assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
9292

9393

9494
def test_draw_boxes():

torchvision/datasets/celeba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class CelebA(VisionDataset):
3232
Defaults to ``attr``. If empty, ``None`` will be returned as target.
3333
3434
transform (callable, optional): A function/transform that takes in an PIL image
35-
and returns a transformed version. E.g, ``transforms.ToTensor``
35+
and returns a transformed version. E.g, ``transforms.PILToTensor``
3636
target_transform (callable, optional): A function/transform that takes in the
3737
target and transforms it.
3838
download (bool, optional): If true, downloads the dataset from the internet and

torchvision/datasets/coco.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CocoDetection(VisionDataset):
1515
root (string): Root directory where images are downloaded to.
1616
annFile (string): Path to json annotation file.
1717
transform (callable, optional): A function/transform that takes in an PIL image
18-
and returns a transformed version. E.g, ``transforms.ToTensor``
18+
and returns a transformed version. E.g, ``transforms.PILToTensor``
1919
target_transform (callable, optional): A function/transform that takes in the
2020
target and transforms it.
2121
transforms (callable, optional): A function/transform that takes input sample and its target as entry
@@ -66,7 +66,7 @@ class CocoCaptions(CocoDetection):
6666
root (string): Root directory where images are downloaded to.
6767
annFile (string): Path to json annotation file.
6868
transform (callable, optional): A function/transform that takes in an PIL image
69-
and returns a transformed version. E.g, ``transforms.ToTensor``
69+
and returns a transformed version. E.g, ``transforms.PILToTensor``
7070
target_transform (callable, optional): A function/transform that takes in the
7171
target and transforms it.
7272
transforms (callable, optional): A function/transform that takes input sample and its target as entry
@@ -80,7 +80,7 @@ class CocoCaptions(CocoDetection):
8080
import torchvision.transforms as transforms
8181
cap = dset.CocoCaptions(root = 'dir where images are',
8282
annFile = 'json annotation file',
83-
transform=transforms.ToTensor())
83+
transform=transforms.PILToTensor())
8484
8585
print('Number of samples: ', len(cap))
8686
img, target = cap[3] # load 4th sample

0 commit comments

Comments
 (0)