Skip to content

Commit c2ab0c5

Browse files
authored
Make reference scripts compatible with submitit (#3785)
* Add submitit script, partition param and parser on its own method. * Fix method names, handle add_help correctly and refactoring. * Delete run_with_submitit.py file
1 parent 90a6206 commit c2ab0c5

File tree

4 files changed

+71
-75
lines changed

4 files changed

+71
-75
lines changed

references/classification/train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@ def main(args):
224224
print('Training time {}'.format(total_time_str))
225225

226226

227-
def parse_args():
227+
def get_args_parser(add_help=True):
228228
import argparse
229-
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
229+
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
230230

231231
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
232232
parser.add_argument('--model', default='resnet18', help='model')
@@ -291,11 +291,9 @@ def parse_args():
291291
help='number of distributed processes')
292292
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
293293

294-
args = parser.parse_args()
295-
296-
return args
294+
return parser
297295

298296

299297
if __name__ == "__main__":
300-
args = parse_args()
298+
args = get_args_parser().parse_args()
301299
main(args)

references/classification/train_quantization.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414

1515
def main(args):
16-
1716
if args.output_dir:
1817
utils.mkdir(args.output_dir)
1918

@@ -162,9 +161,9 @@ def main(args):
162161
print('Training time {}'.format(total_time_str))
163162

164163

165-
def parse_args():
164+
def get_args_parser(add_help=True):
166165
import argparse
167-
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
166+
parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help)
168167

169168
parser.add_argument('--data-path',
170169
default='/datasets01/imagenet_full_size/061417/',
@@ -250,11 +249,9 @@ def parse_args():
250249
default='env://',
251250
help='url used to set up distributed training')
252251

253-
args = parser.parse_args()
254-
255-
return args
252+
return parser
256253

257254

258255
if __name__ == "__main__":
259-
args = parse_args()
256+
args = get_args_parser().parse_args()
260257
main(args)

references/detection/train.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,65 @@ def get_transform(train, data_augmentation):
5151
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
5252

5353

54+
def get_args_parser(add_help=True):
55+
import argparse
56+
parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help)
57+
58+
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
59+
parser.add_argument('--dataset', default='coco', help='dataset')
60+
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
61+
parser.add_argument('--device', default='cuda', help='device')
62+
parser.add_argument('-b', '--batch-size', default=2, type=int,
63+
help='images per gpu, the total batch size is $NGPU x batch_size')
64+
parser.add_argument('--epochs', default=26, type=int, metavar='N',
65+
help='number of total epochs to run')
66+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
67+
help='number of data loading workers (default: 4)')
68+
parser.add_argument('--lr', default=0.02, type=float,
69+
help='initial learning rate, 0.02 is the default value for training '
70+
'on 8 gpus and 2 images_per_gpu')
71+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
72+
help='momentum')
73+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
74+
metavar='W', help='weight decay (default: 1e-4)',
75+
dest='weight_decay')
76+
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
77+
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
78+
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
79+
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
80+
parser.add_argument('--output-dir', default='.', help='path where to save')
81+
parser.add_argument('--resume', default='', help='resume from checkpoint')
82+
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
83+
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
84+
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
85+
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
86+
help='number of trainable layers of backbone')
87+
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
88+
parser.add_argument(
89+
"--test-only",
90+
dest="test_only",
91+
help="Only test the model",
92+
action="store_true",
93+
)
94+
parser.add_argument(
95+
"--pretrained",
96+
dest="pretrained",
97+
help="Use pre-trained models from the modelzoo",
98+
action="store_true",
99+
)
100+
101+
# distributed training parameters
102+
parser.add_argument('--world-size', default=1, type=int,
103+
help='number of distributed processes')
104+
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
105+
106+
return parser
107+
108+
54109
def main(args):
110+
if args.output_dir:
111+
utils.mkdir(args.output_dir)
112+
55113
utils.init_distributed_mode(args)
56114
print(args)
57115

@@ -147,61 +205,5 @@ def main(args):
147205

148206

149207
if __name__ == "__main__":
150-
import argparse
151-
parser = argparse.ArgumentParser(
152-
description=__doc__)
153-
154-
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
155-
parser.add_argument('--dataset', default='coco', help='dataset')
156-
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
157-
parser.add_argument('--device', default='cuda', help='device')
158-
parser.add_argument('-b', '--batch-size', default=2, type=int,
159-
help='images per gpu, the total batch size is $NGPU x batch_size')
160-
parser.add_argument('--epochs', default=26, type=int, metavar='N',
161-
help='number of total epochs to run')
162-
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
163-
help='number of data loading workers (default: 4)')
164-
parser.add_argument('--lr', default=0.02, type=float,
165-
help='initial learning rate, 0.02 is the default value for training '
166-
'on 8 gpus and 2 images_per_gpu')
167-
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
168-
help='momentum')
169-
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
170-
metavar='W', help='weight decay (default: 1e-4)',
171-
dest='weight_decay')
172-
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
173-
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
174-
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
175-
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
176-
parser.add_argument('--output-dir', default='.', help='path where to save')
177-
parser.add_argument('--resume', default='', help='resume from checkpoint')
178-
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
179-
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
180-
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
181-
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
182-
help='number of trainable layers of backbone')
183-
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
184-
parser.add_argument(
185-
"--test-only",
186-
dest="test_only",
187-
help="Only test the model",
188-
action="store_true",
189-
)
190-
parser.add_argument(
191-
"--pretrained",
192-
dest="pretrained",
193-
help="Use pre-trained models from the modelzoo",
194-
action="store_true",
195-
)
196-
197-
# distributed training parameters
198-
parser.add_argument('--world-size', default=1, type=int,
199-
help='number of distributed processes')
200-
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
201-
202-
args = parser.parse_args()
203-
204-
if args.output_dir:
205-
utils.mkdir(args.output_dir)
206-
208+
args = get_args_parser().parse_args()
207209
main(args)

references/segmentation/train.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def main(args):
172172
print('Training time {}'.format(total_time_str))
173173

174174

175-
def parse_args():
175+
def get_args_parser(add_help=True):
176176
import argparse
177-
parser = argparse.ArgumentParser(description='PyTorch Segmentation Training')
177+
parser = argparse.ArgumentParser(description='PyTorch Segmentation Training', add_help=add_help)
178178

179179
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path')
180180
parser.add_argument('--dataset', default='coco', help='dataset name')
@@ -215,10 +215,9 @@ def parse_args():
215215
help='number of distributed processes')
216216
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
217217

218-
args = parser.parse_args()
219-
return args
218+
return parser
220219

221220

222221
if __name__ == "__main__":
223-
args = parse_args()
222+
args = get_args_parser().parse_args()
224223
main(args)

0 commit comments

Comments
 (0)