-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
46 lines (37 loc) · 1.41 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import logging
import os
import shutil
import urllib.parse
from typing import Any, Dict
import requests
from torchvision import datasets, transforms
def get_dataset(data_dir: str, train: bool) -> Any:
return datasets.MNIST(
data_dir,
train=train,
transform=transforms.Compose(
[
transforms.ToTensor(),
# These are the precomputed mean and standard deviation of the
# MNIST data; this normalizes the data to have zero mean and unit
# standard deviation.
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
def download_dataset(download_directory: str, data_config: Dict[str, Any]) -> str:
url = data_config["url"]
url_path = urllib.parse.urlparse(url).path
basename = url_path.rsplit("/", 1)[1]
download_directory = os.path.join(download_directory, "MNIST")
os.makedirs(download_directory, exist_ok=True)
filepath = os.path.join(download_directory, basename)
if not os.path.exists(filepath):
logging.info("Downloading {} to {}".format(url, filepath))
r = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
shutil.unpack_archive(filepath, download_directory)
return os.path.dirname(download_directory)