Skip to content

Commit c1bc525

Browse files
authored
Merge branch 'main' into transforms/mixupcutmix
2 parents 6f2ebea + 12fd3a6 commit c1bc525

File tree

12 files changed

+78
-16
lines changed

12 files changed

+78
-16
lines changed

docs/source/feature_extraction.rst

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Here is an example of how we might extract features for MaskRCNN:
3939
from torchvision.models.feature_extraction import get_graph_node_names
4040
from torchvision.models.feature_extraction import create_feature_extractor
4141
from torchvision.models.detection.mask_rcnn import MaskRCNN
42+
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
4243
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
4344
4445
@@ -57,7 +58,7 @@ Here is an example of how we might extract features for MaskRCNN:
5758
# that appears in each of the main layers:
5859
return_nodes = {
5960
# node_name: user-specified key for output dict
60-
'layer1.2.relu_2': 'layer1',
61+
'layer1.2.relu_2': 'layer1',
6162
'layer2.3.relu_2': 'layer2',
6263
'layer3.5.relu_2': 'layer3',
6364
'layer4.2.relu_2': 'layer4',
@@ -70,7 +71,7 @@ Here is an example of how we might extract features for MaskRCNN:
7071
# performed is the one that corresponds to the output you desire. You should
7172
# consult the source code for the input model to confirm.)
7273
return_nodes = {
73-
'layer1': 'layer1',
74+
'layer1': 'layer1',
7475
'layer2': 'layer2',
7576
'layer3': 'layer3',
7677
'layer4': 'layer4',
@@ -79,7 +80,7 @@ Here is an example of how we might extract features for MaskRCNN:
7980
# Now you can build the feature extractor. This returns a module whose forward
8081
# method returns a dictionary like:
8182
# {
82-
# 'layer1': ouput of layer 1,
83+
# 'layer1': ouput of layer 1,
8384
# 'layer2': ouput of layer 2,
8485
# 'layer3': ouput of layer 3,
8586
# 'layer4': ouput of layer 4,
@@ -94,10 +95,11 @@ Here is an example of how we might extract features for MaskRCNN:
9495
super(Resnet50WithFPN, self).__init__()
9596
# Get a resnet50 backbone
9697
m = resnet50()
97-
# Extract 4 main layers (note: you can also provide a list for return
98-
# nodes if the keys and the values are the same)
98+
# Extract 4 main layers (note: MaskRCNN needs this particular name
99+
# mapping for return nodes)
99100
self.body = create_feature_extractor(
100-
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
101+
m, return_nodes={f'layer{k}': str(v)
102+
for v, k in enumerate([1, 2, 3, 4])})
101103
# Dry run to get number of channels for FPN
102104
inp = torch.randn(2, 3, 224, 224)
103105
with torch.no_grad():
@@ -106,7 +108,8 @@ Here is an example of how we might extract features for MaskRCNN:
106108
# Build FPN
107109
self.out_channels = 256
108110
self.fpn = FeaturePyramidNetwork(
109-
in_channels_list, out_channels=self.out_channels)
111+
in_channels_list, out_channels=self.out_channels,
112+
extra_blocks=LastLevelMaxPool())
110113
111114
def forward(self, x):
112115
x = self.body(x)

references/classification/train.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
amp = None
1818

1919

20-
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
20+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
21+
print_freq, apex=False, model_ema=None):
2122
model.train()
2223
metric_logger = utils.MetricLogger(delimiter=" ")
2324
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
4546
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
4647
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
4748

49+
if model_ema:
50+
model_ema.update_parameters(model)
4851

49-
def evaluate(model, criterion, data_loader, device, print_freq=100):
52+
53+
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
5054
model.eval()
5155
metric_logger = utils.MetricLogger(delimiter=" ")
52-
header = 'Test:'
56+
header = f'Test: {log_suffix}'
5357
with torch.no_grad():
5458
for image, target in metric_logger.log_every(data_loader, print_freq, header):
5559
image = image.to(device, non_blocking=True)
@@ -199,12 +203,18 @@ def main(args):
199203
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
200204
model_without_ddp = model.module
201205

206+
model_ema = None
207+
if args.model_ema:
208+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
209+
202210
if args.resume:
203211
checkpoint = torch.load(args.resume, map_location='cpu')
204212
model_without_ddp.load_state_dict(checkpoint['model'])
205213
optimizer.load_state_dict(checkpoint['optimizer'])
206214
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
207215
args.start_epoch = checkpoint['epoch'] + 1
216+
if model_ema:
217+
model_ema.load_state_dict(checkpoint['model_ema'])
208218

209219
if args.test_only:
210220
evaluate(model, criterion, data_loader_test, device=device)
@@ -215,16 +225,20 @@ def main(args):
215225
for epoch in range(args.start_epoch, args.epochs):
216226
if args.distributed:
217227
train_sampler.set_epoch(epoch)
218-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
228+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
219229
lr_scheduler.step()
220230
evaluate(model, criterion, data_loader_test, device=device)
231+
if model_ema:
232+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
221233
if args.output_dir:
222234
checkpoint = {
223235
'model': model_without_ddp.state_dict(),
224236
'optimizer': optimizer.state_dict(),
225237
'lr_scheduler': lr_scheduler.state_dict(),
226238
'epoch': epoch,
227239
'args': args}
240+
if model_ema:
241+
checkpoint['model_ema'] = model_ema.state_dict()
228242
utils.save_on_master(
229243
checkpoint,
230244
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
306320
parser.add_argument('--world-size', default=1, type=int,
307321
help='number of distributed processes')
308322
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
323+
parser.add_argument(
324+
'--model-ema', action='store_true',
325+
help='enable tracking Exponential Moving Average of model parameters')
326+
parser.add_argument(
327+
'--model-ema-decay', type=float, default=0.99,
328+
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')
309329

310330
return parser
311331

references/classification/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ def log_every(self, iterable, print_freq, header=None):
161161
print('{} Total time: {}'.format(header, total_time_str))
162162

163163

164+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
165+
"""Maintains moving averages of model parameters using an exponential decay.
166+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
167+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
168+
is used to compute the EMA.
169+
"""
170+
def __init__(self, model, decay, device='cpu'):
171+
ema_avg = (lambda avg_model_param, model_param, num_averaged:
172+
decay * avg_model_param + (1 - decay) * model_param)
173+
super().__init__(model, device, ema_avg)
174+
175+
164176
def accuracy(output, target, topk=(1,)):
165177
"""Computes the accuracy over the k top predictions for the specified values of k"""
166178
with torch.no_grad():

test/cpp/test_custom_operators.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ TEST(test_custom_operators, nms) {
1818
double thresh = 0.7;
1919

2020
torch::jit::push(stack, boxes, scores, thresh);
21-
op->getOperation()(&stack);
21+
op->getOperation()(stack);
2222
at::Tensor output_jit;
2323
torch::jit::pop(stack, output_jit);
2424

@@ -47,7 +47,7 @@ TEST(test_custom_operators, roi_align_visible) {
4747
bool aligned = true;
4848

4949
torch::jit::push(stack, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
50-
op->getOperation()(&stack);
50+
op->getOperation()(stack);
5151
at::Tensor output_jit;
5252
torch::jit::pop(stack, output_jit);
5353

test/test_backbone_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TestFxFeatureExtraction:
3939
'num_classes': 1,
4040
'pretrained': False
4141
}
42-
leaf_modules = [torchvision.ops.StochasticDepth]
42+
leaf_modules = []
4343

4444
def _create_feature_extractor(self, *args, **kwargs):
4545
"""

test/test_datasets.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def inject_fake_data(self, tmpdir, config):
512512
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
513513

514514
def _create_split_txt(self, root):
515-
num_images_per_split = dict(train=3, valid=2, test=1)
515+
num_images_per_split = dict(train=4, valid=3, test=2)
516516

517517
data = [
518518
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
@@ -595,6 +595,17 @@ def test_attr_names(self):
595595
with self.create_dataset() as (dataset, info):
596596
assert tuple(dataset.attr_names) == info["attr_names"]
597597

598+
def test_images_names_split(self):
599+
with self.create_dataset(split='all') as (dataset, _):
600+
all_imgs_names = set(dataset.filename)
601+
602+
merged_imgs_names = set()
603+
for split in ["train", "valid", "test"]:
604+
with self.create_dataset(split=split) as (dataset, _):
605+
merged_imgs_names.update(dataset.filename)
606+
607+
assert merged_imgs_names == all_imgs_names
608+
598609

599610
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
600611
DATASET_CLASS = datasets.VOCSegmentation

torchvision/datasets/celeba.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __init__(
9999

100100
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
101101

102-
self.filename = splits.index
102+
if mask == slice(None): # if split == "all"
103+
self.filename = splits.index
104+
else:
105+
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
103106
self.identity = identity.data[mask]
104107
self.bbox = bbox.data[mask]
105108
self.landmarks_align = landmarks_align.data[mask]

torchvision/models/detection/faster_rcnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
300300
"""
301301
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
302302
303+
Reference: `"Faster R-CNN: Towards Real-Time Object Detection with
304+
Region Proposal Networks" <https://arxiv.org/abs/1506.01497>`_.
305+
303306
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
304307
image, and should be in ``0-1`` range. Different images can have different sizes.
305308

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
278278
"""
279279
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
280280
281+
Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_.
282+
281283
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
282284
image, and should be in ``0-1`` range. Different images can have different sizes.
283285

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
271271
"""
272272
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
273273
274+
Reference: `"Mask R-CNN" <https://arxiv.org/abs/1703.06870>`_.
275+
274276
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
275277
image, and should be in ``0-1`` range. Different images can have different sizes.
276278

0 commit comments

Comments
 (0)