11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from unittest import mock
15
14
from unittest .mock import patch
16
15
17
16
import pytest
20
19
from pytorch_lightning .demos .boring_classes import BoringModel
21
20
from pytorch_lightning .loops import TrainingEpochLoop
22
21
from pytorch_lightning .trainer .trainer import Trainer
23
- from tests_pytorch .deprecated_api import no_deprecated_call
24
22
25
23
_out00 = {"loss" : 0.0 }
26
24
_out01 = {"loss" : 0.1 }
33
31
34
32
35
33
class TestPrepareOutputs :
36
- def prepare_outputs (self , fn , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization ):
34
+ def prepare_outputs (self , fn , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization ):
37
35
lightning_module = LightningModule ()
38
- lightning_module .on_train_batch_end = lambda * _ : None # override to trigger the deprecation message
39
36
lightning_module .automatic_optimization = automatic_optimization
40
37
lightning_module .truncated_bptt_steps = tbptt_splits
41
- match = "will change in version v1.8.*new_format=True"
42
- will_warn = tbptt_splits and num_optimizers > 1 and not new_format
43
- ctx_manager = pytest .deprecated_call if will_warn else no_deprecated_call
44
- with ctx_manager (match = match ):
45
- with mock .patch (
46
- "pytorch_lightning.loops.epoch.training_epoch_loop._v1_8_output_format" , return_value = new_format
47
- ):
48
- return fn (
49
- batch_outputs ,
50
- lightning_module = lightning_module ,
51
- num_optimizers = num_optimizers , # does not matter for manual optimization
52
- )
38
+ return fn (
39
+ batch_outputs ,
40
+ lightning_module = lightning_module ,
41
+ num_optimizers = num_optimizers , # does not matter for manual optimization
42
+ )
53
43
54
44
def prepare_outputs_training_epoch_end (
55
- self , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization = True
45
+ self , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization = True
56
46
):
57
47
return self .prepare_outputs (
58
48
TrainingEpochLoop ._prepare_outputs_training_epoch_end ,
59
49
tbptt_splits ,
60
- new_format ,
61
50
batch_outputs ,
62
51
num_optimizers ,
63
52
automatic_optimization = automatic_optimization ,
64
53
)
65
54
66
55
def prepare_outputs_training_batch_end (
67
- self , tbptt_splits , new_format , batch_outputs , num_optimizers , automatic_optimization = True
56
+ self , tbptt_splits , batch_outputs , num_optimizers , automatic_optimization = True
68
57
):
69
58
return self .prepare_outputs (
70
59
TrainingEpochLoop ._prepare_outputs_training_batch_end ,
71
60
tbptt_splits ,
72
- new_format ,
73
61
batch_outputs ,
74
62
num_optimizers ,
75
63
automatic_optimization = automatic_optimization ,
@@ -97,53 +85,19 @@ def prepare_outputs_training_batch_end(
97
85
),
98
86
# 1 batch, tbptt with 2 splits (uneven)
99
87
(1 , 2 , [[{0 : _out00 }, {0 : _out01 }], [{0 : _out03 }]], [[_out00 , _out01 ], [_out03 ]]),
100
- ],
101
- )
102
- @pytest .mark .parametrize ("new_format" , (False , True ))
103
- def test_prepare_outputs_training_epoch_end_automatic (
104
- self , num_optimizers , tbptt_splits , batch_outputs , expected , new_format
105
- ):
106
- """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
107
- currently expects in the case of automatic optimization."""
108
- assert (
109
- self .prepare_outputs_training_epoch_end (tbptt_splits , new_format , batch_outputs , num_optimizers ) == expected
110
- )
111
-
112
- @pytest .mark .parametrize (
113
- "num_optimizers,tbptt_splits,batch_outputs,expected" ,
114
- [
115
- # 3 batches, tbptt with 2 splits, 2 optimizers alternating
116
- (
117
- 2 ,
118
- 2 ,
119
- [[{0 : _out00 }, {0 : _out01 }], [{1 : _out10 }, {1 : _out11 }], [{0 : _out02 }, {0 : _out03 }]],
120
- [[[_out00 , _out01 ], [], [_out02 , _out03 ]], [[], [_out10 , _out11 ], []]],
121
- )
122
- ],
123
- )
124
- def test_prepare_outputs_training_epoch_end_automatic_old_format (
125
- self , num_optimizers , tbptt_splits , batch_outputs , expected
126
- ):
127
- assert self .prepare_outputs_training_epoch_end (tbptt_splits , False , batch_outputs , num_optimizers ) == expected
128
-
129
- @pytest .mark .parametrize (
130
- "num_optimizers,tbptt_splits,batch_outputs,expected" ,
131
- [
132
88
# 3 batches, tbptt with 2 splits, 2 optimizers alternating
133
89
(
134
90
2 ,
135
91
2 ,
136
92
[[{0 : _out00 }, {0 : _out01 }], [{1 : _out10 }, {1 : _out11 }], [{0 : _out02 }, {0 : _out03 }]],
137
93
[[[_out00 ], [_out01 ]], [[_out10 ], [_out11 ]], [[_out02 ], [_out03 ]]],
138
- )
94
+ ),
139
95
],
140
96
)
141
- def test_prepare_outputs_training_epoch_end_automatic_new_format (
142
- self , num_optimizers , tbptt_splits , batch_outputs , expected
143
- ):
97
+ def test_prepare_outputs_training_epoch_end_automatic (self , num_optimizers , tbptt_splits , batch_outputs , expected ):
144
98
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
145
99
currently expects in the case of automatic optimization."""
146
- assert self .prepare_outputs_training_epoch_end (tbptt_splits , True , batch_outputs , num_optimizers ) == expected
100
+ assert self .prepare_outputs_training_epoch_end (tbptt_splits , batch_outputs , num_optimizers ) == expected
147
101
148
102
@pytest .mark .parametrize (
149
103
"batch_outputs,expected" ,
@@ -160,14 +114,10 @@ def test_prepare_outputs_training_epoch_end_automatic_new_format(
160
114
([[_out00 , _out01 ], [_out02 , _out03 ], [], [_out10 ]], [[_out00 , _out01 ], [_out02 , _out03 ], [_out10 ]]),
161
115
],
162
116
)
163
- @pytest .mark .parametrize ("new_format" , (False , True ))
164
- def test_prepare_outputs_training_epoch_end_manual (self , batch_outputs , expected , new_format ):
117
+ def test_prepare_outputs_training_epoch_end_manual (self , batch_outputs , expected ):
165
118
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
166
119
currently expects in the case of manual optimization."""
167
- assert (
168
- self .prepare_outputs_training_epoch_end (0 , new_format , batch_outputs , - 1 , automatic_optimization = False )
169
- == expected
170
- )
120
+ assert self .prepare_outputs_training_epoch_end (0 , batch_outputs , - 1 , automatic_optimization = False ) == expected
171
121
172
122
@pytest .mark .parametrize (
173
123
"num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
@@ -180,47 +130,17 @@ def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected
180
130
(2 , 0 , [{0 : _out00 , 1 : _out01 }], [_out00 , _out01 ]),
181
131
# tbptt with 2 splits
182
132
(1 , 2 , [{0 : _out00 }, {0 : _out01 }], [_out00 , _out01 ]),
133
+ # 2 optimizers, tbptt with 2 splits
134
+ (2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out01 ], [_out10 , _out11 ]]),
183
135
],
184
136
)
185
- @pytest .mark .parametrize ("new_format" , (False , True ))
186
137
def test_prepare_outputs_training_batch_end_automatic (
187
- self , num_optimizers , tbptt_splits , batch_end_outputs , expected , new_format
188
- ):
189
- """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
190
- currently expects in the case of automatic optimization."""
191
-
192
- assert (
193
- self .prepare_outputs_training_batch_end (tbptt_splits , new_format , batch_end_outputs , num_optimizers )
194
- == expected
195
- )
196
-
197
- @pytest .mark .parametrize (
198
- "num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
199
- # 2 optimizers, tbptt with 2 splits
200
- [(2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out10 ], [_out01 , _out11 ]])],
201
- )
202
- def test_prepare_outputs_training_batch_end_automatic_old_format (
203
138
self , num_optimizers , tbptt_splits , batch_end_outputs , expected
204
139
):
205
140
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
206
141
currently expects in the case of automatic optimization."""
207
- assert (
208
- self .prepare_outputs_training_batch_end (tbptt_splits , False , batch_end_outputs , num_optimizers ) == expected
209
- )
210
142
211
- @pytest .mark .parametrize (
212
- "num_optimizers,tbptt_splits,batch_end_outputs,expected" ,
213
- # 2 optimizers, tbptt with 2 splits
214
- [(2 , 2 , [{0 : _out00 , 1 : _out01 }, {0 : _out10 , 1 : _out11 }], [[_out00 , _out01 ], [_out10 , _out11 ]])],
215
- )
216
- def test_prepare_outputs_training_batch_end_automatic_new_format (
217
- self , num_optimizers , tbptt_splits , batch_end_outputs , expected
218
- ):
219
- """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
220
- currently expects in the case of automatic optimization."""
221
- assert (
222
- self .prepare_outputs_training_batch_end (tbptt_splits , True , batch_end_outputs , num_optimizers ) == expected
223
- )
143
+ assert self .prepare_outputs_training_batch_end (tbptt_splits , batch_end_outputs , num_optimizers ) == expected
224
144
225
145
@pytest .mark .parametrize (
226
146
"batch_end_outputs,expected" ,
@@ -237,8 +157,7 @@ def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expe
237
157
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
238
158
currently expects in the case of manual optimization."""
239
159
assert (
240
- self .prepare_outputs_training_batch_end (0 , False , batch_end_outputs , - 1 , automatic_optimization = False )
241
- == expected
160
+ self .prepare_outputs_training_batch_end (0 , batch_end_outputs , - 1 , automatic_optimization = False ) == expected
242
161
)
243
162
244
163
0 commit comments