From e659e27edc9dde86735777f7e83668c5a56912e7 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Fri, 11 Nov 2016 20:53:14 -0500 Subject: [PATCH] fix ToTensor to handle numpy --- test/cifar.py | 35 +++++++++++++++++++++++++++++------ torchvision/transforms.py | 16 +++++++++++----- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/test/cifar.py b/test/cifar.py index daf542fa800..c3272d85e8d 100644 --- a/test/cifar.py +++ b/test/cifar.py @@ -1,12 +1,35 @@ import torch import torchvision.datasets as dset +import torchvision.transforms as transforms -print('\n\nCifar 10') -a = dset.CIFAR10(root="abc/def/ghi", download=True) +# print('\n\nCifar 10') +# a = dset.CIFAR10(root="abc/def/ghi", download=True) -print(a[3]) +# print(a[3]) -print('\n\nCifar 100') -a = dset.CIFAR100(root="abc/def/ghi", download=True) +# print('\n\nCifar 100') +# a = dset.CIFAR100(root="abc/def/ghi", download=True) -print(a[3]) +# print(a[3]) + + +dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor()) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, + shuffle=True, num_workers=2) + + +# miter = dataloader.__iter__() +# def getBatch(): +# global miter +# try: +# return miter.next() +# except StopIteration: +# miter = dataloader.__iter__() +# return miter.next() + +# i=0 +# while True: +# print(i) +# img, target = getBatch() +# i+=1 + diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 64483f611e3..cfc0b5c8755 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -2,6 +2,7 @@ import math import random from PIL import Image +import numpy as np class Compose(object): @@ -16,11 +17,16 @@ def __call__(self, img): class ToTensor(object): def __call__(self, pic): - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - img = img.view(pic.size[0], pic.size[1], 3) - # put it in CHW format - # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 2).transpose(1, 2).contiguous() + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic) + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[0], pic.size[1], 3) + # put it in CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 2).transpose(1, 2).contiguous() return img.float() class Normalize(object):