Skip to content

Commit 9f5d495

Browse files
ananthsubcarmocca
andauthored
[1/N] Define dataclasses for progress tracking (#6603)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 901b2ba commit 9f5d495

File tree

3 files changed

+242
-0
lines changed

3 files changed

+242
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
1515

1616

17+
- Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603))
18+
19+
1720
- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))
1821

1922

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
17+
18+
@dataclass
19+
class ProgressState:
20+
"""
21+
Basic dataclass to track event progress.
22+
23+
Args:
24+
ready: Intended to track the number of events ready to start.
25+
started: Intended to be incremented after the event is started (e.g. after `on_*_start runs).
26+
processed: Intended to be incremented after the event is processed.
27+
completed: Intended to be incremented after the event completes (e.g. after `on_*_end` runs).
28+
"""
29+
ready: int = 0
30+
started: int = 0
31+
processed: int = 0
32+
completed: int = 0
33+
34+
def reset(self) -> None:
35+
self.ready = 0
36+
self.started = 0
37+
self.processed = 0
38+
self.completed = 0
39+
40+
41+
@dataclass
42+
class Progress:
43+
"""
44+
Basic dataclass to track aggregated and current progress states.
45+
46+
Args:
47+
total: Intended to track the total progress of an event
48+
current: Intended to track the current progress of an event
49+
"""
50+
total: ProgressState = field(default_factory=ProgressState)
51+
current: ProgressState = field(default_factory=ProgressState)
52+
53+
def increment_ready(self) -> None:
54+
self.total.ready += 1
55+
self.current.ready += 1
56+
57+
def increment_started(self) -> None:
58+
self.total.started += 1
59+
self.current.started += 1
60+
61+
def increment_processed(self) -> None:
62+
self.total.processed += 1
63+
self.current.processed += 1
64+
65+
def increment_completed(self) -> None:
66+
self.total.completed += 1
67+
self.current.completed += 1
68+
69+
70+
@dataclass
71+
class LoopProgress:
72+
"""
73+
Dataclass to track loop progress during execution.
74+
75+
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
76+
Args:
77+
epoch: Tracks epochs progress.
78+
batch: Tracks batch progress.
79+
"""
80+
epoch: Progress = field(default_factory=Progress)
81+
batch: Progress = field(default_factory=Progress)
82+
83+
def increment_epoch_completed(self) -> None:
84+
self.epoch.increment_completed()
85+
self.reset_on_epoch()
86+
87+
def reset_on_epoch(self) -> None:
88+
self.batch.current.reset()
89+
self.epoch.current.reset()

tests/trainer/test_progress.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.trainer.progress import LoopProgress, ProgressState
16+
17+
18+
def test_increment_ready(tmpdir):
19+
prog = LoopProgress()
20+
prog.batch.increment_ready()
21+
assert prog.batch.total.ready == 1
22+
assert prog.batch.current.ready == prog.batch.total.ready
23+
24+
25+
def test_increment_started(tmpdir):
26+
prog = LoopProgress()
27+
prog.epoch.increment_started()
28+
assert prog.batch.total.ready == 0
29+
assert prog.epoch.total.ready == 0
30+
assert prog.epoch.total.started == 1
31+
assert prog.epoch.total.started == prog.epoch.current.started
32+
33+
34+
def test_increment_processed(tmpdir):
35+
prog = LoopProgress()
36+
prog.epoch.increment_processed()
37+
assert prog.batch.total.ready == 0
38+
assert prog.batch.total.started == 0
39+
assert prog.epoch.total.started == 0
40+
assert prog.epoch.total.processed == 1
41+
assert prog.epoch.total.processed == prog.epoch.current.processed
42+
43+
44+
def test_increment_completed(tmpdir):
45+
prog = LoopProgress()
46+
prog.epoch.increment_completed()
47+
assert prog.batch.total.ready == 0
48+
assert prog.batch.total.started == 0
49+
assert prog.epoch.total.started == 0
50+
assert prog.epoch.total.processed == 0
51+
assert prog.epoch.total.completed == 1
52+
assert prog.epoch.total.completed == prog.epoch.current.completed
53+
54+
55+
def test_increment_epoch(tmpdir):
56+
""" Test sequences for incrementing epochs. """
57+
prog = LoopProgress()
58+
prog.batch.increment_completed()
59+
assert prog.batch.current.completed == 1
60+
61+
prog.increment_epoch_completed()
62+
prog.increment_epoch_completed()
63+
assert prog.epoch.current.completed == 0
64+
assert prog.epoch.total.completed == 2
65+
assert prog.batch.current.completed == 0
66+
assert prog.batch.total.completed == 1
67+
68+
69+
def test_reset_on_epoch(tmpdir):
70+
""" Test sequences for resetting. """
71+
prog = LoopProgress()
72+
73+
prog.batch.increment_started()
74+
assert prog.batch.total.started == 1
75+
assert prog.epoch.total.started == 0
76+
77+
prog.reset_on_epoch()
78+
assert prog.batch.current.started == 0
79+
assert prog.batch.total == ProgressState(started=1)
80+
81+
prog.batch.increment_started()
82+
assert prog.batch.total == ProgressState(started=2)
83+
assert prog.epoch.total.started == 0
84+
85+
86+
def test_increment_batch_ready_start_process_finish_epoch(tmpdir):
87+
""" Test sequences for incrementing batches reads and epochs. """
88+
prog = LoopProgress()
89+
90+
prog.epoch.increment_ready()
91+
assert prog.epoch.total.ready == 1
92+
assert prog.epoch.current.ready == 1
93+
assert prog.batch.total.ready == 0
94+
assert prog.batch.current.ready == 0
95+
96+
prog.epoch.increment_started()
97+
assert prog.epoch.total.ready == 1
98+
assert prog.epoch.current.ready == 1
99+
assert prog.epoch.total.started == 1
100+
assert prog.epoch.current.started == 1
101+
assert prog.batch.total.started == 0
102+
assert prog.batch.current.started == 0
103+
104+
prog.batch.increment_ready()
105+
assert prog.batch.total.ready == 1
106+
assert prog.batch.current.ready == 1
107+
assert prog.epoch.total.ready == 1
108+
assert prog.epoch.current.ready == 1
109+
110+
prog.batch.increment_started()
111+
assert prog.batch.total.started == 1
112+
assert prog.batch.current.started == 1
113+
assert prog.epoch.total.started == 1
114+
assert prog.epoch.current.started == 1
115+
116+
prog.batch.increment_processed()
117+
assert prog.batch.total.processed == 1
118+
assert prog.batch.current.processed == 1
119+
assert prog.epoch.total.processed == 0
120+
assert prog.epoch.current.processed == 0
121+
122+
prog.batch.increment_completed()
123+
assert prog.batch.total.completed == 1
124+
assert prog.batch.current.completed == 1
125+
assert prog.epoch.total.completed == 0
126+
assert prog.epoch.current.completed == 0
127+
128+
prog.epoch.increment_processed()
129+
assert prog.batch.total.processed == 1
130+
assert prog.batch.current.processed == 1
131+
assert prog.epoch.total.processed == 1
132+
assert prog.epoch.current.processed == 1
133+
134+
prog.increment_epoch_completed()
135+
assert prog.batch.total.completed == 1
136+
assert prog.batch.current.completed == 0
137+
assert prog.epoch.total.completed == 1
138+
assert prog.epoch.current.completed == 0
139+
140+
prog.epoch.increment_ready()
141+
assert prog.epoch.total.ready == 2
142+
assert prog.epoch.current.ready == 1
143+
144+
prog.batch.increment_ready()
145+
assert prog.batch.total.ready == 2
146+
assert prog.batch.current.ready == 1
147+
148+
prog.reset_on_epoch()
149+
assert prog.batch.current.ready == 0
150+
assert prog.epoch.current.ready == 0

0 commit comments

Comments
 (0)