Skip to content

Commit 90a753c

Browse files
authored
Updated cifar10 example (#1632)
* Updates for cifar10 example * Updates for cifar10 example * More updates * Updated code * Fixed code-formatting
1 parent ad138ea commit 90a753c

File tree

2 files changed

+116
-64
lines changed

2 files changed

+116
-64
lines changed

examples/contrib/cifar10/main.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import torch.nn as nn
77
import torch.optim as optim
88
import utils
9+
from torch.cuda.amp import GradScaler, autocast
910

1011
import ignite
1112
import ignite.distributed as idist
1213
from ignite.contrib.engines import common
1314
from ignite.contrib.handlers import PiecewiseLinear
1415
from ignite.engine import Engine, Events, create_supervised_evaluator
15-
from ignite.handlers import Checkpoint, DiskSaver
16+
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
1617
from ignite.metrics import Accuracy, Loss
1718
from ignite.utils import manual_seed, setup_logger
1819

@@ -76,8 +77,8 @@ def training(local_rank, config):
7677

7778
# Let's now setup evaluator engine to perform model's validation and compute metrics
7879
metrics = {
79-
"accuracy": Accuracy(),
80-
"loss": Loss(criterion),
80+
"Accuracy": Accuracy(),
81+
"Loss": Loss(criterion),
8182
}
8283

8384
# We define two evaluators as they wont have exactly similar roles:
@@ -102,15 +103,18 @@ def run_validation(engine):
102103
evaluators = {"training": train_evaluator, "test": evaluator}
103104
tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators)
104105

105-
# Store 3 best models by validation accuracy:
106-
common.gen_save_best_models_by_val_score(
107-
save_handler=get_save_handler(config),
108-
evaluator=evaluator,
109-
models={"model": model},
110-
metric_name="accuracy",
111-
n_saved=3,
112-
trainer=trainer,
113-
tag="test",
106+
# Store 2 best models by validation accuracy starting from num_epochs / 2:
107+
best_model_handler = Checkpoint(
108+
{"model": model},
109+
get_save_handler(config),
110+
filename_prefix="best",
111+
n_saved=2,
112+
global_step_transform=global_step_from_engine(trainer),
113+
score_name="test_accuracy",
114+
score_function=Checkpoint.get_default_score_fn("accuracy"),
115+
)
116+
evaluator.add_event_handler(
117+
Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
114118
)
115119

116120
# In order to check training resuming we can stop training on a given iteration
@@ -124,9 +128,8 @@ def _():
124128
try:
125129
trainer.run(train_loader, max_epochs=config["num_epochs"])
126130
except Exception as e:
127-
import traceback
128-
129-
print(traceback.format_exc())
131+
logger.exception("")
132+
raise e
130133

131134
if rank == 0:
132135
tb_logger.close()
@@ -145,13 +148,14 @@ def run(
145148
learning_rate=0.4,
146149
num_warmup_epochs=4,
147150
validate_every=3,
148-
checkpoint_every=200,
151+
checkpoint_every=1000,
149152
backend=None,
150153
resume_from=None,
151154
log_every_iters=15,
152155
nproc_per_node=None,
153156
stop_iteration=None,
154157
with_clearml=False,
158+
with_amp=False,
155159
**spawn_kwargs,
156160
):
157161
"""Main entry to train an model on CIFAR10 dataset.
@@ -179,6 +183,7 @@ def run(
179183
It can be 0 to disable it. Default, 15.
180184
stop_iteration (int, optional): iteration to stop the training. Can be used to check resume from checkpoint.
181185
with_clearml (bool): if True, experiment ClearML logger is setup. Default, False.
186+
with_amp (bool): if True, enables native automatic mixed precision. Default, False.
182187
**spawn_kwargs: Other kwargs to spawn run in child processes: master_addr, master_port, node_rank, nnodes
183188
184189
"""
@@ -245,13 +250,17 @@ def initialize(config):
245250

246251
def log_metrics(logger, epoch, elapsed, tag, metrics):
247252
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
248-
logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {int(elapsed)} - {tag} metrics:\n {metrics_output}")
253+
logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")
249254

250255

251256
def log_basic_info(logger, config):
252257
logger.info(f"Train {config['model']} on CIFAR10")
253258
logger.info(f"- PyTorch version: {torch.__version__}")
254259
logger.info(f"- Ignite version: {ignite.__version__}")
260+
if torch.cuda.is_available():
261+
logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
262+
logger.info(f"- CUDA version: {torch.version.cuda}")
263+
logger.info(f"- CUDNN version: {torch.backends.cudnn.version()}")
255264

256265
logger.info("\n")
257266
logger.info("Configuration:")
@@ -279,6 +288,9 @@ def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, con
279288
# - RunningAverage` on `train_step` output
280289
# - Two progress bars on epochs and optionally on iterations
281290

291+
with_amp = config["with_amp"]
292+
scaler = GradScaler(enabled=with_amp)
293+
282294
def train_step(engine, batch):
283295

284296
x, y = batch[0], batch[1]
@@ -288,28 +300,21 @@ def train_step(engine, batch):
288300
y = y.to(device, non_blocking=True)
289301

290302
model.train()
291-
# Supervised part
292-
y_pred = model(x)
293-
loss = criterion(y_pred, y)
294303

295-
optimizer.zero_grad()
296-
loss.backward()
297-
optimizer.step()
304+
with autocast(enabled=with_amp):
305+
y_pred = model(x)
306+
loss = criterion(y_pred, y)
298307

299-
# This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
300-
if config["log_every_iters"] > 0 and (engine.state.iteration - 1) % config["log_every_iters"] == 0:
301-
batch_loss = loss.item()
302-
engine.state.saved_batch_loss = batch_loss
303-
else:
304-
batch_loss = engine.state.saved_batch_loss
308+
optimizer.zero_grad()
309+
scaler.scale(loss).backward()
310+
scaler.step(optimizer)
311+
scaler.update()
305312

306313
return {
307-
"batch loss": batch_loss,
314+
"batch loss": loss.item(),
308315
}
309316

310317
trainer = Engine(train_step)
311-
trainer.state.saved_batch_loss = -1.0
312-
trainer.state_dict_user_keys.append("saved_batch_loss")
313318
trainer.logger = logger
314319

315320
to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}

examples/contrib/cifar10_qat/main.py

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import torch.nn as nn
77
import torch.optim as optim
88
import utils
9+
from torch.cuda.amp import GradScaler, autocast
910

1011
import ignite
1112
import ignite.distributed as idist
1213
from ignite.contrib.engines import common
1314
from ignite.contrib.handlers import PiecewiseLinear
1415
from ignite.engine import Engine, Events, create_supervised_evaluator
15-
from ignite.handlers import Checkpoint, DiskSaver
16+
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
1617
from ignite.metrics import Accuracy, Loss
1718
from ignite.utils import manual_seed, setup_logger
1819

@@ -31,16 +32,37 @@ def training(local_rank, config):
3132
if rank == 0:
3233
now = datetime.now().strftime("%Y%m%d-%H%M%S")
3334

34-
folder_name = "{}_backend-{}-{}_{}".format(config["model"], idist.backend(), idist.get_world_size(), now)
35+
folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
3536
output_path = Path(output_path) / folder_name
3637
if not output_path.exists():
3738
output_path.mkdir(parents=True)
3839
config["output_path"] = output_path.as_posix()
39-
logger.info("Output path: {}".format(config["output_path"]))
40+
logger.info(f"Output path: {config['output_path']}")
4041

4142
if "cuda" in device.type:
4243
config["cuda device name"] = torch.cuda.get_device_name(local_rank)
4344

45+
if config["with_clearml"]:
46+
try:
47+
from clearml import Task
48+
except ImportError:
49+
# Backwards-compatibility for legacy Trains SDK
50+
from trains import Task
51+
52+
task = Task.init("CIFAR10-Training", task_name=output_path.stem)
53+
task.connect_configuration(config)
54+
# Log hyper parameters
55+
hyper_params = [
56+
"model",
57+
"batch_size",
58+
"momentum",
59+
"weight_decay",
60+
"num_epochs",
61+
"learning_rate",
62+
"num_warmup_epochs",
63+
]
64+
task.connect({k: config[k] for k in hyper_params})
65+
4466
# Setup dataflow, model, optimizer, criterion
4567
train_loader, test_loader = get_dataflow(config)
4668

@@ -78,15 +100,18 @@ def run_validation(engine):
78100
evaluators = {"training": train_evaluator, "test": evaluator}
79101
tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators)
80102

81-
# Store 3 best models by validation accuracy:
82-
common.save_best_model_by_val_score(
83-
output_path=config["output_path"],
84-
evaluator=evaluator,
85-
model=model,
86-
metric_name="Accuracy",
87-
n_saved=1,
88-
trainer=trainer,
89-
tag="test",
103+
# Store 2 best models by validation accuracy starting from num_epochs / 2:
104+
best_model_handler = Checkpoint(
105+
{"model": model},
106+
get_save_handler(config),
107+
filename_prefix="best",
108+
n_saved=2,
109+
global_step_transform=global_step_from_engine(trainer),
110+
score_name="test_accuracy",
111+
score_function=Checkpoint.get_default_score_fn("accuracy"),
112+
)
113+
evaluator.add_event_handler(
114+
Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
90115
)
91116

92117
trainer.run(train_loader, max_epochs=config["num_epochs"])
@@ -108,11 +133,13 @@ def run(
108133
learning_rate=0.4,
109134
num_warmup_epochs=4,
110135
validate_every=3,
111-
checkpoint_every=200,
136+
checkpoint_every=1000,
112137
backend=None,
113138
resume_from=None,
114139
log_every_iters=15,
115140
nproc_per_node=None,
141+
with_clearml=False,
142+
with_amp=False,
116143
**spawn_kwargs,
117144
):
118145
"""Main entry to train an model on CIFAR10 dataset.
@@ -138,6 +165,8 @@ def run(
138165
resume_from (str, optional): path to checkpoint to use to resume the training from. Default, None.
139166
log_every_iters (int): argument to log batch loss every ``log_every_iters`` iterations.
140167
It can be 0 to disable it. Default, 15.
168+
with_clearml (bool): if True, experiment ClearML logger is setup. Default, False.
169+
with_amp (bool): if True, enables native automatic mixed precision. Default, False.
141170
**spawn_kwargs: Other kwargs to spawn run in child processes: master_addr, master_port, node_rank, nnodes
142171
143172
"""
@@ -149,10 +178,8 @@ def run(
149178
spawn_kwargs["nproc_per_node"] = nproc_per_node
150179

151180
with idist.Parallel(backend=backend, **spawn_kwargs) as parallel:
152-
try:
153-
parallel.run(training, config)
154-
except Exception as e:
155-
raise e
181+
182+
parallel.run(training, config)
156183

157184

158185
def get_dataflow(config):
@@ -167,7 +194,7 @@ def get_dataflow(config):
167194
# Ensure that only rank 0 download the dataset
168195
idist.barrier()
169196

170-
# Setup data loader also adapted to distributed config
197+
# Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
171198
train_loader = idist.auto_dataloader(
172199
train_dataset, batch_size=config["batch_size"], num_workers=config["num_workers"], shuffle=True, drop_last=True,
173200
)
@@ -180,6 +207,7 @@ def get_dataflow(config):
180207

181208
def initialize(config):
182209
model = utils.get_model(config["model"])
210+
# Adapt model for distributed settings if configured
183211
model = idist.auto_model(model, find_unused_parameters=True)
184212

185213
optimizer = optim.SGD(
@@ -205,24 +233,28 @@ def initialize(config):
205233

206234
def log_metrics(logger, epoch, elapsed, tag, metrics):
207235
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
208-
logger.info(f"\nEpoch {epoch} - Time taken (seconds) : {elapsed:.02f} - {tag} metrics:\n {metrics_output}")
236+
logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")
209237

210238

211239
def log_basic_info(logger, config):
212-
logger.info("Quantization Aware Training {} on CIFAR10".format(config["model"]))
213-
logger.info("- PyTorch version: {}".format(torch.__version__))
214-
logger.info("- Ignite version: {}".format(ignite.__version__))
240+
logger.info(f"Quantization Aware Training {config['model']} on CIFAR10")
241+
logger.info(f"- PyTorch version: {torch.__version__}")
242+
logger.info(f"- Ignite version: {ignite.__version__}")
243+
if torch.cuda.is_available():
244+
logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
245+
logger.info(f"- CUDA version: {torch.version.cuda}")
246+
logger.info(f"- CUDNN version: {torch.backends.cudnn.version()}")
215247

216248
logger.info("\n")
217249
logger.info("Configuration:")
218250
for key, value in config.items():
219-
logger.info("\t{}: {}".format(key, value))
251+
logger.info(f"\t{key}: {value}")
220252
logger.info("\n")
221253

222254
if idist.get_world_size() > 1:
223255
logger.info("\nDistributed setting:")
224-
logger.info("\tbackend: {}".format(idist.backend()))
225-
logger.info("\tworld size: {}".format(idist.get_world_size()))
256+
logger.info(f"\tbackend: {idist.backend()}")
257+
logger.info(f"\tworld size: {idist.get_world_size()}")
226258
logger.info("\n")
227259

228260

@@ -239,6 +271,9 @@ def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, con
239271
# - RunningAverage` on `train_step` output
240272
# - Two progress bars on epochs and optionally on iterations
241273

274+
with_amp = config["with_amp"]
275+
scaler = GradScaler(enabled=with_amp)
276+
242277
def train_step(engine, batch):
243278

244279
x, y = batch[0], batch[1]
@@ -248,12 +283,15 @@ def train_step(engine, batch):
248283
y = y.to(device, non_blocking=True)
249284

250285
model.train()
251-
y_pred = model(x)
252-
loss = criterion(y_pred, y)
286+
287+
with autocast(enabled=with_amp):
288+
y_pred = model(x)
289+
loss = criterion(y_pred, y)
253290

254291
optimizer.zero_grad()
255-
loss.backward()
256-
optimizer.step()
292+
scaler.scale(loss).backward()
293+
scaler.step(optimizer)
294+
scaler.update()
257295

258296
return {
259297
"batch loss": loss.item(),
@@ -272,7 +310,7 @@ def train_step(engine, batch):
272310
train_sampler=train_sampler,
273311
to_save=to_save,
274312
save_every_iters=config["checkpoint_every"],
275-
output_path=config["output_path"],
313+
save_handler=get_save_handler(config),
276314
lr_scheduler=lr_scheduler,
277315
output_names=metric_names if config["log_every_iters"] > 0 else None,
278316
with_pbars=False,
@@ -282,13 +320,22 @@ def train_step(engine, batch):
282320
resume_from = config["resume_from"]
283321
if resume_from is not None:
284322
checkpoint_fp = Path(resume_from)
285-
assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
286-
logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
323+
assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
324+
logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
287325
checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
288326
Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
289327

290328
return trainer
291329

292330

331+
def get_save_handler(config):
332+
if config["with_clearml"]:
333+
from ignite.contrib.handlers.clearml_logger import ClearMLSaver
334+
335+
return ClearMLSaver(dirname=config["output_path"])
336+
337+
return DiskSaver(config["output_path"], require_empty=False)
338+
339+
293340
if __name__ == "__main__":
294341
fire.Fire({"run": run})

0 commit comments

Comments
 (0)