Skip to content

Commit e848542

Browse files
authored
CI: fix examples - patch download MNIST (#6357)
* patch download * CI * isort * extra
1 parent b6aa350 commit e848542

File tree

11 files changed

+58
-31
lines changed

11 files changed

+58
-31
lines changed

.github/workflows/ci_test-full.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,9 @@ jobs:
143143
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
144144
coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
145145
146-
# todo: put this back just when TorchVision can download datasets
147-
#- name: Examples
148-
# run: |
149-
# python -m pytest pl_examples -v --durations=10
146+
- name: Examples
147+
run: |
148+
python -m pytest pl_examples -v --durations=10
150149
151150
- name: Upload pytest test results
152151
uses: actions/upload-artifact@v2

azure-pipelines.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ jobs:
100100
python -m pytest benchmarks -v --maxfail=2 --durations=0
101101
displayName: 'Testing: benchmarks'
102102
103-
# todo: put this back just when TorchVision can download datasets
104-
#- bash: |
105-
# python -m pytest pl_examples -v --maxfail=2 --durations=0
106-
# python setup.py install --user --quiet
107-
# bash pl_examples/run_ddp-example.sh
108-
# pip uninstall -y pytorch-lightning
109-
# displayName: 'Examples'
103+
- bash: |
104+
python -m pytest pl_examples -v --maxfail=2 --durations=0
105+
python setup.py install --user --quiet
106+
bash pl_examples/run_ddp-example.sh
107+
cd pl_examples/basic_examples
108+
bash submit_ddp_job.sh
109+
bash submit_ddp2_job.sh
110+
pip uninstall -y pytorch-lightning
111+
displayName: 'Examples'

pl_examples/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
11
import os
2+
from urllib.error import HTTPError
3+
4+
from six.moves import urllib
25

36
from pytorch_lightning.utilities import _module_available
47

8+
# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
9+
opener = urllib.request.build_opener()
10+
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
11+
urllib.request.install_opener(opener)
12+
513
_EXAMPLES_ROOT = os.path.dirname(__file__)
614
_PACKAGE_ROOT = os.path.dirname(_EXAMPLES_ROOT)
715
_DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets')
816

917
_TORCHVISION_AVAILABLE = _module_available("torchvision")
18+
_TORCHVISION_MNIST_AVAILABLE = True
1019
_DALI_AVAILABLE = _module_available("nvidia.dali")
1120

21+
if _TORCHVISION_AVAILABLE:
22+
try:
23+
from torchvision.datasets.mnist import MNIST
24+
MNIST(_DATASETS_PATH, download=True)
25+
except HTTPError:
26+
_TORCHVISION_MNIST_AVAILABLE = False
27+
1228
LIGHTNING_LOGO = """
1329
####
1430
###########

pl_examples/basic_examples/autoencoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from torch.utils.data import DataLoader, random_split
2121

2222
import pytorch_lightning as pl
23-
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo
23+
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
2424

25-
if _TORCHVISION_AVAILABLE:
25+
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
2626
from torchvision import transforms
2727
from torchvision.datasets.mnist import MNIST
2828
else:

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from torch.utils.data import DataLoader, random_split
2020

2121
import pytorch_lightning as pl
22-
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo
22+
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
2323

24-
if _TORCHVISION_AVAILABLE:
24+
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
2525
from torchvision import transforms
2626
from torchvision.datasets.mnist import MNIST
2727
else:

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@
2323
from torch.utils.data import random_split
2424

2525
import pytorch_lightning as pl
26-
from pl_examples import _DALI_AVAILABLE, _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo
27-
28-
if _TORCHVISION_AVAILABLE:
26+
from pl_examples import (
27+
_DALI_AVAILABLE,
28+
_DATASETS_PATH,
29+
_TORCHVISION_AVAILABLE,
30+
_TORCHVISION_MNIST_AVAILABLE,
31+
cli_lightning_logo,
32+
)
33+
34+
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
2935
from torchvision import transforms
3036
from torchvision.datasets.mnist import MNIST
3137
else:

pl_examples/basic_examples/mnist_datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from torch.utils.data import DataLoader, random_split
1919

20-
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE
20+
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE
2121
from pytorch_lightning import LightningDataModule
2222

23-
if _TORCHVISION_AVAILABLE:
23+
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
2424
from torchvision import transforms as transform_lib
2525
from torchvision.datasets import MNIST
2626
else:

pl_examples/domain_templates/generative_adversarial_net.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@
2626
import torch
2727
import torch.nn as nn
2828
import torch.nn.functional as F # noqa
29-
import torchvision
30-
import torchvision.transforms as transforms
3129
from torch.utils.data import DataLoader
32-
from torchvision.datasets import MNIST
3330

34-
from pl_examples import cli_lightning_logo
31+
from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
3532
from pytorch_lightning.core import LightningDataModule, LightningModule
3633
from pytorch_lightning.trainer import Trainer
3734

35+
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
36+
import torchvision
37+
import torchvision.transforms as transforms
38+
from torchvision.datasets import MNIST
39+
else:
40+
from tests.helpers.datasets import MNIST
41+
3842

3943
class Generator(nn.Module):
4044
"""

pytorch_lightning/callbacks/pruning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
import logging
2020
from copy import deepcopy
2121
from functools import partial
22-
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
22+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2323

2424
import torch
2525
import torch.nn.utils.prune as pytorch_prune
2626
from torch import nn
2727

2828
from pytorch_lightning.callbacks.base import Callback
2929
from pytorch_lightning.core.lightning import LightningModule
30-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
30+
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232

3333
log = logging.getLogger(__name__)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union, Dict
15+
from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union
1616

1717
import torch
1818
from torch.nn import Module

0 commit comments

Comments
 (0)