Skip to content

Commit f52ddb0

Browse files
authored
Adding label smoothing on classification reference (#4335)
* Adding label smoothing on classification reference. * Replace underscore with dash.
1 parent 388b19c commit f52ddb0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

references/classification/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def main(args):
175175
if args.distributed and args.sync_bn:
176176
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
177177

178-
criterion = nn.CrossEntropyLoss()
178+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
179179

180180
opt_name = args.opt.lower()
181181
if opt_name == 'sgd':
@@ -256,6 +256,9 @@ def get_args_parser(add_help=True):
256256
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
257257
metavar='W', help='weight decay (default: 1e-4)',
258258
dest='weight_decay')
259+
parser.add_argument('--label-smoothing', default=0.0, type=float,
260+
help='label smoothing (default: 0.0)',
261+
dest='label_smoothing')
259262
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
260263
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
261264
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')

0 commit comments

Comments
 (0)