diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index c4d4cad350e..a961ce0e133 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -66,7 +66,10 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') - entry = pickle.load(fo) + if sys.version_info[0] == 2: + entry = pickle.load(fo) + else: + entry = pickle.load(fo, encoding='latin1') self.test_data = entry['data'] if 'labels' in entry: self.test_labels = entry['labels']