Skip to content

Commit 78a4f0d

Browse files
kaushikb11pre-commit-ci[bot]
authored andcommitted
Add __len__ method to IndexBatchSamplerWrapper (#7681)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e22a3cf commit 78a4f0d

File tree

3 files changed

+201
-5
lines changed

3 files changed

+201
-5
lines changed

CHANGELOG.md

Lines changed: 185 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,200 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.3.3] - 2021-05-27
8+
## [1.4.0] - 2021-MM-DD
9+
10+
### Added
11+
12+
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
13+
14+
15+
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
16+
17+
18+
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
19+
20+
21+
- Added `ModelPruning(prune_on_train_epoch_end=True|False)` to choose when to apply pruning ([#7704](https://github.com/PyTorchLightning/pytorch-lightning/pull/7704))
22+
23+
24+
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))
25+
26+
27+
- Added dataclasses for progress tracking (
28+
[#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603),
29+
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
30+
31+
32+
- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))
33+
34+
35+
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
36+
37+
38+
- Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195))
39+
40+
41+
- Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241))
42+
43+
44+
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))
45+
46+
47+
- Added `__len__` to `IndexBatchSamplerWrapper` ([#7681](https://github.com/PyTorchLightning/pytorch-lightning/pull/7681))
48+
49+
50+
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))
51+
52+
53+
- Added `save_config_filename` init argument to `LightningCLI` to ease resolving name conflicts ([#7741](https://github.com/PyTorchLightning/pytorch-lightning/pull/7741))
54+
955

1056
### Changed
1157

12-
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563))
58+
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
59+
60+
- Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539))
61+
62+
63+
- Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272))
64+
65+
66+
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
67+
68+
69+
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
70+
71+
72+
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))
73+
74+
75+
- Refactored Loops
76+
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
77+
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
78+
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
79+
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
80+
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
81+
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
82+
83+
84+
- Refactored logging
85+
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
86+
87+
88+
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
89+
90+
91+
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))
92+
93+
94+
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
95+
96+
97+
- Default `seed_everything(workers=True)` in the `LightningCLI` ([#7504](https://github.com/PyTorchLightning/pytorch-lightning/pull/7504))
98+
99+
100+
- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([#7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))
101+
102+
103+
- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))
104+
105+
106+
- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))
107+
- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))
108+
109+
110+
- Changed `teardown()` in `Accelerator` to allow `training_type_plugin` to customize `teardown` logic ([#7579](https://github.com/PyTorchLightning/pytorch-lightning/pull/7579))
111+
112+
113+
### Deprecated
114+
115+
116+
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
117+
118+
119+
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
120+
121+
122+
### Removed
123+
124+
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
125+
126+
127+
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
128+
129+
130+
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([#7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
131+
132+
133+
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([#7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))
134+
135+
136+
- Removed support for `self.log(tbptt_reduce_fx)` and `self.log(tbptt_pad_token)`. Please, open a discussion explaining your use-case if you relied on these. ([#7644](https://github.com/PyTorchLightning/pytorch-lightning/pull/7644))
137+
138+
139+
- Removed deprecated utils modules `model_utils`, `warning_utils`, `xla_device_utils` and partially `argparse_utils` ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))
140+
141+
142+
- Removed deprecated trainer attributes - `on_cpu`, `on_tpu`, `use_tpu`, `on_gpu`, `use_dp`, `use_ddp`, `use_ddp2`, `use_horovod`, `use_single_gpu` ([#7501](https://github.com/PyTorchLightning/pytorch-lightning/pull/7501))
143+
13144

14145
### Fixed
15146

16-
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
17-
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
147+
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
148+
18149
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
19-
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
150+
151+
20152
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
153+
154+
21155
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
156+
157+
158+
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
159+
160+
161+
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
162+
163+
164+
- Fixed `None` loss keys getting added in `training_epoch_end` when using manual optimization and not returning a loss ([#7772](https://github.com/PyTorchLightning/pytorch-lightning/pull/7772))
165+
166+
167+
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
168+
169+
170+
- Fixed formatting of info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
171+
172+
173+
## [1.3.4] - 2021-06-01
174+
175+
### Changed
176+
177+
### Fixed
178+
179+
- Update pre-commit and add new hooks ([#7781](https://github.com/PyTorchLightning/pytorch-lightning/pull/7781))
180+
- fix info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
181+
- Fixed missing __len__ method to IndexBatchSamplerWrapper ([#7681](https://github.com/PyTorchLightning/pytorch-lightning/pull/7681))
182+
183+
184+
185+
186+
## [1.3.3] - 2021-05-27
187+
188+
### Changed
189+
190+
191+
- Move parameter validation specific to TPU Training plugins ([#7415](https://github.com/PyTorchLightning/pytorch-lightning/pull/7415))
192+
193+
### Fixed
194+
195+
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
196+
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
197+
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
198+
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
199+
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
200+
- Fix/mismatched toggle optimizer ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563))
201+
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
22202
- Fixed formatting of info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
23203

24204

pytorch_lightning/overrides/distributed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def __iter__(self) -> Iterator[List[int]]:
132132
self.batch_indices = batch
133133
yield batch
134134

135+
def __len__(self) -> int:
136+
return len(self._sampler)
137+
135138
@property
136139
def drop_last(self) -> bool:
137140
return self._sampler.drop_last

tests/overrides/test_distributed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections.abc import Iterable
15+
1416
import pytest
1517
from torch.utils.data import BatchSampler, SequentialSampler
1618

1719
from pytorch_lightning import seed_everything
1820
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
21+
from pytorch_lightning.utilities.data import has_len
1922

2023

2124
@pytest.mark.parametrize("shuffle", [False, True])
@@ -54,3 +57,13 @@ def test_index_batch_sampler(tmpdir):
5457

5558
for batch in index_batch_sampler:
5659
assert index_batch_sampler.batch_indices == batch
60+
61+
62+
def test_index_batch_sampler_methods():
63+
dataset = range(15)
64+
sampler = SequentialSampler(dataset)
65+
batch_sampler = BatchSampler(sampler, 3, False)
66+
index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
67+
68+
assert isinstance(index_batch_sampler, Iterable)
69+
assert has_len(index_batch_sampler)

0 commit comments

Comments
 (0)