Skip to content

Commit 42d88f7

Browse files
authored
Merge branch 'master' into bugfix/7930_parent_module_w_param
2 parents 59abbe5 + f15ea60 commit 42d88f7

File tree

4 files changed

+93
-16
lines changed

4 files changed

+93
-16
lines changed

CHANGELOG.md

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Add `dataclass` support for `pytorch_lightning.utilities.apply_to_collection` ([#7935](https://github.com/PyTorchLightning/pytorch-lightning/pull/7935))
13+
14+
1215
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
1316

1417

@@ -53,9 +56,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5356
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))
5457

5558

56-
- Added `__len__` to `IndexBatchSamplerWrapper` ([#7681](https://github.com/PyTorchLightning/pytorch-lightning/pull/7681))
57-
58-
5959
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))
6060

6161

@@ -82,7 +82,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8282

8383
### Changed
8484

85-
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
8685

8786
- Changed the `Trainer`'s `checkpoint_callback` argument to allow only boolean values ([#7539](https://github.com/PyTorchLightning/pytorch-lightning/pull/7539))
8887

@@ -231,40 +230,56 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
231230
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
232231

233232

234-
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
233+
- Fixed dev debugger memory growing due to tracking events even when disabled ([#7875](https://github.com/PyTorchLightning/pytorch-lightning/pull/7875))
235234

236235

237-
- Fixed dev debugger memory growing due to tracking events even when disabled ([#7875](https://github.com/PyTorchLightning/pytorch-lightning/pull/7875))
236+
- Fixed `None` loss keys getting added in `training_epoch_end` when using manual optimization and not returning a loss ([#7772](https://github.com/PyTorchLightning/pytorch-lightning/pull/7772))
238237

239238

240-
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
239+
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
241240

242241

243-
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
242+
## [1.3.5] - 2021-06-08
244243

244+
### Added
245245

246-
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
246+
- Added warning to Training Step output ([#7779](https://github.com/PyTorchLightning/pytorch-lightning/pull/7779))
247247

248+
### Fixed
248249

249-
- Fixed `None` loss keys getting added in `training_epoch_end` when using manual optimization and not returning a loss ([#7772](https://github.com/PyTorchLightning/pytorch-lightning/pull/7772))
250+
- Fixed `LearningRateMonitor` and `BackboneFinetuning` ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835))
251+
- Minor improvements to `apply_to_collection` and type signature of `log_dict` ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
252+
- Fixed docker versions ([#7834](https://github.com/PyTorchLightning/pytorch-lightning/pull/7834))
253+
- Fixed sharded training check for fp16 precision ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
254+
- Fixed support for torch Module type hints in LightningCLI ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
250255

256+
### Changed
251257

252-
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
258+
- Move `training_output` validation to after `train_step_end` ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
253259

254260

255-
- Fixed formatting of info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
261+
## [1.3.4] - 2021-06-01
256262

263+
### Fixed
257264

258-
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
265+
- Fixed info message when max training time reached ([#7780](https://github.com/PyTorchLightning/pytorch-lightning/pull/7780))
266+
- Fixed missing `__len__` method to `IndexBatchSamplerWrapper` ([#7681](https://github.com/PyTorchLightning/pytorch-lightning/pull/7681))
259267

260268

261-
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
269+
## [1.3.3] - 2021-05-27
262270

271+
### Changed
263272

264-
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
273+
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563))
265274

275+
### Fixed
266276

267-
- Fixed `LearningRateMonitor` keys not properly setup when running with `BackboneFinetuning` Callback ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835))
277+
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
278+
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
279+
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
280+
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
281+
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
282+
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
268283

269284

270285
## [1.3.2] - 2021-05-18

pytorch_lightning/utilities/apply_func.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import dataclasses
1415
import operator
1516
from abc import ABC
1617
from collections import OrderedDict
@@ -60,6 +61,11 @@ def _is_namedtuple(obj: object) -> bool:
6061
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
6162

6263

64+
def _is_dataclass_instance(obj):
65+
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
66+
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
67+
68+
6369
def apply_to_collection(
6470
data: Any,
6571
dtype: Union[type, tuple],
@@ -110,6 +116,14 @@ def apply_to_collection(
110116
out.append(v)
111117
return elem_type(*out) if is_namedtuple else elem_type(out)
112118

119+
if _is_dataclass_instance(data):
120+
out = dict()
121+
for field in data.__dataclass_fields__:
122+
v = apply_to_collection(getattr(data, field), dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
123+
if include_none or v is not None:
124+
out[field] = v
125+
return elem_type(**out)
126+
113127
# data is neither of dtype, nor a collection
114128
return data
115129

tests/checkpointing/test_legacy_checkpoints.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
"1.3.0",
6666
"1.3.1",
6767
"1.3.2",
68+
"1.3.3",
69+
"1.3.4",
70+
"1.3.5",
6871
]
6972
)
7073
def test_resume_legacy_checkpoints(tmpdir, pl_version: str):

tests/utilities/test_apply_func.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import dataclasses
1415
import numbers
1516
from collections import namedtuple, OrderedDict
17+
from typing import List
1618

1719
import numpy as np
1820
import pytest
@@ -24,6 +26,17 @@
2426
def test_recursive_application_to_collection():
2527
ntc = namedtuple('Foo', ['bar'])
2628

29+
@dataclasses.dataclass
30+
class Feature:
31+
input_ids: torch.Tensor
32+
segment_ids: np.ndarray
33+
34+
@dataclasses.dataclass
35+
class ModelExample:
36+
example_ids: List[str]
37+
feature: Feature
38+
label: torch.Tensor
39+
2740
to_reduce = {
2841
'a': torch.tensor([1.]), # Tensor
2942
'b': [torch.tensor([2.])], # list
@@ -32,6 +45,12 @@ def test_recursive_application_to_collection():
3245
'e': np.array([10.]), # numpy array
3346
'f': 'this_is_a_dummy_str', # string
3447
'g': 12., # number
48+
'h': Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])), # dataclass
49+
'i': ModelExample(
50+
example_ids=['i-1', 'i-2', 'i-3'],
51+
feature=Feature(input_ids=torch.tensor([1., 2., 3.]), segment_ids=np.array([4., 5., 6.])),
52+
label=torch.tensor([7., 8., 9.])
53+
) # nested dataclass
3554
}
3655

3756
expected_result = {
@@ -42,6 +61,12 @@ def test_recursive_application_to_collection():
4261
'e': np.array([20.]),
4362
'f': 'this_is_a_dummy_str',
4463
'g': 24.,
64+
'h': Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
65+
'i': ModelExample(
66+
example_ids=['i-1', 'i-2', 'i-3'],
67+
feature=Feature(input_ids=torch.tensor([2., 4., 6.]), segment_ids=np.array([8., 10., 12.])),
68+
label=torch.tensor([14., 16., 18.])
69+
)
4570
}
4671

4772
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
@@ -78,6 +103,26 @@ def test_recursive_application_to_collection():
78103
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
79104
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
80105

106+
assert dataclasses.is_dataclass(reduced['h']) and not isinstance(reduced['h'], type), \
107+
'Reduction of a dataclass should result in a dataclass'
108+
assert torch.allclose(reduced['h'].input_ids, expected_result['h'].input_ids), \
109+
'Reduction of a dataclass did not yield the desired result'
110+
assert np.allclose(reduced['h'].segment_ids, expected_result['h'].segment_ids), \
111+
'Reduction of a dataclass did not yield the desired result'
112+
113+
assert dataclasses.is_dataclass(reduced['i']) and not isinstance(reduced['i'], type), \
114+
'Reduction of a dataclass should result in a dataclass'
115+
assert dataclasses.is_dataclass(reduced['i'].feature) and not isinstance(reduced['i'].feature, type), \
116+
'Reduction of a nested dataclass should result in a nested dataclass'
117+
assert reduced['i'].example_ids == expected_result['i'].example_ids, \
118+
'Reduction of a nested dataclass did not yield the desired result'
119+
assert torch.allclose(reduced['i'].label, expected_result['i'].label), \
120+
'Reduction of a nested dataclass did not yield the desired result'
121+
assert torch.allclose(reduced['i'].feature.input_ids, expected_result['i'].feature.input_ids), \
122+
'Reduction of a nested dataclass did not yield the desired result'
123+
assert np.allclose(reduced['i'].feature.segment_ids, expected_result['i'].feature.segment_ids), \
124+
'Reduction of a nested dataclass did not yield the desired result'
125+
81126
# mapping support
82127
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
83128
assert reduced == {'a': '1', 'b': '2'}

0 commit comments

Comments
 (0)