diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index a467ebee554..7bbb4d54214 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -36,15 +36,20 @@ class MNIST(data.Dataset): training_file = 'training.pt' test_file = 'test.pt' - def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + def __init__(self, root, train=True, transform=None, target_transform=None, + download=False, from_local=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set + self.from_local = from_local if download: self.download() + elif self.from_local: + self.download() + if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') @@ -95,6 +100,7 @@ def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" from six.moves import urllib import gzip + import shutil if self._check_exists(): return @@ -110,12 +116,18 @@ def download(self): raise for url in self.urls: - print('Downloading ' + url) - data = urllib.request.urlopen(url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) - with open(file_path, 'wb') as f: - f.write(data.read()) + + if self.from_local: + tmp_file_path = os.path.join(self.root, filename) + shutil.move(tmp_file_path, os.path.join(self.root, self.raw_folder)) + else: + print('Downloading ' + url) + data = urllib.request.urlopen(url) + with open(file_path, 'wb') as f: + f.write(data.read()) + with open(file_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(file_path) as zip_f: out_f.write(zip_f.read())