Skip to content

Commit 27150a9

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Moving the check for prototype support in all references. (#4849)
Reviewed By: kazhang Differential Revision: D32216666 fbshipit-source-id: 97eee21e32d3946cfca569eaf0dd9f98ea2fce1a
1 parent 496fb36 commit 27150a9

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

references/classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def load_data(traindir, valdir, args):
182182

183183

184184
def main(args):
185+
if args.weights and PM is None:
186+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
185187
if args.output_dir:
186188
utils.mkdir(args.output_dir)
187189

@@ -226,8 +228,6 @@ def main(args):
226228
if not args.weights:
227229
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
228230
else:
229-
if PM is None:
230-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
231231
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
232232
model.to(device)
233233

references/classification/train_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020

2121
def main(args):
22+
if args.weights and PM is None:
23+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
2224
if args.output_dir:
2325
utils.mkdir(args.output_dir)
2426

@@ -55,8 +57,6 @@ def main(args):
5557
if not args.weights:
5658
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
5759
else:
58-
if PM is None:
59-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
6060
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
6161
model.to(device)
6262

references/detection/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def get_args_parser(add_help=True):
148148

149149

150150
def main(args):
151+
if args.weights and PM is None:
152+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
151153
if args.output_dir:
152154
utils.mkdir(args.output_dir)
153155

@@ -194,8 +196,6 @@ def main(args):
194196
pretrained=args.pretrained, num_classes=num_classes, **kwargs
195197
)
196198
else:
197-
if PM is None:
198-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
199199
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
200200
model.to(device)
201201
if args.distributed and args.sync_bn:

references/segmentation/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
9292

9393

9494
def main(args):
95+
if args.weights and PM is None:
96+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
9597
if args.output_dir:
9698
utils.mkdir(args.output_dir)
9799

@@ -130,8 +132,6 @@ def main(args):
130132
aux_loss=args.aux_loss,
131133
)
132134
else:
133-
if PM is None:
134-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
135135
model = PM.segmentation.__dict__[args.model](
136136
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
137137
)

references/video_classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def collate_fn(batch):
9999

100100

101101
def main(args):
102+
if args.weights and PM is None:
103+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
102104
if args.apex and amp is None:
103105
raise RuntimeError(
104106
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
@@ -214,8 +216,6 @@ def main(args):
214216
if not args.weights:
215217
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
216218
else:
217-
if PM is None:
218-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
219219
model = PM.video.__dict__[args.model](weights=args.weights)
220220
model.to(device)
221221
if args.distributed and args.sync_bn:

0 commit comments

Comments
 (0)