Skip to content

Commit 60d0c95

Browse files
Bordalexierule
authored andcommitted
fix importing torchtext batch (#6365)
* copy torchtext batch * update * rev * rev
1 parent 7a39789 commit 60d0c95

File tree

6 files changed

+25
-15
lines changed

6 files changed

+25
-15
lines changed

.github/workflows/docs-checks.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ jobs:
4141
4242
- name: Install dependencies
4343
run: |
44+
python --version
45+
pip --version
4446
# remove Horovod from requirements
4547
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)"
4648
# python -m pip install --upgrade --user pip
4749
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
4850
pip install --requirement requirements/extra.txt
4951
pip install --requirement requirements/loggers.txt
5052
pip install --requirement requirements/docs.txt
51-
python --version
52-
pip --version
5353
pip list
5454
shell: bash
5555

@@ -84,12 +84,12 @@ jobs:
8484
8585
- name: Install dependencies
8686
run: |
87-
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
87+
python --version
88+
pip --version
89+
# pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
8890
pip install --requirement requirements/docs.txt
8991
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
9092
sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
91-
python --version
92-
pip --version
9393
pip list
9494
shell: bash
9595

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union
15+
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
1616

1717
import torch
1818
from torch.nn import Module

pytorch_lightning/utilities/apply_func.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import operator
1515
from abc import ABC
1616
from collections.abc import Mapping, Sequence
1717
from copy import copy
@@ -22,10 +22,13 @@
2222
import torch
2323

2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
25-
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE
25+
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE
2626

2727
if _TORCHTEXT_AVAILABLE:
28-
from torchtext.data import Batch
28+
if _compare_version("torchtext", operator.ge, "0.9.0"):
29+
from torchtext.legacy.data import Batch
30+
else:
31+
from torchtext.data import Batch
2932
else:
3033
Batch = type(None)
3134

tests/helpers/imports.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import operator
2+
3+
from pytorch_lightning.utilities.imports import _compare_version
4+
5+
if _compare_version("torchtext", operator.ge, "0.9.0"):
6+
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
7+
else:
8+
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401

tests/models/test_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
import pytest
1818
import torch
19-
from torchtext.data import Batch, Dataset, Example, Field, LabelField
2019

2120
import tests.helpers.pipelines as tpipes
2221
import tests.helpers.utils as tutils
2322
from pytorch_lightning import Trainer
2423
from pytorch_lightning.utilities import device_parser
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625
from tests.helpers import BoringModel
26+
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
2727

2828
PRETEND_N_OF_GPUS = 16
2929

tests/utilities/test_apply_func_torchtext.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
# limitations under the License.
1414
import pytest
1515
import torch
16-
import torchtext
17-
from torchtext.data.example import Example
1816

1917
from pytorch_lightning.utilities.apply_func import move_data_to_device
18+
from tests.helpers.imports import Dataset, Example, Field, Iterator
2019

2120

2221
def _get_torchtext_data_iterator(include_lengths=False):
23-
text_field = torchtext.data.Field(
22+
text_field = Field(
2423
sequential=True,
2524
pad_first=False, # nosec
2625
init_token="<s>",
@@ -32,13 +31,13 @@ def _get_torchtext_data_iterator(include_lengths=False):
3231
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
3332
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})
3433

35-
dataset = torchtext.data.Dataset(
34+
dataset = Dataset(
3635
[example1, example2, example3],
3736
{"text": text_field},
3837
)
3938
text_field.build_vocab(dataset)
4039

41-
iterator = torchtext.data.Iterator(
40+
iterator = Iterator(
4241
dataset,
4342
batch_size=3,
4443
sort_key=None,

0 commit comments

Comments
 (0)