diff --git a/train.py b/train.py index 284ebf8477..b0d74d39e2 100755 --- a/train.py +++ b/train.py @@ -17,6 +17,13 @@ from model import Model from test import validation +try: + from apex import amp + from apex import fp16_utils + APEX_AVAILABLE = True + amp_handle = amp.init(enabled=True) +except ModuleNotFoundError: + APEX_AVAILABLE = False def train(opt): """ dataset preparation """ @@ -42,7 +49,7 @@ def train(opt): if opt.rgb: opt.input_channel = 3 - model = Model(opt) + model = Model(opt).cuda() print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) @@ -62,9 +69,7 @@ def train(opt): param.data.fill_(1) continue - # data parallel for multi-GPU - model = torch.nn.DataParallel(model).cuda() - model.train() + if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) @@ -118,6 +123,13 @@ def train(opt): best_norm_ED = 1e+6 i = start_iter + if APEX_AVAILABLE: + model, optimizer = amp.initialize(model, optimizer, opt_level="O2") + + # data parallel for multi-GPU + model = torch.nn.DataParallel(model).cuda() + model.train() + while(True): # train part for p in model.parameters(): @@ -140,8 +152,13 @@ def train(opt): cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() - cost.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) + if APEX_AVAILABLE: + with amp.scale_loss(cost, optimizer) as scaled_loss: + scaled_loss.backward() + fp16_utils.clip_grad_norm(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) + else: + cost.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost)