Skip to content

Commit 9ee69c4

Browse files
committed
Removing hardcoded interpolation and sizes from the scripts.
1 parent d861b33 commit 9ee69c4

File tree

3 files changed

+51
-20
lines changed

3 files changed

+51
-20
lines changed

references/classification/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
3131
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
3232
normalization and thus are trained with the default parameters.
3333

34+
### Inception V3
35+
36+
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
37+
38+
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
39+
40+
```
41+
torchrun --nproc_per_node=8 train.py --model inception_v3
42+
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
43+
```
44+
3445
### ResNext-50 32x4d
3546
```
3647
torchrun --nproc_per_node=8 train.py\
@@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht
7990

8091
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
8192

93+
All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
94+
```
95+
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
96+
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
97+
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
98+
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
99+
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
100+
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
101+
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
102+
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
103+
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
104+
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
105+
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
106+
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
107+
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
108+
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
109+
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
110+
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
111+
```
82112

83113
### RegNet
84114

@@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t
181211
```
182212
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
183213
```
214+
215+
For inception_v3 you need to pass the following extra parameters:
216+
```
217+
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299
218+
```

references/classification/train.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,26 +107,8 @@ def _get_cache_path(filepath):
107107
def load_data(traindir, valdir, args):
108108
# Data loading code
109109
print("Loading data")
110-
val_resize_size, val_crop_size, train_crop_size = 256, 224, 224
110+
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
111111
interpolation = InterpolationMode(args.interpolation)
112-
if args.model == "inception_v3":
113-
val_resize_size, val_crop_size, train_crop_size = 342, 299, 299
114-
elif args.model == "resnet50":
115-
val_resize_size, val_crop_size, train_crop_size = 256, 224, 176
116-
elif args.model.startswith("efficientnet_"):
117-
sizes = {
118-
"b0": (256, 224, 224),
119-
"b1": (256, 240, 240),
120-
"b2": (288, 288, 288),
121-
"b3": (320, 300, 300),
122-
"b4": (384, 380, 380),
123-
"b5": (456, 456, 456),
124-
"b6": (528, 528, 528),
125-
"b7": (600, 600, 600),
126-
}
127-
e_type = args.model.replace("efficientnet_", "")
128-
val_resize_size, val_crop_size, train_crop_size = sizes[e_type]
129-
interpolation = InterpolationMode.BICUBIC
130112

131113
print("Loading training data")
132114
st = time.time()
@@ -458,7 +440,13 @@ def get_args_parser(add_help=True):
458440
parser.add_argument(
459441
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
460442
)
461-
parser.add_argument("--interpolation", default="bilinear", help="the default interpolation (default: bilinear)")
443+
parser.add_argument("--interpolation", default="bilinear", help="the interpolation method (default: bilinear)")
444+
parser.add_argument("--val-resize-size", default=256, type=int,
445+
help="the resize size used for validation (default: 256)")
446+
parser.add_argument("--val-crop-size", default=224, type=int,
447+
help="the central crop size used for validation (default: 224)")
448+
parser.add_argument("--train-crop-size", default=224, type=int,
449+
help="the random crop size used for training (default: 224)")
462450

463451
return parser
464452

references/classification/train_quantization.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ def get_args_parser(add_help=True):
236236
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
237237
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
238238

239+
parser.add_argument("--interpolation", default="bilinear", help="the interpolation method (default: bilinear)")
240+
parser.add_argument("--val-resize-size", default=256, type=int,
241+
help="the resize size used for validation (default: 256)")
242+
parser.add_argument("--val-crop-size", default=224, type=int,
243+
help="the central crop size used for validation (default: 224)")
244+
parser.add_argument("--train-crop-size", default=224, type=int,
245+
help="the random crop size used for training (default: 224)")
246+
239247
return parser
240248

241249

0 commit comments

Comments
 (0)