6
6
import torch .nn as nn
7
7
import torch .optim as optim
8
8
import utils
9
+ from torch .cuda .amp import GradScaler , autocast
9
10
10
11
import ignite
11
12
import ignite .distributed as idist
12
13
from ignite .contrib .engines import common
13
14
from ignite .contrib .handlers import PiecewiseLinear
14
15
from ignite .engine import Engine , Events , create_supervised_evaluator
15
- from ignite .handlers import Checkpoint , DiskSaver
16
+ from ignite .handlers import Checkpoint , DiskSaver , global_step_from_engine
16
17
from ignite .metrics import Accuracy , Loss
17
18
from ignite .utils import manual_seed , setup_logger
18
19
@@ -31,16 +32,37 @@ def training(local_rank, config):
31
32
if rank == 0 :
32
33
now = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
33
34
34
- folder_name = "{}_backend-{}-{}_{}" . format ( config [" model" ], idist .backend (), idist .get_world_size (), now )
35
+ folder_name = f" { config [' model' ] } _backend- { idist .backend ()} - { idist .get_world_size ()} _ { now } "
35
36
output_path = Path (output_path ) / folder_name
36
37
if not output_path .exists ():
37
38
output_path .mkdir (parents = True )
38
39
config ["output_path" ] = output_path .as_posix ()
39
- logger .info ("Output path: {}" . format ( config [" output_path" ]) )
40
+ logger .info (f "Output path: { config [' output_path' ] } " )
40
41
41
42
if "cuda" in device .type :
42
43
config ["cuda device name" ] = torch .cuda .get_device_name (local_rank )
43
44
45
+ if config ["with_clearml" ]:
46
+ try :
47
+ from clearml import Task
48
+ except ImportError :
49
+ # Backwards-compatibility for legacy Trains SDK
50
+ from trains import Task
51
+
52
+ task = Task .init ("CIFAR10-Training" , task_name = output_path .stem )
53
+ task .connect_configuration (config )
54
+ # Log hyper parameters
55
+ hyper_params = [
56
+ "model" ,
57
+ "batch_size" ,
58
+ "momentum" ,
59
+ "weight_decay" ,
60
+ "num_epochs" ,
61
+ "learning_rate" ,
62
+ "num_warmup_epochs" ,
63
+ ]
64
+ task .connect ({k : config [k ] for k in hyper_params })
65
+
44
66
# Setup dataflow, model, optimizer, criterion
45
67
train_loader , test_loader = get_dataflow (config )
46
68
@@ -78,15 +100,18 @@ def run_validation(engine):
78
100
evaluators = {"training" : train_evaluator , "test" : evaluator }
79
101
tb_logger = common .setup_tb_logging (output_path , trainer , optimizer , evaluators = evaluators )
80
102
81
- # Store 3 best models by validation accuracy:
82
- common .save_best_model_by_val_score (
83
- output_path = config ["output_path" ],
84
- evaluator = evaluator ,
85
- model = model ,
86
- metric_name = "Accuracy" ,
87
- n_saved = 1 ,
88
- trainer = trainer ,
89
- tag = "test" ,
103
+ # Store 2 best models by validation accuracy starting from num_epochs / 2:
104
+ best_model_handler = Checkpoint (
105
+ {"model" : model },
106
+ get_save_handler (config ),
107
+ filename_prefix = "best" ,
108
+ n_saved = 2 ,
109
+ global_step_transform = global_step_from_engine (trainer ),
110
+ score_name = "test_accuracy" ,
111
+ score_function = Checkpoint .get_default_score_fn ("accuracy" ),
112
+ )
113
+ evaluator .add_event_handler (
114
+ Events .COMPLETED (lambda * _ : trainer .state .epoch > config ["num_epochs" ] // 2 ), best_model_handler
90
115
)
91
116
92
117
trainer .run (train_loader , max_epochs = config ["num_epochs" ])
@@ -108,11 +133,13 @@ def run(
108
133
learning_rate = 0.4 ,
109
134
num_warmup_epochs = 4 ,
110
135
validate_every = 3 ,
111
- checkpoint_every = 200 ,
136
+ checkpoint_every = 1000 ,
112
137
backend = None ,
113
138
resume_from = None ,
114
139
log_every_iters = 15 ,
115
140
nproc_per_node = None ,
141
+ with_clearml = False ,
142
+ with_amp = False ,
116
143
** spawn_kwargs ,
117
144
):
118
145
"""Main entry to train an model on CIFAR10 dataset.
@@ -138,6 +165,8 @@ def run(
138
165
resume_from (str, optional): path to checkpoint to use to resume the training from. Default, None.
139
166
log_every_iters (int): argument to log batch loss every ``log_every_iters`` iterations.
140
167
It can be 0 to disable it. Default, 15.
168
+ with_clearml (bool): if True, experiment ClearML logger is setup. Default, False.
169
+ with_amp (bool): if True, enables native automatic mixed precision. Default, False.
141
170
**spawn_kwargs: Other kwargs to spawn run in child processes: master_addr, master_port, node_rank, nnodes
142
171
143
172
"""
@@ -149,10 +178,8 @@ def run(
149
178
spawn_kwargs ["nproc_per_node" ] = nproc_per_node
150
179
151
180
with idist .Parallel (backend = backend , ** spawn_kwargs ) as parallel :
152
- try :
153
- parallel .run (training , config )
154
- except Exception as e :
155
- raise e
181
+
182
+ parallel .run (training , config )
156
183
157
184
158
185
def get_dataflow (config ):
@@ -167,7 +194,7 @@ def get_dataflow(config):
167
194
# Ensure that only rank 0 download the dataset
168
195
idist .barrier ()
169
196
170
- # Setup data loader also adapted to distributed config
197
+ # Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
171
198
train_loader = idist .auto_dataloader (
172
199
train_dataset , batch_size = config ["batch_size" ], num_workers = config ["num_workers" ], shuffle = True , drop_last = True ,
173
200
)
@@ -180,6 +207,7 @@ def get_dataflow(config):
180
207
181
208
def initialize (config ):
182
209
model = utils .get_model (config ["model" ])
210
+ # Adapt model for distributed settings if configured
183
211
model = idist .auto_model (model , find_unused_parameters = True )
184
212
185
213
optimizer = optim .SGD (
@@ -205,24 +233,28 @@ def initialize(config):
205
233
206
234
def log_metrics (logger , epoch , elapsed , tag , metrics ):
207
235
metrics_output = "\n " .join ([f"\t { k } : { v } " for k , v in metrics .items ()])
208
- logger .info (f"\n Epoch { epoch } - Time taken (seconds) : { elapsed :.02f } - { tag } metrics:\n { metrics_output } " )
236
+ logger .info (f"\n Epoch { epoch } - Evaluation time (seconds): { elapsed :.2f } - { tag } metrics:\n { metrics_output } " )
209
237
210
238
211
239
def log_basic_info (logger , config ):
212
- logger .info ("Quantization Aware Training {} on CIFAR10" .format (config ["model" ]))
213
- logger .info ("- PyTorch version: {}" .format (torch .__version__ ))
214
- logger .info ("- Ignite version: {}" .format (ignite .__version__ ))
240
+ logger .info (f"Quantization Aware Training { config ['model' ]} on CIFAR10" )
241
+ logger .info (f"- PyTorch version: { torch .__version__ } " )
242
+ logger .info (f"- Ignite version: { ignite .__version__ } " )
243
+ if torch .cuda .is_available ():
244
+ logger .info (f"- GPU Device: { torch .cuda .get_device_name (idist .get_local_rank ())} " )
245
+ logger .info (f"- CUDA version: { torch .version .cuda } " )
246
+ logger .info (f"- CUDNN version: { torch .backends .cudnn .version ()} " )
215
247
216
248
logger .info ("\n " )
217
249
logger .info ("Configuration:" )
218
250
for key , value in config .items ():
219
- logger .info ("\t {}: {}" . format ( key , value ) )
251
+ logger .info (f "\t { key } : { value } " )
220
252
logger .info ("\n " )
221
253
222
254
if idist .get_world_size () > 1 :
223
255
logger .info ("\n Distributed setting:" )
224
- logger .info ("\t backend: {}" . format ( idist .backend ()) )
225
- logger .info ("\t world size: {}" . format ( idist .get_world_size ()) )
256
+ logger .info (f "\t backend: { idist .backend ()} " )
257
+ logger .info (f "\t world size: { idist .get_world_size ()} " )
226
258
logger .info ("\n " )
227
259
228
260
@@ -239,6 +271,9 @@ def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, con
239
271
# - RunningAverage` on `train_step` output
240
272
# - Two progress bars on epochs and optionally on iterations
241
273
274
+ with_amp = config ["with_amp" ]
275
+ scaler = GradScaler (enabled = with_amp )
276
+
242
277
def train_step (engine , batch ):
243
278
244
279
x , y = batch [0 ], batch [1 ]
@@ -248,12 +283,15 @@ def train_step(engine, batch):
248
283
y = y .to (device , non_blocking = True )
249
284
250
285
model .train ()
251
- y_pred = model (x )
252
- loss = criterion (y_pred , y )
286
+
287
+ with autocast (enabled = with_amp ):
288
+ y_pred = model (x )
289
+ loss = criterion (y_pred , y )
253
290
254
291
optimizer .zero_grad ()
255
- loss .backward ()
256
- optimizer .step ()
292
+ scaler .scale (loss ).backward ()
293
+ scaler .step (optimizer )
294
+ scaler .update ()
257
295
258
296
return {
259
297
"batch loss" : loss .item (),
@@ -272,7 +310,7 @@ def train_step(engine, batch):
272
310
train_sampler = train_sampler ,
273
311
to_save = to_save ,
274
312
save_every_iters = config ["checkpoint_every" ],
275
- output_path = config [ "output_path" ] ,
313
+ save_handler = get_save_handler ( config ) ,
276
314
lr_scheduler = lr_scheduler ,
277
315
output_names = metric_names if config ["log_every_iters" ] > 0 else None ,
278
316
with_pbars = False ,
@@ -282,13 +320,22 @@ def train_step(engine, batch):
282
320
resume_from = config ["resume_from" ]
283
321
if resume_from is not None :
284
322
checkpoint_fp = Path (resume_from )
285
- assert checkpoint_fp .exists (), "Checkpoint '{}' is not found" . format ( checkpoint_fp . as_posix ())
286
- logger .info ("Resume from a checkpoint: {}" . format ( checkpoint_fp .as_posix ()) )
323
+ assert checkpoint_fp .exists (), f "Checkpoint '{ checkpoint_fp . as_posix () } ' is not found"
324
+ logger .info (f "Resume from a checkpoint: { checkpoint_fp .as_posix ()} " )
287
325
checkpoint = torch .load (checkpoint_fp .as_posix (), map_location = "cpu" )
288
326
Checkpoint .load_objects (to_load = to_save , checkpoint = checkpoint )
289
327
290
328
return trainer
291
329
292
330
331
+ def get_save_handler (config ):
332
+ if config ["with_clearml" ]:
333
+ from ignite .contrib .handlers .clearml_logger import ClearMLSaver
334
+
335
+ return ClearMLSaver (dirname = config ["output_path" ])
336
+
337
+ return DiskSaver (config ["output_path" ], require_empty = False )
338
+
339
+
293
340
if __name__ == "__main__" :
294
341
fire .Fire ({"run" : run })
0 commit comments