Skip to content

Commit df93022

Browse files
committed
Update test
1 parent 138fc3a commit df93022

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

tests/models/test_tpu.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -495,24 +495,3 @@ def teardown(self, stage):
495495

496496
model = DebugModel()
497497
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
498-
499-
500-
@RunIf(tpu=True)
501-
@pl_multi_process_test
502-
def test_predict_tpu_multi_cores(tmpdir):
503-
"""Test if predict works for Multi TPU cores"""
504-
505-
tutils.reset_seed()
506-
model = BoringModel()
507-
trainer = Trainer(
508-
default_root_dir=tmpdir,
509-
progress_bar_refresh_rate=0,
510-
max_epochs=2,
511-
limit_train_batches=2,
512-
limit_val_batches=2,
513-
tpu_cores=8,
514-
)
515-
trainer.fit(model)
516-
trainer.predict(model, DataLoader(RandomDataset(32, 2000), batch_size=32))
517-
518-
assert trainer.state.finished, f"Training failed with {trainer.state}"

tests/overrides/test_distributed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections.abc import Iterable
15+
1416
import pytest
1517
from torch.utils.data import BatchSampler, SequentialSampler
1618

1719
from pytorch_lightning import seed_everything
1820
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
21+
from pytorch_lightning.utilities.data import has_len
1922

2023

2124
@pytest.mark.parametrize("shuffle", [False, True])
@@ -54,3 +57,13 @@ def test_index_batch_sampler(tmpdir):
5457

5558
for batch in index_batch_sampler:
5659
assert index_batch_sampler.batch_indices == batch
60+
61+
62+
def test_index_batch_sampler_methods():
63+
dataset = range(15)
64+
sampler = SequentialSampler(dataset)
65+
batch_sampler = BatchSampler(sampler, 3, False)
66+
index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
67+
68+
assert isinstance(index_batch_sampler, Iterable)
69+
assert has_len(index_batch_sampler)

0 commit comments

Comments
 (0)