Skip to content

Commit d59398b

Browse files
authored
Detection recipe enhancements (#5715)
* Detection recipe enhancements * Add back nesterov momentum
1 parent ec1c2a1 commit d59398b

File tree

12 files changed

+76
-11
lines changed

12 files changed

+76
-11
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def main(args):
230230
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
231231

232232
if args.norm_weight_decay is None:
233-
parameters = model.parameters()
233+
parameters = [p for p in model.parameters() if p.requires_grad]
234234
else:
235235
param_groups = torchvision.ops._utils.split_normalization_params(model)
236236
wd_groups = [args.norm_weight_decay, args.weight_decay]

references/detection/presets.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class DetectionPresetTrain:
6-
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
6+
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
77
if data_augmentation == "hflip":
88
self.transforms = T.Compose(
99
[
@@ -12,6 +12,27 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)
1212
T.ConvertImageDtype(torch.float),
1313
]
1414
)
15+
elif data_augmentation == "lsj":
16+
self.transforms = T.Compose(
17+
[
18+
T.ScaleJitter(target_size=(1024, 1024)),
19+
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
20+
T.RandomHorizontalFlip(p=hflip_prob),
21+
T.PILToTensor(),
22+
T.ConvertImageDtype(torch.float),
23+
]
24+
)
25+
elif data_augmentation == "multiscale":
26+
self.transforms = T.Compose(
27+
[
28+
T.RandomShortestSize(
29+
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
30+
),
31+
T.RandomHorizontalFlip(p=hflip_prob),
32+
T.PILToTensor(),
33+
T.ConvertImageDtype(torch.float),
34+
]
35+
)
1536
elif data_augmentation == "ssd":
1637
self.transforms = T.Compose(
1738
[

references/detection/train.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_args_parser(add_help=True):
6868
parser.add_argument(
6969
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
7070
)
71+
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
7172
parser.add_argument(
7273
"--lr",
7374
default=0.02,
@@ -84,6 +85,12 @@ def get_args_parser(add_help=True):
8485
help="weight decay (default: 1e-4)",
8586
dest="weight_decay",
8687
)
88+
parser.add_argument(
89+
"--norm-weight-decay",
90+
default=None,
91+
type=float,
92+
help="weight decay for Normalization layers (default: None, same value as --wd)",
93+
)
8794
parser.add_argument(
8895
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
8996
)
@@ -176,6 +183,8 @@ def main(args):
176183

177184
print("Creating model")
178185
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
186+
if args.data_augmentation in ["multiscale", "lsj"]:
187+
kwargs["_skip_resize"] = True
179188
if "rcnn" in args.model:
180189
if args.rpn_score_thresh is not None:
181190
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
@@ -191,8 +200,26 @@ def main(args):
191200
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
192201
model_without_ddp = model.module
193202

194-
params = [p for p in model.parameters() if p.requires_grad]
195-
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
203+
if args.norm_weight_decay is None:
204+
parameters = [p for p in model.parameters() if p.requires_grad]
205+
else:
206+
param_groups = torchvision.ops._utils.split_normalization_params(model)
207+
wd_groups = [args.norm_weight_decay, args.weight_decay]
208+
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
209+
210+
opt_name = args.opt.lower()
211+
if opt_name.startswith("sgd"):
212+
optimizer = torch.optim.SGD(
213+
parameters,
214+
lr=args.lr,
215+
momentum=args.momentum,
216+
weight_decay=args.weight_decay,
217+
nesterov="nesterov" in opt_name,
218+
)
219+
elif opt_name == "adamw":
220+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
221+
else:
222+
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
196223

197224
scaler = torch.cuda.amp.GradScaler() if args.amp else None
198225

test/test_extended_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def test_get_weight(name, weight):
6464
)
6565
def test_naming_conventions(model_fn):
6666
weights_enum = _get_model_weights(model_fn)
67-
print(weights_enum)
6867
assert weights_enum is not None
6968
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
7069

torchvision/models/detection/faster_rcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
box_batch_size_per_image=512,
188188
box_positive_fraction=0.25,
189189
bbox_reg_weights=None,
190+
**kwargs,
190191
):
191192

192193
if not hasattr(backbone, "out_channels"):
@@ -268,7 +269,7 @@ def __init__(
268269
image_mean = [0.485, 0.456, 0.406]
269270
if image_std is None:
270271
image_std = [0.229, 0.224, 0.225]
271-
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
272+
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
272273

273274
super().__init__(backbone, rpn, roi_heads, transform)
274275

torchvision/models/detection/fcos.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __init__(
373373
nms_thresh: float = 0.6,
374374
detections_per_img: int = 100,
375375
topk_candidates: int = 1000,
376+
**kwargs,
376377
):
377378
super().__init__()
378379
_log_api_usage_once(self)
@@ -410,7 +411,7 @@ def __init__(
410411
image_mean = [0.485, 0.456, 0.406]
411412
if image_std is None:
412413
image_std = [0.229, 0.224, 0.225]
413-
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
414+
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
414415

415416
self.center_sampling_radius = center_sampling_radius
416417
self.score_thresh = score_thresh

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
keypoint_head=None,
199199
keypoint_predictor=None,
200200
num_keypoints=None,
201+
**kwargs,
201202
):
202203

203204
if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
@@ -259,6 +260,7 @@ def __init__(
259260
box_batch_size_per_image,
260261
box_positive_fraction,
261262
bbox_reg_weights,
263+
**kwargs,
262264
)
263265

264266
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
mask_roi_pool=None,
196196
mask_head=None,
197197
mask_predictor=None,
198+
**kwargs,
198199
):
199200

200201
if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
@@ -254,6 +255,7 @@ def __init__(
254255
box_batch_size_per_image,
255256
box_positive_fraction,
256257
bbox_reg_weights,
258+
**kwargs,
257259
)
258260

259261
self.roi_heads.mask_roi_pool = mask_roi_pool

torchvision/models/detection/retinanet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def __init__(
342342
fg_iou_thresh=0.5,
343343
bg_iou_thresh=0.4,
344344
topk_candidates=1000,
345+
**kwargs,
345346
):
346347
super().__init__()
347348
_log_api_usage_once(self)
@@ -383,7 +384,7 @@ def __init__(
383384
image_mean = [0.485, 0.456, 0.406]
384385
if image_std is None:
385386
image_std = [0.229, 0.224, 0.225]
386-
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
387+
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
387388

388389
self.score_thresh = score_thresh
389390
self.nms_thresh = nms_thresh

torchvision/models/detection/ssd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
iou_thresh: float = 0.5,
196196
topk_candidates: int = 400,
197197
positive_fraction: float = 0.25,
198+
**kwargs: Any,
198199
):
199200
super().__init__()
200201
_log_api_usage_once(self)
@@ -227,7 +228,7 @@ def __init__(
227228
if image_std is None:
228229
image_std = [0.229, 0.224, 0.225]
229230
self.transform = GeneralizedRCNNTransform(
230-
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
231+
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
231232
)
232233

233234
self.score_thresh = score_thresh

torchvision/models/detection/transform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List, Tuple, Dict, Optional
2+
from typing import List, Tuple, Dict, Optional, Any
33

44
import torch
55
import torchvision
@@ -91,6 +91,7 @@ def __init__(
9191
image_std: List[float],
9292
size_divisible: int = 32,
9393
fixed_size: Optional[Tuple[int, int]] = None,
94+
**kwargs: Any,
9495
):
9596
super().__init__()
9697
if not isinstance(min_size, (list, tuple)):
@@ -101,6 +102,7 @@ def __init__(
101102
self.image_std = image_std
102103
self.size_divisible = size_divisible
103104
self.fixed_size = fixed_size
105+
self._skip_resize = kwargs.pop("_skip_resize", False)
104106

105107
def forward(
106108
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
@@ -170,6 +172,8 @@ def resize(
170172
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
171173
h, w = image.shape[-2:]
172174
if self.training:
175+
if self._skip_resize:
176+
return image, target
173177
size = float(self.torch_choice(self.min_size))
174178
else:
175179
# FIXME assume for now that testing uses the largest scale

torchvision/ops/_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ def split_normalization_params(
4343
) -> Tuple[List[Tensor], List[Tensor]]:
4444
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
4545
if not norm_classes:
46-
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
46+
norm_classes = [
47+
nn.modules.batchnorm._BatchNorm,
48+
nn.LayerNorm,
49+
nn.GroupNorm,
50+
nn.modules.instancenorm._InstanceNorm,
51+
nn.LocalResponseNorm,
52+
]
4753

4854
for t in norm_classes:
4955
if not issubclass(t, nn.Module):

0 commit comments

Comments
 (0)