diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..c13d0256e8d 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -53,6 +53,16 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104 T.ConvertImageDtype(torch.float), ] ) + elif data_augmentation == "yolo": + self.transforms = T.Compose( + [ + T.ScaleJitter(target_size=(640, 640)), + T.FixedSizeCrop(size=(640, 640), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/test/assets/yolov4-tiny-3l.cfg b/test/assets/yolov4-tiny-3l.cfg new file mode 100644 index 00000000000..c2dcb29e481 --- /dev/null +++ b/test/assets/yolov4-tiny-3l.cfg @@ -0,0 +1,327 @@ +[net] +batch=64 +subdivisions=1 +width=608 +height=608 +channels=3 +momentum=0.9 +decay=0.0005 +angle=0 +saturation = 1.5 +exposure = 1.5 +hue=.1 + +learning_rate=0.00261 +burn_in=1000 +max_batches = 500200 +policy=steps +steps=400000,450000 +scales=.1,.1 + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=2 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=2 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=64 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +################################## + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + + + +[yolo] +mask = 6,7,8 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 + +[route] +layers = -4 + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[upsample] +stride=2 + +[route] +layers = -1, 23 + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + +[yolo] +mask = 3,4,5 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 + + +[route] +layers = -3 + +[convolutional] +batch_normalize=1 +filters=64 +size=1 +stride=1 +pad=1 +activation=leaky + +[upsample] +stride=2 + +[route] +layers = -1, 15 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + +[yolo] +mask = 0,1,2 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 diff --git a/test/expect/ModelTester.test_yolo_darknet_expect.pkl b/test/expect/ModelTester.test_yolo_darknet_expect.pkl new file mode 100644 index 00000000000..d3e297368cf Binary files /dev/null and b/test/expect/ModelTester.test_yolo_darknet_expect.pkl differ diff --git a/test/expect/ModelTester.test_yolov4_expect.pkl b/test/expect/ModelTester.test_yolov4_expect.pkl new file mode 100644 index 00000000000..5996ef88fd5 Binary files /dev/null and b/test/expect/ModelTester.test_yolov4_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index e1a288f4eb5..b220aebff77 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -19,10 +19,12 @@ from PIL import Image from torchvision import models, transforms from torchvision.models import get_model_builder, list_models +from torchvision.models.detection import yolo_darknet ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" +DARKNET_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "yolov4-tiny-3l.cfg") def list_model_fns(module): @@ -370,6 +372,10 @@ def _check_input_backprop(model, inputs): "input_shape": (1, 3, 16, 224, 224), }, "googlenet": {"init_weights": True}, + "yolov4": { + "num_classes": 10, + "input_shape": (3, 224, 224), + }, } # speeding up slow models: slow_models = [ @@ -467,6 +473,10 @@ def is_skippable(model_name, device): "max_trainable": 5, "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107], }, + "yolov4": { + "max_trainable": 5, + "n_trn_params_per_layer": [138, 174, 234, 294, 318, 339], + }, } @@ -783,6 +793,61 @@ def check_out(out): _check_input_backprop(model, x) +def check_model_output(out, model_name): + assert len(out) == 1 + + def compact(tensor): + tensor = tensor.cpu() + size = tensor.size() + elements_per_sample = functools.reduce(operator.mul, size[1:], 1) + if elements_per_sample > 30: + return compute_mean_std(tensor) + else: + return subsample_tensor(tensor) + + def subsample_tensor(tensor): + num_elems = tensor.size(0) + num_samples = 20 + if num_elems <= num_samples: + return tensor + + ith_index = num_elems // num_samples + return tensor[ith_index - 1 :: ith_index] + + def compute_mean_std(tensor): + # can't compute mean of integral tensor + tensor = tensor.to(torch.double) + mean = torch.mean(tensor) + std = torch.std(tensor) + return {"mean": mean, "std": std} + + output = map_nested_tensor_object(out, tensor_map_fn=compact) + prec = 0.01 + try: + # We first try to assert the entire output if possible. This is not + # only the best way to assert results but also handles the cases + # where we need to create a new expected result. + _assert_expected(output, model_name, prec=prec) + except AssertionError: + # Unfortunately detection models are flaky due to the unstable sort + # in NMS. If matching across all outputs fails, use the same approach + # as in NMSTester.test_nms_cuda to see if this is caused by duplicate + # scores. + expected_file = _get_expected_file(model_name) + expected = torch.load(expected_file) + torch.testing.assert_close( + output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False + ) + + # Note: Fmassa proposed turning off NMS by adapting the threshold + # and then using the Hungarian algorithm as in DETR to find the + # best match between output and expected boxes and eliminate some + # of the flakiness. Worth exploring. + return False # Partial validation performed + + return True # Full validation performed + + @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_detection_model(model_fn, dev): @@ -809,61 +874,7 @@ def test_detection_model(model_fn, dev): out = model(model_input) assert model_input[0] is x - def check_out(out): - assert len(out) == 1 - - def compact(tensor): - tensor = tensor.cpu() - size = tensor.size() - elements_per_sample = functools.reduce(operator.mul, size[1:], 1) - if elements_per_sample > 30: - return compute_mean_std(tensor) - else: - return subsample_tensor(tensor) - - def subsample_tensor(tensor): - num_elems = tensor.size(0) - num_samples = 20 - if num_elems <= num_samples: - return tensor - - ith_index = num_elems // num_samples - return tensor[ith_index - 1 :: ith_index] - - def compute_mean_std(tensor): - # can't compute mean of integral tensor - tensor = tensor.to(torch.double) - mean = torch.mean(tensor) - std = torch.std(tensor) - return {"mean": mean, "std": std} - - output = map_nested_tensor_object(out, tensor_map_fn=compact) - prec = 0.01 - try: - # We first try to assert the entire output if possible. This is not - # only the best way to assert results but also handles the cases - # where we need to create a new expected result. - _assert_expected(output, model_name, prec=prec) - except AssertionError: - # Unfortunately detection models are flaky due to the unstable sort - # in NMS. If matching across all outputs fails, use the same approach - # as in NMSTester.test_nms_cuda to see if this is caused by duplicate - # scores. - expected_file = _get_expected_file(model_name) - expected = torch.load(expected_file) - torch.testing.assert_close( - output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False - ) - - # Note: Fmassa proposed turning off NMS by adapting the threshold - # and then using the Hungarian algorithm as in DETR to find the - # best match between output and expected boxes and eliminate some - # of the flakiness. Worth exploring. - return False # Partial validation performed - - return True # Full validation performed - - full_validation = check_out(out) + full_validation = check_model_output(out, model_name) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) if dev == "cuda": @@ -871,7 +882,7 @@ def compute_mean_std(tensor): out = model(model_input) # See autocast_flaky_numerics comment at top of file. if model_name not in autocast_flaky_numerics: - full_validation &= check_out(out) + full_validation &= check_model_output(out) if not full_validation: msg = ( @@ -886,32 +897,68 @@ def compute_mean_std(tensor): _check_input_backprop(model, model_input) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +def test_yolo_darknet(dev): + set_rng_seed(0) + model_name = "yolo_darknet" + dtype = torch.get_default_dtype() + input_shape = (3, 224, 224) + + model = yolo_darknet(DARKNET_CONFIG) + model.eval().to(device=dev, dtype=dtype) + x = _get_image(input_shape=input_shape, real_image=False, device=dev, dtype=dtype) + model_input = [x] + with torch.no_grad(), freeze_rng_state(): + out = model(model_input) + assert model_input[0] is x + + full_validation = check_model_output(out, model_name) + _check_jit_scriptable(model, ([x],), unwrapper=None, eager_out=out) + + if dev == "cuda": + with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state(): + out = model(model_input) + full_validation &= check_model_output(out, model_name) + + if not full_validation: + msg = ( + "The output of yolo_darknet could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase." + ) + warnings.warn(msg, RuntimeWarning) + pytest.skip(msg) + + _check_input_backprop(model, model_input) + + @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) model = model_fn(num_classes=50, weights=None, weights_backbone=None) - input_shape = (3, 300, 300) + input_shape = (3, 256, 256) # YOLO models expect the input dimensions to be a multiple of 32 or 64. x = [torch.rand(input_shape)] # validate that targets are present in training - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x) # validate type targets = [{"boxes": 0.0}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, TypeError)): model(x, targets=targets) # validate boxes shape for boxes in (torch.rand((4,)), torch.rand((1, 5))): targets = [{"boxes": boxes}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x, targets=targets) # validate that no degenerate boxes are present boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) targets = [{"boxes": boxes}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x, targets=targets) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 645d4624d64..edba1b1626c 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -1,7 +1,13 @@ import pytest import torch from common_utils import assert_equal -from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator +from torchvision.models.detection.anchor_utils import ( + AnchorGenerator, + DefaultBoxGenerator, + global_xy, + grid_centers, + grid_offsets, +) from torchvision.models.detection.image_list import ImageList @@ -97,3 +103,40 @@ def test_defaultbox_generator(self): assert tuple(dboxes[1].shape) == (4, 4) torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8) torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8) + + +@pytest.mark.parametrize("width,height", [(10, 5)]) +def test_grid_offsets(width: int, height: int): + size = torch.tensor([width, height]) + offsets = grid_offsets(size) + assert offsets.shape == (height, width, 2) + assert torch.equal(offsets[0, :, 0], torch.arange(width, dtype=offsets.dtype)) + assert torch.equal(offsets[0, :, 1], torch.zeros(width, dtype=offsets.dtype)) + assert torch.equal(offsets[:, 0, 0], torch.zeros(height, dtype=offsets.dtype)) + assert torch.equal(offsets[:, 0, 1], torch.arange(height, dtype=offsets.dtype)) + + +@pytest.mark.parametrize("width,height", [(10, 5)]) +def test_grid_centers(width: int, height: int): + size = torch.tensor([width, height]) + centers = grid_centers(size) + assert centers.shape == (height, width, 2) + assert torch.equal(centers[0, :, 0], 0.5 + torch.arange(width, dtype=torch.float)) + assert torch.equal(centers[0, :, 1], 0.5 * torch.ones(width)) + assert torch.equal(centers[:, 0, 0], 0.5 * torch.ones(height)) + assert torch.equal(centers[:, 0, 1], 0.5 + torch.arange(height, dtype=torch.float)) + + +def test_global_xy(): + xy = torch.ones((2, 4, 4, 3, 2)) * 0.5 # 4x4 grid of coordinates to the center of the cell. + image_size = torch.tensor([400, 200]) + xy = global_xy(xy, image_size) + assert xy.shape == (2, 4, 4, 3, 2) + assert torch.all(xy[:, :, 0, :, 0] == 50) + assert torch.all(xy[:, 0, :, :, 1] == 25) + assert torch.all(xy[:, :, 1, :, 0] == 150) + assert torch.all(xy[:, 1, :, :, 1] == 75) + assert torch.all(xy[:, :, 2, :, 0] == 250) + assert torch.all(xy[:, 2, :, :, 1] == 125) + assert torch.all(xy[:, :, 3, :, 0] == 350) + assert torch.all(xy[:, 3, :, :, 1] == 175) diff --git a/test/test_models_detection_box_utils.py b/test/test_models_detection_box_utils.py new file mode 100644 index 00000000000..75f748249be --- /dev/null +++ b/test/test_models_detection_box_utils.py @@ -0,0 +1,77 @@ +import pytest +import torch +from torchvision.models.detection.anchor_utils import grid_centers +from torchvision.models.detection.box_utils import aligned_iou, box_size_ratio, iou_below, is_inside_box + + +@pytest.mark.parametrize( + "dims1, dims2, expected_ious", + [ + ( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), + ) + ], +) +def test_aligned_iou(dims1, dims2, expected_ious): + torch.testing.assert_close(aligned_iou(dims1, dims2), expected_ious) + + +def test_iou_below(): + tl = torch.rand((10, 10, 3, 2)) * 100 + br = tl + 10 + pred_boxes = torch.cat((tl, br), -1) + target_boxes = torch.stack((pred_boxes[1, 1, 0], pred_boxes[3, 5, 1])) + result = iou_below(pred_boxes, target_boxes, 0.9) + assert result.shape == (10, 10, 3) + assert not result[1, 1, 0] + assert not result[3, 5, 1] + + +def test_is_inside_box(): + """ + centers: + [[1,1; 3,1; 5,1; 7,1; 9,1; 11,1; 13,1; 15,1; 17,1; 19,1] + [1,3; 3,3; 5,3; 7,3; 9,3; 11,3; 13,3; 15,3; 17,3; 19,3] + [1,5; 3,5; 5,5; 7,5; 9,5; 11,5; 13,5; 15,5; 17,5; 19,5] + [1,7; 3,7; 5,7; 7,7; 9,7; 11,7; 13,7; 15,7; 17,7; 19,7] + [1,9; 3,9; 5,9; 7,9; 9,9; 11,9; 13,9; 15,9; 17,9; 19,9]] + + is_inside[..., 0]: + [[F, F, F, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F]] + + is_inside[..., 1]: + [[F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, T, T, F]] + """ + size = torch.tensor([10, 5]) + centers = grid_centers(size) * 2.0 + centers = centers.view(-1, 2) + boxes = torch.tensor([[2, 2, 6, 6], [14, 8, 18, 10]]) + is_inside = is_inside_box(centers, boxes).view(5, 10, 2) + assert torch.count_nonzero(is_inside) == 6 + assert torch.all(is_inside[1:3, 1:3, 0]) + assert torch.all(is_inside[4, 7:9, 1]) + + +def test_box_size_ratio(): + wh1 = torch.tensor([[24, 11], [12, 25], [26, 27], [15, 17]]) + wh2 = torch.tensor([[10, 30], [15, 9]]) + result = box_size_ratio(wh1, wh2) + assert result.shape == (4, 2) + assert result[0, 0] == 30 / 11 + assert result[0, 1] == 24 / 15 + assert result[1, 0] == 12 / 10 + assert result[1, 1] == 25 / 9 + assert result[2, 0] == 26 / 10 + assert result[2, 1] == 27 / 9 + assert result[3, 0] == 30 / 17 + assert result[3, 1] == 17 / 9 diff --git a/test/test_models_detection_target_matching.py b/test/test_models_detection_target_matching.py new file mode 100644 index 00000000000..aaf678575e3 --- /dev/null +++ b/test/test_models_detection_target_matching.py @@ -0,0 +1,26 @@ +import torch + +from torchvision.models.detection.target_matching import _sim_ota_match + + +def test_sim_ota_match(): + # For each of the two targets, k will be the sum of the IoUs. 2 and 1 predictions will be selected for the first and + # the second target respectively. + ious = torch.tensor([[0.1, 0.2], [0.1, 0.3], [0.9, 0.4], [0.9, 0.1]]) + # Costs will determine that the first and the last prediction will be selected for the first target, and the first + # prediction will be selected for the second target. The first prediction was selected for two targets, but it will + # be matched to the best target only (the second one). + costs = torch.tensor([[0.3, 0.1], [0.5, 0.2], [0.4, 0.5], [0.3, 0.3]]) + matched_preds, matched_targets = _sim_ota_match(costs, ious) + + # The first and the last prediction were matched. + assert len(matched_preds) == 4 + assert matched_preds[0] + assert not matched_preds[1] + assert not matched_preds[2] + assert matched_preds[3] + + # The first prediction was matched to the target 1 and the last prediction was matched to target 0. + assert len(matched_targets) == 2 + assert matched_targets[0] == 1 + assert matched_targets[1] == 0 diff --git a/test/test_models_detection_yolo_networks.py b/test/test_models_detection_yolo_networks.py new file mode 100644 index 00000000000..0bec4c09e8f --- /dev/null +++ b/test/test_models_detection_yolo_networks.py @@ -0,0 +1,83 @@ +import pytest +import torch.nn as nn +from torchvision.models.detection.yolo_networks import ( + _create_convolutional, + _create_maxpool, + _create_shortcut, + _create_upsample, +) + + +@pytest.mark.parametrize( + "config", + [ + ({"batch_normalize": 1, "filters": 8, "size": 3, "stride": 1, "pad": 1, "activation": "leaky"}), + ({"batch_normalize": 0, "filters": 2, "size": 1, "stride": 1, "pad": 1, "activation": "mish"}), + ({"batch_normalize": 1, "filters": 6, "size": 3, "stride": 2, "pad": 1, "activation": "logistic"}), + ({"batch_normalize": 0, "filters": 4, "size": 3, "stride": 2, "pad": 0, "activation": "linear"}), + ], +) +def test_create_convolutional(config): + conv, _ = _create_convolutional(config, [3]) + + assert conv.conv.out_channels == config["filters"] + assert conv.conv.kernel_size == (config["size"], config["size"]) + assert conv.conv.stride == (config["stride"], config["stride"]) + + pad_size = (config["size"] - 1) // 2 if config["pad"] else 0 + if config["pad"]: + assert conv.conv.padding == (pad_size, pad_size) + + if config["batch_normalize"]: + assert isinstance(conv.norm, nn.BatchNorm2d) + + if config["activation"] == "linear": + assert isinstance(conv.act, nn.Identity) + elif config["activation"] == "logistic": + assert isinstance(conv.act, nn.Sigmoid) + else: + assert conv.act.__class__.__name__.lower().startswith(config["activation"]) + + +@pytest.mark.parametrize( + "config", + [ + ({"size": 2, "stride": 2}), + ({"size": 6, "stride": 3}), + ], +) +def test_create_maxpool(config): + pad_size, remainder = divmod(max(config["size"], config["stride"]) - config["stride"], 2) + maxpool, _ = _create_maxpool(config, [3]) + + assert maxpool.maxpool.kernel_size == config["size"] + assert maxpool.maxpool.stride == config["stride"] + assert maxpool.maxpool.padding == pad_size + if remainder != 0: + assert isinstance(maxpool.pad, nn.ZeroPad2d) + + +@pytest.mark.parametrize( + "config", + [ + ({"from": 1, "activation": "linear"}), + ({"from": 3, "activation": "linear"}), + ], +) +def test_create_shortcut(config): + shortcut, _ = _create_shortcut(config, [3]) + + assert shortcut.source_layer == config["from"] + + +@pytest.mark.parametrize( + "config", + [ + ({"stride": 2}), + ({"stride": 4}), + ], +) +def test_create_upsample(config): + upsample, _ = _create_upsample(config, [3]) + + assert upsample.scale_factor == float(config["stride"]) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 4146651c737..35fcdcf9015 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -5,3 +5,13 @@ from .retinanet import * from .ssd import * from .ssdlite import * +from .yolo import YOLO, yolo_darknet, yolov4, YOLOV4_Backbone_Weights, YOLOV4_Weights +from .yolo_networks import ( + DarknetNetwork, + YOLOV4Network, + YOLOV4P6Network, + YOLOV4TinyNetwork, + YOLOV5Network, + YOLOV7Network, + YOLOXNetwork, +) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 253f6502a9b..7a41a5b0e0c 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -266,3 +266,57 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten ) dboxes.append(dboxes_in_image) return dboxes + + +def grid_offsets(grid_size: Tensor) -> Tensor: + """Given a grid size, returns a tensor containing offsets to the grid cells. + + Args: + The width and height of the grid in a tensor. + + Returns: + A ``[height, width, 2]`` tensor containing the grid cell `(x, y)` offsets. + """ + x_range = torch.arange(grid_size[0].item(), device=grid_size.device) + y_range = torch.arange(grid_size[1].item(), device=grid_size.device) + grid_y, grid_x = torch.meshgrid([y_range, x_range], indexing="ij") + return torch.stack((grid_x, grid_y), -1) + + +def grid_centers(grid_size: Tensor) -> Tensor: + """Given a grid size, returns a tensor containing coordinates to the centers of the grid cells. + + Returns: + A ``[height, width, 2]`` tensor containing coordinates to the centers of the grid cells. + """ + return grid_offsets(grid_size) + 0.5 + + +@torch.jit.script +def global_xy(xy: Tensor, image_size: Tensor) -> Tensor: + """Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. + + The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding + offset to the cell, dividing by the grid size, and multiplying by the image size, we get global coordinates in the + image scale. + + The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based + generator will loose track of e.g. ``xy.shape[1]`` and treat it as a Python variable and not a tensor. This will + cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes. + + Args: + xy: The predicted center coordinates before scaling. Values from zero to one in a tensor sized + ``[batch_size, height, width, boxes_per_cell, 2]``. + image_size: Width and height in a vector that will be used to scale the coordinates. + + Returns: + Global coordinates scaled to the size of the network input image, in a tensor with the same shape as the input + tensor. + """ + height = xy.shape[1] + width = xy.shape[2] + grid_size = torch.tensor([width, height], device=xy.device) + # Scripting requires explicit conversion to a floating point type. + offset = grid_offsets(grid_size).to(xy.dtype).unsqueeze(2) # [height, width, 1, 2] + scale = torch.true_divide(image_size, grid_size) + return (xy + offset) * scale diff --git a/torchvision/models/detection/box_utils.py b/torchvision/models/detection/box_utils.py new file mode 100644 index 00000000000..ac188ebaebf --- /dev/null +++ b/torchvision/models/detection/box_utils.py @@ -0,0 +1,80 @@ +import torch +from torch import Tensor + +from ...ops import box_iou + + +def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor: + """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at + the same coordinates. + + Args: + wh1: An ``[N, 2]`` matrix of box shapes (width and height). + wh2: An ``[M, 2]`` matrix of box shapes (width and height). + + Returns: + An ``[N, M]`` matrix of pairwise IoU values for every element in ``wh1`` and ``wh2`` + """ + area1 = wh1[:, 0] * wh1[:, 1] # [N] + area2 = wh2[:, 0] * wh2[:, 1] # [M] + + inter_wh = torch.min(wh1[:, None, :], wh2) # [N, M, 2] + inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] + union = area1[:, None] + area2 - inter # [N, M] + + return inter / union + + +def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: + """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target + significantly (IoU greater than ``threshold``). + + Args: + pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. + target_boxes: Corner coordinates of the target boxes. Tensor of size ``[height, width, boxes_per_cell, 4]``. + + Returns: + A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a + target significantly and ``True`` elsewhere. + """ + shape = pred_boxes.shape[:-1] + pred_boxes = pred_boxes.view(-1, 4) + ious = box_iou(pred_boxes, target_boxes) + best_iou = ious.max(-1).values + below_threshold = best_iou <= threshold + return below_threshold.view(shape) + + +def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor: + """Get pairwise truth values of whether the point is inside the box. + + Args: + points: Point (x, y) coordinates, a tensor shaped ``[points, 2]``. + boxes: Box (x1, y1, x2, y2) coordinates, a tensor shaped ``[boxes, 4]``. + + Returns: + A tensor shaped ``[points, boxes]`` containing pairwise truth values of whether the points are inside the boxes. + """ + lt = points[:, None, :] - boxes[None, :, :2] # [boxes, points, 2] + rb = boxes[None, :, 2:] - points[:, None, :] # [boxes, points, 2] + deltas = torch.cat((lt, rb), -1) # [points, boxes, 4] + return deltas.min(-1).values > 0.0 # [points, boxes] + + +def box_size_ratio(wh1: Tensor, wh2: Tensor) -> Tensor: + """Compares the dimensions of the boxes pairwise. + + For each pair of boxes, calculates the largest ratio that can be obtained by dividing the widths with each other or + dividing the heights with each other. + + Args: + wh1: An ``[N, 2]`` matrix of box shapes (width and height). + wh2: An ``[M, 2]`` matrix of box shapes (width and height). + + Returns: + An ``[N, M]`` matrix of ratios of width or height dimensions, whichever is larger. + """ + wh_ratio = wh1[:, None, :] / wh2[None, :, :] # [M, N, 2] + wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio) + wh_ratio = wh_ratio.max(2).values # [M, N] + return wh_ratio diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py new file mode 100644 index 00000000000..b28b89d3b33 --- /dev/null +++ b/torchvision/models/detection/target_matching.py @@ -0,0 +1,452 @@ +from typing import Dict, List, Tuple + +import torch +from torch import Tensor + +from ...ops import box_convert +from .anchor_utils import grid_centers +from .box_utils import aligned_iou, box_size_ratio, iou_below, is_inside_box +from .yolo_loss import YOLOLoss + +PRIOR_SHAPES = List[List[int]] # TorchScript doesn't allow a list of tuples. + + +def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Tensor]: + """Scales target bounding boxes to feature map coordinates. + + It would be better to implement this in a super class, but TorchScript doesn't allow class inheritance. + + Args: + preds: Predicted bounding boxes for a single image. + targets: Target bounding boxes for a single image. + image_size: Input image width and height. + + Returns: + Two tensors with as many rows as there are targets. An integer tensor containing x/y coordinates to the feature + map that correspond to the target position, and a floating point tensor containing the target width and height + scaled to the feature map size. + """ + height, width = preds.shape[:2] + + # A multiplier for scaling image coordinates to feature map coordinates + grid_size = torch.tensor([width, height], device=image_size.device) + image_to_grid = torch.true_divide(grid_size, image_size) + + # Bounding box center coordinates are converted to the feature map dimensions so that the whole number tells the + # cell index and the fractional part tells the location inside the cell. + xywh = box_convert(targets, in_fmt="xyxy", out_fmt="cxcywh") + xy = (xywh[:, :2] * image_to_grid).to(torch.int64) + x = xy[:, 0].clamp(0, width - 1) + y = xy[:, 1].clamp(0, height - 1) + xy = torch.stack((x, y), 1) + return xy, xywh[:, 2:] + + +class HighestIoUMatching: + """For each target, select the prior shape that gives the highest IoU. + + This is the original YOLO matching rule. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + """ + + def __init__( + self, prior_shapes: PRIOR_SHAPES, prior_shape_idxs: List[int], ignore_bg_threshold: float = 0.7 + ) -> None: + self.prior_shapes = prior_shapes + # anchor_map maps the anchor indices to anchors in this layer, or to -1 if it's not an anchor of this layer. + # This layer ignores the target if all the selected anchors are in another layer. + self.anchor_map = [ + prior_shape_idxs.index(idx) if idx in prior_shape_idxs else -1 for idx in range(len(prior_shapes)) + ] + self.ignore_bg_threshold = ignore_bg_threshold + + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=wh.device) + + ious = aligned_iou(wh, prior_wh) + highest_iou_anchors = ious.max(1).indices + highest_iou_anchors = anchor_map[highest_iou_anchors] + matched_targets = highest_iou_anchors >= 0 + matched_anchors = highest_iou_anchors[matched_targets] + return matched_targets, matched_anchors + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[anchor_y, anchor_x, anchor_idx] = False + + pred_selector = [anchor_y, anchor_x, anchor_idx] + return pred_selector, background_mask, target_selector + + +class IoUThresholdMatching: + """For each target, select all prior shapes that give a high enough IoU. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + threshold: IoU treshold for matching. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + """ + + def __init__( + self, + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], + threshold: float, + ignore_bg_threshold: float = 0.7, + ) -> None: + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] + self.threshold = threshold + self.ignore_bg_threshold = ignore_bg_threshold + + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + + ious = aligned_iou(wh, prior_wh) + above_threshold = (ious > self.threshold).nonzero() + return above_threshold[:, 0], above_threshold[:, 1] + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[anchor_y, anchor_x, anchor_idx] = False + + pred_selector = [anchor_y, anchor_x, anchor_idx] + return pred_selector, background_mask, target_selector + + +class SizeRatioMatching: + """For each target, select those prior shapes, whose width and height relative to the target is below given + ratio. + + This is the matching rule used by Ultralytics YOLOv5 implementation. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + threshold: Size ratio threshold for matching. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + """ + + def __init__( + self, + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], + threshold: float, + ignore_bg_threshold: float = 0.7, + ) -> None: + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] + self.threshold = threshold + self.ignore_bg_threshold = ignore_bg_threshold + + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + below_threshold = (box_size_ratio(wh, prior_wh) < self.threshold).nonzero() + return below_threshold[:, 0], below_threshold[:, 1] + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[anchor_y, anchor_x, anchor_idx] = False + + pred_selector = [anchor_y, anchor_x, anchor_idx] + return pred_selector, background_mask, target_selector + + +def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: + """Implements the SimOTA matching rule. + + The number of units supplied by each supplier (training target) needs to be decided in the Optimal Transport + problem. "Dynamic k Estimation" uses the sum of the top 10 IoU values (casted to int) between the target and the + predicted boxes. + + Args: + costs: A ``[predictions, targets]`` matrix of losses. + ious: A ``[predictions, targets]`` matrix of IoUs. + + Returns: + A mask of predictions that were matched, and the indices of the matched targets. The latter contains as many + elements as there are ``True`` values in the mask. + """ + num_preds, num_targets = ious.shape + + matching_matrix = torch.zeros_like(costs, dtype=torch.bool) + + if ious.numel() > 0: + # For each target, define k as the sum of the 10 highest IoUs. + top10_iou = torch.topk(ious, min(10, num_preds), dim=0).values.sum(0) + ks = torch.clip(top10_iou.int(), min=1) + assert len(ks) == num_targets + + # For each target, select k predictions with the lowest cost. + for target_idx, (target_costs, k) in enumerate(zip(costs.T, ks)): + pred_idx = torch.topk(target_costs, k, largest=False).indices + matching_matrix[pred_idx, target_idx] = True + + # If there's more than one match for some prediction, match it with the best target. Now we consider all + # targets, regardless of whether they were originally matched with the prediction or not. + more_than_one_match = matching_matrix.sum(1) > 1 + best_targets = costs[more_than_one_match, :].argmin(1) + matching_matrix[more_than_one_match, :] = False + matching_matrix[more_than_one_match, best_targets] = True + + # For those predictions that were matched, get the index of the target. + pred_mask = matching_matrix.sum(1) > 0 + target_selector = matching_matrix[pred_mask, :].int().argmax(1) + return pred_mask, target_selector + + +class SimOTAMatching: + """Selects which anchors are used to predict each target using the SimOTA matching rule. + + This is the matching rule used by YOLOX. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + loss_func: A ``YOLOLoss`` object that can be used to calculate the pairwise costs. + spatial_range: For each target, restrict to the anchors that are within an `N x N` grid cell are centered at the + target, where `N` is the value of this parameter. + size_range: For each target, restrict to the anchors whose prior dimensions are not larger than the target + dimensions multiplied by this value and not smaller than the target dimensions divided by this value. + """ + + def __init__( + self, + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], + loss_func: YOLOLoss, + spatial_range: float, + size_range: float, + ) -> None: + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] + self.loss_func = loss_func + self.spatial_range = spatial_range + self.size_range = size_range + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """For each target, selects predictions using the SimOTA matching rule. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + A mask of predictions that were matched, background mask (inverse of the first mask), and the indices of the + matched targets. The last tensor contains as many elements as there are ``True`` values in the first mask. + """ + height, width, boxes_per_cell, _ = preds["boxes"].shape + prior_mask, anchor_inside_target = self._get_prior_mask(targets, image_size, width, height, boxes_per_cell) + prior_preds = { + "boxes": preds["boxes"][prior_mask], + "confidences": preds["confidences"][prior_mask], + "classprobs": preds["classprobs"][prior_mask], + } + + losses, ious = self.loss_func.pairwise(prior_preds, targets, input_is_normalized=False) + costs = losses.overlap + losses.confidence + losses.classification + costs += 100000.0 * ~anchor_inside_target + pred_mask, target_selector = _sim_ota_match(costs, ious) + + # Add the anchor dimension to the mask and replace True values with the results of the actual SimOTA matching. + pred_selector = prior_mask.nonzero().T.tolist() + prior_mask[pred_selector] = pred_mask + + background_mask = torch.logical_not(prior_mask) + + return prior_mask, background_mask, target_selector + + def _get_prior_mask( + self, + targets: Dict[str, Tensor], + image_size: Tensor, + grid_width: int, + grid_height: int, + boxes_per_cell: int, + ) -> Tuple[Tensor, Tensor]: + """Creates a mask for selecting the "center prior" anchors. + + In the first step we restrict ourselves to the grid cells whose center is inside or close enough to one or more + targets. + + Args: + targets: Training targets for a single image. + image_size: Input image width and height. + grid_width: Width of the feature grid. + grid_height: Height of the feature grid. + boxes_per_cell: Number of boxes that will be predicted per feature grid cell. + + Returns: + Two masks, a ``[grid_height, grid_width, boxes_per_cell]`` mask for selecting anchors that are close and + similar in shape to a target, and an ``[anchors, targets]`` matrix that indicates which targets are inside + those anchors. + """ + # A multiplier for scaling feature map coordinates to image coordinates + grid_size = torch.tensor([grid_width, grid_height], device=targets["boxes"].device) + grid_to_image = torch.true_divide(image_size, grid_size) + + # Get target center coordinates and dimensions. + xywh = box_convert(targets["boxes"], in_fmt="xyxy", out_fmt="cxcywh") + xy = xywh[:, :2] + wh = xywh[:, 2:] + + # Create a [boxes_per_cell, targets] tensor for selecting prior shapes that are close enough to the target + # dimensions. + prior_wh = torch.tensor(self.prior_shapes, device=targets["boxes"].device) + shape_selector = box_size_ratio(prior_wh, wh) < self.size_range + + # Create a [grid_cells, targets] tensor for selecting spatial locations that are inside target bounding boxes. + centers = grid_centers(grid_size).view(-1, 2) * grid_to_image + inside_selector = is_inside_box(centers, targets["boxes"]) + + # Combine the above selectors into a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are + # inside target bounding boxes and close enough shape. + inside_selector = inside_selector[:, None, :].repeat(1, boxes_per_cell, 1) + inside_selector = torch.logical_and(inside_selector, shape_selector) + + # Set the width and height of all target bounding boxes to self.range grid cells and create a selector for + # anchors that are now inside the boxes. If a small target has no anchors inside its bounding box, it will be + # matched to one of these anchors, but a high penalty will ensure that anchors that are inside the bounding box + # will be preferred. + wh = self.spatial_range * grid_to_image * torch.ones_like(xy) + xywh = torch.cat((xy, wh), -1) + boxes = box_convert(xywh, in_fmt="cxcywh", out_fmt="xyxy") + close_selector = is_inside_box(centers, boxes) + + # Create a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are spatially close to a + # target and whose shape is close enough to the target. + close_selector = close_selector[:, None, :].repeat(1, boxes_per_cell, 1) + close_selector = torch.logical_and(close_selector, shape_selector) + + mask = torch.logical_or(inside_selector, close_selector).sum(-1) > 0 + mask = mask.view(grid_height, grid_width, boxes_per_cell) + inside_selector = inside_selector.view(grid_height, grid_width, boxes_per_cell, -1) + return mask, inside_selector[mask] diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py new file mode 100644 index 00000000000..54ae7b9aee9 --- /dev/null +++ b/torchvision/models/detection/yolo.py @@ -0,0 +1,437 @@ +import warnings +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from ...ops import batched_nms +from ...transforms import functional as F +from .._api import register_model, Weights, WeightsEnum +from .._utils import _ovewrite_value_param +from ..yolo import YOLOV4Backbone +from .backbone_utils import _validate_trainable_layers +from .yolo_networks import DarknetNetwork, PRED, TARGET, TARGETS, YOLOV4Network + +IMAGES = List[Tensor] # TorchScript doesn't allow a tuple. + + +class YOLO(nn.Module): + """YOLO implementation that supports the most important features of YOLOv3, YOLOv4, YOLOv5, YOLOv7, Scaled- + YOLOv4, and YOLOX. + + *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `__ + + *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `__ + + *YOLOv7 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao `__ + + *Scaled-YOLOv4 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao + `__ + + *YOLOX paper*: `Zheng Ge, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun `__ + + The network architecture can be written in PyTorch, or read from a Darknet configuration file using the + :class:`~.yolo_networks.DarknetNetwork` class. ``DarknetNetwork`` is also able to read weights that have been saved + by Darknet. + + The input is expected to be a list of images. Each image is a tensor with shape ``[channels, height, width]``. The + images from a single batch will be stacked into a single tensor, so the sizes have to match. Different batches can + have different image sizes, as long as the size is divisible by the ratio in which the network downsamples the + input. + + During training, the model expects both the image tensors and a list of targets. It's possible to train a model + using one integer class label per box, but the YOLO model supports also multiple labels per box. For multi-label + training, simply use a boolean matrix that indicates which classes are assigned to which boxes, in place of the + class labels. *Each target is a dictionary containing the following tensors*: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format + - labels (``Int64Tensor[N]`` or ``BoolTensor[N, classes]``): the class label or a boolean class mask for each + ground-truth box + + :func:`~.yolo.YOLO.forward` method returns all predictions from all detection layers in one tensor with shape + ``[N, anchors, classes + 5]``, where ``anchors`` is the total number of anchors in all detection layers. The + coordinates are scaled to the input image size. During training it also returns a dictionary containing the + classification, box overlap, and confidence losses. + + During inference, the model requires only the image tensor. :func:`~.yolo.YOLO.infer` method filters and + processes the predictions. If a prediction has a high score for more than one class, it will be duplicated. *The + processed output is returned in a dictionary containing the following tensors*: + + - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image space + - scores (``FloatTensor[N]``): detection confidences + - labels (``Int64Tensor[N]``): the predicted labels for each object + + Detection using a Darknet configuration and pretrained weights: + + >>> from urllib.request import urlretrieve + >>> import torch + >>> from torchvision.models.detection import DarknetNetwork, YOLO + >>> + >>> urlretrieve("https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg", "yolov4-tiny-3l.cfg") + >>> urlretrieve("https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.conv.29", "yolov4-tiny.conv.29") + >>> network = DarknetNetwork("yolov4-tiny-3l.cfg", "yolov4-tiny.conv.29") + >>> model = YOLO(network) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Detection using a predefined YOLOv4 network: + + >>> import torch + >>> from torchvision.models.detection import YOLOV4Network, YOLO + >>> + >>> network = YOLOV4Network(num_classes=91) + >>> model = YOLO(network) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + network: A module that represents the network layers. This can be obtained from a Darknet configuration using + :func:`~.yolo_networks.DarknetNetwork`, or it can be defined as PyTorch code. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + """ + + def __init__( + self, + network: nn.Module, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + ) -> None: + super().__init__() + + self.network = network + self.confidence_threshold = confidence_threshold + self.nms_threshold = nms_threshold + self.detections_per_image = detections_per_image + + def forward( + self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None + ) -> Union[Tensor, Dict[str, Tensor]]: + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets + are provided, computes the losses from the detection layers. + + Detections are concatenated from the detection layers. Each detection layer will produce a number of detections + that depends on the size of the feature map and the number of anchors per feature map cell. + + Args: + images: A tensor of size ``[batch_size, channels, height, width]`` containing a batch of images or a list of + image tensors. + targets: Compute losses against these targets. A list of dictionaries, one for each image. Must be given in + training mode. + + Returns: + If targets are given, returns a dictionary containing the three losses (overlap, confidence, and + classification). Otherwise returns detections in a tensor shaped ``[batch_size, anchors, classes + 5]``, + where ``anchors`` is the total number of anchors in all detection layers. The number of anchors in a + detection layer is the feature map size (width * height) times the number of anchors per cell (usually 3 or + 4). The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + """ + self.validate_batch(images, targets) + images_tensor = images if isinstance(images, Tensor) else torch.stack(images) + detections, losses, hits = self.network(images_tensor, targets) + + if targets is None: + detections = torch.cat(detections, 1) + return detections + + losses = torch.stack(losses).sum(0) + return {"overlap": losses[0], "confidence": losses[1], "classification": losses[2]} + + def infer(self, image: Tensor) -> PRED: + """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class + labels. + + If a prediction has a high score for more than one class, it will be duplicated. + + Args: + image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. + + Returns: + A dictionary containing tensors "boxes", "scores", and "labels". "boxes" is a matrix of detected bounding + box `(x1, y1, x2, y2)` coordinates. "scores" is a vector of confidence scores for the bounding box + detections. "labels" is a vector of predicted class labels. + """ + if not isinstance(image, Tensor): + image = F.to_tensor(image) + + was_training = self.training + self.eval() + + detections = self([image]) + detections = self.process_detections(detections) + detections = detections[0] + + if was_training: + self.train() + return detections + + def process_detections(self, preds: Tensor) -> List[PRED]: + """Splits the detection tensor returned by a forward pass into a list of prediction dictionaries, and + filters them based on confidence threshold, non-maximum suppression (NMS), and maximum number of + predictions. + + If for any single detection there are multiple categories whose score is above the confidence threshold, the + detection will be duplicated to create one detection for each category. NMS processes one category at a time, + iterating over the bounding boxes in descending order of confidence score, and removes lower scoring boxes that + have an IoU greater than the NMS threshold with a higher scoring box. + + The returned detections are sorted by descending confidence. The items of the dictionaries are as follows: + - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates + - scores (``Tensor[batch_size, N]``): detection confidences + - labels (``Int64Tensor[batch_size, N]``): the predicted class IDs + + Args: + preds: A tensor of detected bounding boxes and their attributes. + + Returns: + Filtered detections. A list of prediction dictionaries, one for each image. + """ + + def process(boxes: Tensor, confidences: Tensor, classprobs: Tensor) -> Dict[str, Any]: + scores = classprobs * confidences[:, None] + + # Select predictions with high scores. If a prediction has a high score for more than one class, it will be + # duplicated. + idxs, labels = (scores > self.confidence_threshold).nonzero().T + boxes = boxes[idxs] + scores = scores[idxs, labels] + + keep = batched_nms(boxes, scores, labels, self.nms_threshold) + keep = keep[: self.detections_per_image] + return {"boxes": boxes[keep], "scores": scores[keep], "labels": labels[keep]} + + return [process(p[..., :4], p[..., 4], p[..., 5:]) for p in preds] + + def process_targets(self, targets: TARGETS) -> List[TARGET]: + """Duplicates multi-label targets to create one target for each label. + + Args: + targets: List of target dictionaries. Each dictionary must contain "boxes" and "labels". "labels" is either + a one-dimensional list of class IDs, or a two-dimensional boolean class map. + + Returns: + Single-label targets. A list of target dictionaries, one for each image. + """ + + def process(boxes: Tensor, labels: Tensor, **other: Any) -> Dict[str, Any]: + if labels.ndim == 2: + idxs, labels = labels.nonzero().T + boxes = boxes[idxs] + return {"boxes": boxes, "labels": labels, **other} + + return [process(**t) for t in targets] + + def validate_batch(self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS]) -> None: + """Validates the format of a batch of data. + + Args: + images: A tensor containing a batch of images or a list of image tensors. + targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target + dictionaries as there are images. + """ + if not isinstance(images, Tensor): + if not isinstance(images, (tuple, list)): + raise TypeError(f"Expected images to be a Tensor, tuple, or a list, got {type(images).__name__}.") + if not images: + raise ValueError("No images in batch.") + shape = images[0].shape + for image in images: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") + if image.shape != shape: + raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") + + if targets is None: + if self.training: + raise ValueError("Targets should be given in training mode.") + else: + return + + if not isinstance(targets, (tuple, list)): + raise TypeError(f"Expected targets to be a tuple or a list, got {type(images).__name__}.") + if len(images) != len(targets): + raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") + + for target in targets: + if "boxes" not in target: + raise ValueError("Target dictionary doesn't contain boxes.") + boxes = target["boxes"] + if not isinstance(boxes, Tensor): + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes).__name__}.") + if (boxes.ndim != 2) or (boxes.shape[-1] != 4): + raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") + if "labels" not in target: + raise ValueError("Target dictionary doesn't contain labels.") + labels = target["labels"] + if not isinstance(labels, Tensor): + raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels).__name__}.") + if (labels.ndim < 1) or (labels.ndim > 2) or (len(labels) != len(boxes)): + raise ValueError( + f"Expected target labels to be tensors of shape [N] or [N, num_classes], got {list(labels.shape)}." + ) + + +class YOLOV4_Backbone_Weights(WeightsEnum): + # TODO: Create pretrained weights. + DEFAULT = Weights( + url="", + transforms=lambda x: x, + meta={}, + ) + + +class YOLOV4_Weights(WeightsEnum): + # TODO: Create pretrained weights. + DEFAULT = Weights( + url="", + transforms=lambda x: x, + meta={}, + ) + + +def freeze_backbone_layers(backbone: nn.Module, trainable_layers: Optional[int], is_trained: bool) -> None: + """Freezes backbone layers layers that won't be used for training. + + Args: + backbone: The backbone network. + trainable_layers: Number of trainable layers (stages), starting from the final stage. + is_trained: Set to ``True`` when using pre-trained weights. Otherwise will issue a warning if + ``trainable_layers`` is set. + """ + if not hasattr(backbone, "stages"): + warnings.warn("Cannot freeze backbone layers. Backbone object has no 'stages' attribute.") + num_layers = len(backbone.stages) # type: ignore + trainable_layers = _validate_trainable_layers(is_trained, trainable_layers, num_layers, 3) + + layers_to_train = [f"stages.{idx}" for idx in range(num_layers - trainable_layers, num_layers)] + if trainable_layers == num_layers: + layers_to_train.append("stem") + + for name, parameter in backbone.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + +@register_model() +def yolov4( + weights: Optional[YOLOV4_Weights] = None, + progress: bool = True, + in_channels: int = 3, + num_classes: Optional[int] = None, + weights_backbone: Optional[YOLOV4_Backbone_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + **kwargs: Any, +) -> YOLO: + """ + Constructs a YOLOv4 model. + + .. betastatus:: detection module + + Example: + + >>> import torch + >>> from torchvision.models.detection import yolov4, YOLOV4_Weights + >>> + >>> model = yolov4(weights=YOLOV4_Weights.DEFAULT) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + weights: Pretrained weights to use. See :class:`~.YOLOV4_Weights` below for more details and possible values. By + default, the model will be initialized randomly. + progress: If ``True``, displays a progress bar of the download to ``stderr``. + in_channels: Number of channels in the input image. + num_classes: Number of output classes of the model (including the background). By default, this value is set to + 91 or read from the weights. + weights_backbone: Pretrained weights for the backbone. See :class:`~.YOLOV4_Backbone_Weights` below for more + details and possible values. By default, the backbone will be initialized randomly. + trainable_backbone_layers: Number of trainable (not frozen) layers (stages), starting from the final stage. + Valid values are between 0 and the number of stages in the backbone. By default, this value is set to 3. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + **kwargs: Parameters passed to the ``torchvision.models.detection.YOLOV4Network`` class. Please refer to the + `source code `_ + for more details about this class. + + .. autoclass:: .YOLOV4_Weights + :members: + + .. autoclass:: .YOLOV4_Backbone_Weights + :members: + """ + weights = YOLOV4_Weights.verify(weights) + weights_backbone = YOLOV4_Backbone_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + backbone_kwargs = {key: kwargs[key] for key in ("widths", "activation", "normalization") if key in kwargs} + backbone = YOLOV4Backbone(in_channels, **backbone_kwargs) + + is_trained = weights is not None or weights_backbone is not None + freeze_backbone_layers(backbone, trainable_backbone_layers, is_trained) + + if weights_backbone is not None: + backbone.load_state_dict(weights_backbone.get_state_dict(progress=progress)) + + network = YOLOV4Network(num_classes, backbone, **kwargs) + model = YOLO(network, confidence_threshold, nms_threshold, detections_per_image) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +def yolo_darknet( + config_path: str, + weights_path: Optional[str] = None, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + **kwargs: Any, +) -> YOLO: + """ + Constructs a YOLO model from a Darknet configuration file. + + .. betastatus:: detection module + + Example: + + >>> from urllib.request import urlretrieve + >>> from torchvision.models.detection import yolo_darknet + >>> + >>> urlretrieve("https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg", "yolov4-tiny-3l.cfg") + >>> urlretrieve("https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.conv.29", "yolov4-tiny.conv.29") + >>> model = yolo_darknet("yolov4-tiny-3l.cfg", "yolov4-tiny.conv.29") + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + config_path: Path to a Darknet configuration file that defines the network architecture. + weights_path: Path to a Darknet weights file to load. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + **kwargs: Parameters passed to the ``torchvision.models.detection.DarknetNetwork`` class. Please refer to the + `source code `_ + for more details about this class. + """ + network = DarknetNetwork(config_path, weights_path, **kwargs) + return YOLO(network, confidence_threshold, nms_threshold, detections_per_image) diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py new file mode 100644 index 00000000000..1ac0940680a --- /dev/null +++ b/torchvision/models/detection/yolo_loss.py @@ -0,0 +1,362 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import binary_cross_entropy, binary_cross_entropy_with_logits, one_hot + +from torchvision.ops import ( + box_iou, + complete_box_iou, + complete_box_iou_loss, + distance_box_iou, + distance_box_iou_loss, + generalized_box_iou, + generalized_box_iou_loss, +) + + +def _binary_cross_entropy( + inputs: Tensor, targets: Tensor, reduction: str = "mean", input_is_normalized: bool = True +) -> Tensor: + """Returns the binary cross entropy from either normalized inputs or logits. + + It would be more convenient to pass the correct cross entropy function to every function that uses it, but + TorchScript doesn't allow passing functions. + + Args: + inputs: Probabilities in a tensor of an arbitrary shape. + targets: Targets in a tensor of the same shape as ``input``. + reduction: Specifies the reduction to apply to the output. ``'none'``: no reduction will be applied, ``'mean'``: + the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be + summed. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + """ + if input_is_normalized: + return binary_cross_entropy(inputs, targets, reduction=reduction) + else: + return binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + +def box_iou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor: + return 1.0 - box_iou(boxes1, boxes2).diagonal() + + +def _size_compensation(targets: Tensor, image_size: Tensor) -> Tensor: + """Calcuates the size compensation factor for the overlap loss. + + The overlap losses for each target should be multiplied by the returned weight. The returned value is + `2 - (unit_width * unit_height)`, which is large for small boxes (the maximum value is 2) and small for large boxes + (the minimum value is 1). + + Args: + targets: An ``[N, 4]`` matrix of target `(x1, y1, x2, y2)` coordinates. + image_size: Image size, which is used to scale the target boxes to unit coordinates. + + Returns: + The size compensation factor. + """ + unit_wh = targets[:, 2:] / image_size + return 2 - (unit_wh[:, 0] * unit_wh[:, 1]) + + +def _pairwise_confidence_loss( + preds: Tensor, overlap: Tensor, input_is_normalized: bool, predict_overlap: Optional[float] +) -> Tensor: + """Calculates the confidence loss for every pair of a foreground anchor and a target. + + If ``predict_overlap`` is ``None``, the target confidence will be 1. If ``predict_overlap`` is 1.0, ``overlap`` will + be used as the target confidence. Otherwise this parameter defines a balance between these two targets. The method + returns a vector of losses for each foreground anchor. + + Args: + preds: An ``[N]`` vector of predicted confidences. + overlap: An ``[N, M]`` matrix of overlaps between all predicted and target bounding boxes. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the overlap. + + Returns: + An ``[N, M]`` matrix of confidence losses between all predictions and targets. + """ + if predict_overlap is not None: + # When predicting overlap, target confidence is different for each pair of a prediction and a target. The + # tensors have to be broadcasted to [N, M]. + preds = preds.unsqueeze(1).expand(overlap.shape) + targets = torch.ones_like(preds) - predict_overlap + # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. + targets += predict_overlap * overlap.detach().clamp(min=0) + return _binary_cross_entropy(preds, targets, reduction="none", input_is_normalized=input_is_normalized) + else: + # When not predicting overlap, target confidence is the same for every prediction, but we should still return a + # matrix. + targets = torch.ones_like(preds) + result = _binary_cross_entropy(preds, targets, reduction="none", input_is_normalized=input_is_normalized) + return result.unsqueeze(1).expand(overlap.shape) + + +def _foreground_confidence_loss( + preds: Tensor, overlap: Tensor, input_is_normalized: bool, predict_overlap: Optional[float] +) -> Tensor: + """Calculates the sum of the confidence losses for foreground anchors and their matched targets. + + If ``predict_overlap`` is ``None``, the target confidence will be 1. If ``predict_overlap`` is 1.0, ``overlap`` will + be used as the target confidence. Otherwise this parameter defines a balance between these two targets. The method + returns a vector of losses for each foreground anchor. + + Args: + preds: A vector of predicted confidences. + overlap: A vector of overlaps between matched target and predicted bounding boxes. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1, and 1.0 means that the target confidence is the overlap. + + Returns: + The sum of the confidence losses for foreground anchors. + """ + targets = torch.ones_like(preds) + if predict_overlap is not None: + targets -= predict_overlap + # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. + targets += predict_overlap * overlap.detach().clamp(min=0) + return _binary_cross_entropy(preds, targets, reduction="sum", input_is_normalized=input_is_normalized) + + +def _background_confidence_loss(preds: Tensor, input_is_normalized: bool) -> Tensor: + """Calculates the sum of the confidence losses for background anchors. + + Args: + preds: A vector of predicted confidences for background anchors. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + + Returns: + The sum of the background confidence losses. + """ + targets = torch.zeros_like(preds) + return _binary_cross_entropy(preds, targets, reduction="sum", input_is_normalized=input_is_normalized) + + +def _target_labels_to_probs( + targets: Tensor, num_classes: int, dtype: torch.dtype, label_smoothing: Optional[float] = None +) -> Tensor: + """If ``targets`` is a vector of class labels, converts it to a matrix of one-hot class probabilities. + + If label smoothing is disabled, the returned target probabilities will be binary. If label smoothing is enabled, the + target probabilities will be, ``(label_smoothing / 2)`` or ``(label_smoothing / 2) + (1.0 - label_smoothing)``. That + corresponds to label smoothing with two categories, since the YOLO model does multi-label classification. + + Args: + targets: An ``[M, C]`` matrix of target class probabilities or an ``[M]`` vector of class labels. + num_classes: The number of classes (C dimension) for the new targets. If ``targets`` is already two-dimensional, + checks that the length of the second dimension matches this number. + dtype: Floating-point data type to be used for the one-hot targets. + label_smoothing: The epsilon parameter (weight) for label smoothing. 0.0 means no smoothing (binary targets), + and 1.0 means that the target probabilities are always 0.5. + + Returns: + An ``[M, C]`` matrix of target class probabilities. + """ + if targets.ndim == 1: + # The data may contain a different number of classes than what the model predicts. In case a label is + # greater than the number of predicted classes, it will be mapped to the last class. + last_class = torch.tensor(num_classes - 1, device=targets.device) + targets = torch.min(targets, last_class) + targets = one_hot(targets, num_classes) + elif targets.shape[-1] != num_classes: + raise ValueError( + f"The number of classes in the data ({targets.shape[-1]}) doesn't match the number of classes " + f"predicted by the model ({num_classes})." + ) + targets = targets.to(dtype=dtype) + if label_smoothing is not None: + targets = (label_smoothing / 2) + targets * (1.0 - label_smoothing) + return targets + + +@torch.jit.script +@dataclass +class Losses: + overlap: Tensor + confidence: Tensor + classification: Tensor + + +class YOLOLoss: + """A class for calculating the YOLO losses from predictions and targets. + + If label smoothing is enabled, the target class probabilities will be ``(label_smoothing / 2)`` or + ``(label_smoothing / 2) + (1.0 - label_smoothing)``, instead of 0 or 1. That corresponds to label smoothing with two + categories, since the YOLO model does multi-label classification. + + Args: + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + """ + + def __init__( + self, + overlap_func: str = "ciou", + predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, + overlap_multiplier: float = 5.0, + confidence_multiplier: float = 1.0, + class_multiplier: float = 1.0, + ): + self.overlap_func = overlap_func + self.predict_overlap = predict_overlap + self.label_smoothing = label_smoothing + self.overlap_multiplier = overlap_multiplier + self.confidence_multiplier = confidence_multiplier + self.class_multiplier = class_multiplier + + def pairwise( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + input_is_normalized: bool, + ) -> Tuple[Losses, Tensor]: + """Calculates matrices containing the losses for all prediction/target pairs. + + This method is called for obtaining costs for SimOTA matching. + + Args: + preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". Each tensor + contains `N` rows. + targets: A dictionary of training targets, containing "boxes" and "labels". Each tensor contains `M` rows. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + + Returns: + Loss matrices and an overlap matrix. Each matrix is shaped ``[N, M]``. + """ + loss_shape = torch.Size([len(preds["boxes"]), len(targets["boxes"])]) + + overlap = self._pairwise_overlap(preds["boxes"], targets["boxes"]) + assert overlap.shape == loss_shape + + overlap_loss = 1.0 - overlap + assert overlap_loss.shape == loss_shape + + confidence_loss = _pairwise_confidence_loss( + preds["confidences"], overlap, input_is_normalized, self.predict_overlap + ) + assert confidence_loss.shape == loss_shape + + pred_probs = preds["classprobs"].unsqueeze(1) # [N, 1, classes] + target_probs = _target_labels_to_probs( + targets["labels"], pred_probs.shape[-1], pred_probs.dtype, self.label_smoothing + ) + target_probs = target_probs.unsqueeze(0) # [1, M, classes] + pred_probs, target_probs = torch.broadcast_tensors(pred_probs, target_probs) + class_loss = _binary_cross_entropy( + pred_probs, target_probs, reduction="none", input_is_normalized=input_is_normalized + ) + class_loss = class_loss.sum(-1) + assert class_loss.shape == loss_shape + + losses = Losses( + overlap_loss * self.overlap_multiplier, + confidence_loss * self.confidence_multiplier, + class_loss * self.class_multiplier, + ) + + return losses, overlap + + def elementwise_sums( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + input_is_normalized: bool, + image_size: Tensor, + ) -> Losses: + """Calculates the sums of the losses for optimization, over prediction/target pairs, assuming the + predictions and targets have been matched (there are as many predictions and targets). + + Args: + preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". + targets: A dictionary of training targets, containing "boxes" and "labels". + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + image_size: Width and height in a vector that defines the scale of the target coordinates. + + Returns: + The final losses. + """ + overlap_loss = self._elementwise_overlap_loss(targets["boxes"], preds["boxes"]) + overlap = 1.0 - overlap_loss + overlap_loss = (overlap_loss * _size_compensation(targets["boxes"], image_size)).sum() + + confidence_loss = _foreground_confidence_loss( + preds["confidences"], overlap, input_is_normalized, self.predict_overlap + ) + confidence_loss += _background_confidence_loss(preds["bg_confidences"], input_is_normalized) + + pred_probs = preds["classprobs"] + target_probs = _target_labels_to_probs( + targets["labels"], pred_probs.shape[-1], pred_probs.dtype, self.label_smoothing + ) + class_loss = _binary_cross_entropy( + pred_probs, target_probs, reduction="sum", input_is_normalized=input_is_normalized + ) + + losses = Losses( + overlap_loss * self.overlap_multiplier, + confidence_loss * self.confidence_multiplier, + class_loss * self.class_multiplier, + ) + + return losses + + def _pairwise_overlap(self, boxes1: Tensor, boxes2: Tensor) -> Tensor: + """Returns the pairwise intersection-over-union values between two sets of boxes. + + Uses the IoU function specified in ``self.overlap_func``. It would be better to save the function in a variable, + but TorchScript doesn't allow this. + + Args: + boxes1: first set of boxes + boxes2: second set of boxes + + Returns: + A matrix containing the pairwise IoU values for every element in ``boxes1`` and ``boxes2``. + """ + if self.overlap_func == "iou": + return box_iou(boxes1, boxes2) + elif self.overlap_func == "giou": + return generalized_box_iou(boxes1, boxes2) + elif self.overlap_func == "diou": + return distance_box_iou(boxes1, boxes2) + elif self.overlap_func == "ciou": + return complete_box_iou(boxes1, boxes2) + else: + raise ValueError(f"Unknown IoU function '{self.overlap_func}'.") + + def _elementwise_overlap_loss(self, boxes1: Tensor, boxes2: Tensor) -> Tensor: + """Returns the elementwise intersection-over-union losses between two sets of boxes. + + Uses the IoU loss function specified in ``self.overlap_func``. It would be better to save the function in a + variable, but TorchScript doesn't allow this. + + Args: + boxes1: first set of boxes + boxes2: second set of boxes + + Returns: + A vector containing the IoU losses between corresponding elements in ``boxes1`` and ``boxes2``. + """ + if self.overlap_func == "iou": + return box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "giou": + return generalized_box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "diou": + return distance_box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "ciou": + return complete_box_iou_loss(boxes1, boxes2) + else: + raise ValueError(f"Unknown IoU function '{self.overlap_func}'.") diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py new file mode 100644 index 00000000000..dd395f95af5 --- /dev/null +++ b/torchvision/models/detection/yolo_networks.py @@ -0,0 +1,2025 @@ +import io +import re +from collections import OrderedDict +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from ...ops import box_convert +from ..yolo import ( + Conv, + CSPSPP, + CSPStage, + ELANStage, + FastSPP, + MaxPool, + RouteLayer, + ShortcutLayer, + YOLOV4Backbone, + YOLOV4TinyBackbone, + YOLOV5Backbone, + YOLOV7Backbone, +) +from .anchor_utils import global_xy +from .target_matching import HighestIoUMatching, IoUThresholdMatching, PRIOR_SHAPES, SimOTAMatching, SizeRatioMatching +from .yolo_loss import YOLOLoss + +DARKNET_CONFIG = Dict[str, Any] +CREATE_LAYER_OUTPUT = Tuple[nn.Module, int] # layer, num_outputs +PRED = Dict[str, Tensor] +PREDS = List[PRED] # TorchScript doesn't allow a tuple +TARGET = Dict[str, Tensor] +TARGETS = List[TARGET] # TorchScript doesn't allow a tuple +NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits + + +class DetectionLayer(nn.Module): + """A YOLO detection layer. + + A YOLO model has usually 1 - 3 detection layers at different resolutions. The loss is summed from all of them. + + Args: + num_classes: Number of different classes that this layer predicts. + prior_shapes: A list of prior box dimensions for this layer, used for scaling the predicted dimensions. The list + should contain [width, height] pairs in the network input resolution. + matching_func: The matching algorithm to be used for assigning targets to anchors. + loss_func: ``YOLOLoss`` object for calculating the losses. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the + detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and + height are scaled up so that the maximum value is four times the anchor dimension. This is used by the + Darknet configurations of Scaled-YOLOv4. + """ + + def __init__( + self, + num_classes: int, + prior_shapes: PRIOR_SHAPES, + matching_func: Callable, + loss_func: YOLOLoss, + xy_scale: float = 1.0, + input_is_normalized: bool = False, + ) -> None: + super().__init__() + + self.num_classes = num_classes + self.prior_shapes = prior_shapes + self.matching_func = matching_func + self.loss_func = loss_func + self.xy_scale = xy_scale + self.input_is_normalized = input_is_normalized + + def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, PREDS]: + """Runs a forward pass through this YOLO detection layer. + + Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the + anchors, converts the center coordinates to corner coordinates, and maps probabilities to the `]0, 1[` range + using sigmoid. + + If targets are given, computes also losses from the predictions and the targets. This layer is responsible only + for the targets that best match one of the anchors assigned to this layer. Training losses will be saved to the + ``losses`` attribute. ``hits`` attribute will be set to the number of targets that this layer was responsible + for. ``losses`` is a tensor of three elements: the overlap, confidence, and classification loss. + + Args: + x: The output from the previous layer. The size of this tensor has to be + ``[batch_size, anchors_per_cell * (num_classes + 5), height, width]``. + image_size: Image width and height in a vector (defines the scale of the predicted and target coordinates). + + Returns: + The layer output, with normalized probabilities, in a tensor sized + ``[batch_size, anchors_per_cell * height * width, num_classes + 5]`` and a list of dictionaries, containing + the same predictions, but with unnormalized probabilities (for loss calculation). + """ + batch_size, num_features, height, width = x.shape + num_attrs = self.num_classes + 5 + anchors_per_cell = num_features // num_attrs + if anchors_per_cell != len(self.prior_shapes): + raise ValueError( + "The model predicts {} bounding boxes per spatial location, but {} prior box dimensions are defined " + "for this layer.".format(anchors_per_cell, len(self.prior_shapes)) + ) + + # Reshape the output to have the bounding box attributes of each grid cell on its own row. + x = x.permute(0, 2, 3, 1) # [batch_size, height, width, anchors_per_cell * num_attrs] + x = x.view(batch_size, height, width, anchors_per_cell, num_attrs) + + # Take the sigmoid of the bounding box coordinates, confidence score, and class probabilities, unless the input + # is normalized by the previous layer activation. Confidence and class losses use the unnormalized values if + # possible. + norm_x = x if self.input_is_normalized else torch.sigmoid(x) + xy = norm_x[..., :2] + wh = x[..., 2:4] + confidence = x[..., 4] + classprob = x[..., 5:] + norm_confidence = norm_x[..., 4] + norm_classprob = norm_x[..., 5:] + + # Eliminate grid sensitivity. The previous layer should output extremely high values for the sigmoid to produce + # x/y coordinates close to one. YOLOv4 solves this by scaling the x/y coordinates. + xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) + + image_xy = global_xy(xy, image_size) + prior_shapes = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + if self.input_is_normalized: + image_wh = 4 * torch.square(wh) * prior_shapes + else: + image_wh = torch.exp(wh) * prior_shapes + box = torch.cat((image_xy, image_wh), -1) + box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy") + output = torch.cat((box, norm_confidence.unsqueeze(-1), norm_classprob), -1) + output = output.reshape(batch_size, height * width * anchors_per_cell, num_attrs) + + # It's better to use binary_cross_entropy_with_logits() for loss computation, so we'll provide the unnormalized + # confidence and classprob, when available. + preds = [{"boxes": b, "confidences": c, "classprobs": p} for b, c, p in zip(box, confidence, classprob)] + + return output, preds + + def match_targets( + self, + preds: PREDS, + return_preds: PREDS, + targets: TARGETS, + image_size: Tensor, + ) -> Tuple[PRED, TARGET]: + """Matches the predictions to targets. + + Args: + preds: List of predictions for each image, as returned by the ``forward()`` method of this layer. These will + be matched to the training targets. + return_preds: List of predictions for each image. The matched predictions will be returned from this list. + When calculating the auxiliary loss for deep supervision, predictions from a different layer are used + for loss computation. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + + Returns: + Two dictionaries, the matched predictions and targets. + """ + batch_size = len(preds) + if (len(targets) != batch_size) or (len(return_preds) != batch_size): + raise ValueError("Different batch size for predictions and targets.") + + # Creating lists that are concatenated in the end will confuse TorchScript compilation. Instead, we'll create + # tensors and concatenate new matches immediately. + pred_boxes = torch.empty((0, 4), device=return_preds[0]["boxes"].device) + pred_confidences = torch.empty(0, device=return_preds[0]["confidences"].device) + pred_bg_confidences = torch.empty(0, device=return_preds[0]["confidences"].device) + pred_classprobs = torch.empty((0, self.num_classes), device=return_preds[0]["classprobs"].device) + target_boxes = torch.empty((0, 4), device=targets[0]["boxes"].device) + target_labels = torch.empty(0, dtype=torch.int64, device=targets[0]["labels"].device) + + for image_preds, image_return_preds, image_targets in zip(preds, return_preds, targets): + if image_targets["boxes"].shape[0] > 0: + pred_selector, background_selector, target_selector = self.matching_func( + image_preds, image_targets, image_size + ) + pred_boxes = torch.cat((pred_boxes, image_return_preds["boxes"][pred_selector])) + pred_confidences = torch.cat((pred_confidences, image_return_preds["confidences"][pred_selector])) + pred_bg_confidences = torch.cat( + (pred_bg_confidences, image_return_preds["confidences"][background_selector]) + ) + pred_classprobs = torch.cat((pred_classprobs, image_return_preds["classprobs"][pred_selector])) + target_boxes = torch.cat((target_boxes, image_targets["boxes"][target_selector])) + target_labels = torch.cat((target_labels, image_targets["labels"][target_selector])) + else: + pred_bg_confidences = torch.cat((pred_bg_confidences, image_return_preds["confidences"].flatten())) + + matched_preds = { + "boxes": pred_boxes, + "confidences": pred_confidences, + "bg_confidences": pred_bg_confidences, + "classprobs": pred_classprobs, + } + matched_targets = { + "boxes": target_boxes, + "labels": target_labels, + } + return matched_preds, matched_targets + + def calculate_losses( + self, + preds: PREDS, + targets: TARGETS, + image_size: Tensor, + loss_preds: Optional[PREDS] = None, + ) -> Tuple[Tensor, int]: + """Matches the predictions to targets and computes the losses. + + Args: + preds: List of predictions for each image, as returned by ``forward()``. These will be matched to the + training targets and used to compute the losses (unless another set of predictions for loss computation + is given in ``loss_preds``). + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + loss_preds: List of predictions for each image. If given, these will be used for loss computation, instead + of the same predictions that were used for matching. This is needed for deep supervision in YOLOv7. + + Returns: + A vector of the overlap, confidence, and classification loss, normalized by batch size, and the number of + targets that were matched to this layer. + """ + if loss_preds is None: + loss_preds = preds + + matched_preds, matched_targets = self.match_targets(preds, loss_preds, targets, image_size) + + losses = self.loss_func.elementwise_sums(matched_preds, matched_targets, self.input_is_normalized, image_size) + losses = torch.stack((losses.overlap, losses.confidence, losses.classification)) / len(preds) + + hits = len(matched_targets["boxes"]) + + return losses, hits + + +def create_detection_layer( + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], + matching_algorithm: Optional[str] = None, + matching_threshold: Optional[float] = None, + spatial_range: float = 5.0, + size_range: float = 4.0, + ignore_bg_threshold: float = 0.7, + overlap_func: str = "ciou", + predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, + overlap_loss_multiplier: float = 5.0, + confidence_loss_multiplier: float = 1.0, + class_loss_multiplier: float = 1.0, + **kwargs: Any, +) -> DetectionLayer: + """Creates a detection layer module and the required loss function and target matching objects. + + Args: + prior_shapes: A list of all the prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + num_classes: Number of different classes that this layer predicts. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the + detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and + height are scaled up so that the maximum value is four times the anchor dimension. This is used by the + Darknet configurations of Scaled-YOLOv4. + """ + matching_func: Callable + if matching_algorithm == "simota": + loss_func = YOLOLoss( + overlap_func, None, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier + ) + matching_func = SimOTAMatching(prior_shapes, prior_shape_idxs, loss_func, spatial_range, size_range) + elif matching_algorithm == "size": + if matching_threshold is None: + raise ValueError("matching_threshold is required with size ratio matching.") + matching_func = SizeRatioMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold) + elif matching_algorithm == "iou": + if matching_threshold is None: + raise ValueError("matching_threshold is required with IoU threshold matching.") + matching_func = IoUThresholdMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold) + elif matching_algorithm == "maxiou" or matching_algorithm is None: + matching_func = HighestIoUMatching(prior_shapes, prior_shape_idxs, ignore_bg_threshold) + else: + raise ValueError(f"Matching algorithm `{matching_algorithm}´ is unknown.") + + loss_func = YOLOLoss( + overlap_func, + predict_overlap, + label_smoothing, + overlap_loss_multiplier, + confidence_loss_multiplier, + class_loss_multiplier, + ) + layer_shapes = [prior_shapes[i] for i in prior_shape_idxs] + return DetectionLayer(prior_shapes=layer_shapes, matching_func=matching_func, loss_func=loss_func, **kwargs) + + +class DetectionStage(nn.Module): + """This is a convenience class for running a detection layer. + + It might be cleaner to implement this as a function, but TorchScript allows only specific types in function + arguments, not modules. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.detection_layer = create_detection_layer(**kwargs) + + def forward( + self, + layer_input: Tensor, + targets: Optional[TARGETS], + image_size: Tensor, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], + ) -> None: + """Runs the detection layer on the inputs and appends the output to the ``detections`` list. + + If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. + + Args: + layer_input: Input to the detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is + given. + """ + output, preds = self.detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + +class DetectionStageWithAux(nn.Module): + """This class represents a combination of a lead and an auxiliary detection layer. + + Args: + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target. This parameter specifies `N` for the lead head. + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target. This parameter specifies `N` for the auxiliary head. + aux_weight: Weight for the loss from the auxiliary head. + """ + + def __init__( + self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any + ) -> None: + super().__init__() + self.detection_layer = create_detection_layer(spatial_range=spatial_range, **kwargs) + self.aux_detection_layer = create_detection_layer(spatial_range=aux_spatial_range, **kwargs) + self.aux_weight = aux_weight + + def forward( + self, + layer_input: Tensor, + aux_input: Tensor, + targets: Optional[TARGETS], + image_size: Tensor, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], + ) -> None: + """Runs the detection layer and the auxiliary detection layer on their respective inputs and appends the + outputs to the ``detections`` list. + + If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. + + Args: + layer_input: Input to the lead detection layer. + aux_input: Input to the auxiliary detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is + given. + """ + output, preds = self.detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + # Match lead head predictions to targets and calculate losses from lead head outputs. + layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + # Match lead head predictions to targets and calculate losses from auxiliary head outputs. + _, aux_preds = self.aux_detection_layer(aux_input, image_size) + layer_losses, layer_hits = self.aux_detection_layer.calculate_losses( + preds, targets, image_size, loss_preds=aux_preds + ) + losses.append(layer_losses * self.aux_weight) + hits.append(layer_hits) + + +@torch.jit.script +def get_image_size(images: Tensor) -> Tensor: + """Get the image size from an input tensor. + + The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based + generator will loose track of e.g. ``images.shape[1]`` and treat it as a Python variable and not a tensor. This will + cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes. + + Args: + images: An image batch to take the width and height from. + + Returns: + A tensor that contains the image width and height. + """ + height = images.shape[2] + width = images.shape[3] + return torch.tensor([width, height], device=images.device) + + +class YOLOV4TinyNetwork(nn.Module): + """The "tiny" network architecture from YOLOv4. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: The number of channels in the narrowest convolutional layer. The wider convolutional layers will use a + number of channels that is a multiple of this value. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 32, + activation: Optional[str] = "leaky", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=normalization) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def outputs(in_channels: int) -> nn.Module: + return nn.Conv2d(in_channels, num_outputs, kernel_size=1, stride=1, bias=True) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: + assert prior_shapes is not None + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + self.backbone = backbone or YOLOV4TinyBackbone(width=width, activation=activation, normalization=normalization) + + self.fpn5 = conv(width * 16, width * 8) + self.out5 = nn.Sequential( + OrderedDict( + [ + ("channels", conv(width * 8, width * 16)), + (f"outputs_{num_outputs}", outputs(width * 16)), + ] + ) + ) + self.upsample5 = upsample(width * 8, width * 4) + + self.fpn4 = conv(width * 12, width * 8, kernel_size=3) + self.out4 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 8))])) + self.upsample4 = upsample(width * 8, width * 2) + + self.fpn3 = conv(width * 6, width * 4, kernel_size=3) + self.out3 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 4))])) + + self.detect3 = detect([0, 1, 2]) + self.detect4 = detect([3, 4, 5]) + self.detect5 = detect([6, 7, 8]) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, c5 = self.backbone(x)[-3:] + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample5(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), c3), dim=1) + p3 = self.fpn3(x) + + self.detect5(self.out5(p5), targets, image_size, detections, losses, hits) + self.detect4(self.out4(p4), targets, image_size, detections, losses, hits) + self.detect3(self.out3(p3), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV4Network(nn.Module): + """Network architecture that corresponds approximately to the Cross Stage Partial Network from YOLOv4. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=2, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int) -> nn.Module: + conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: + assert prior_shapes is not None + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV4Backbone(widths=widths, activation=activation, normalization=normalization) + + w3 = widths[-3] + w4 = widths[-2] + w5 = widths[-1] + + self.spp = spp(w5, w5) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5, w4 // 2) + self.fpn4 = csp(w4, w4) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4, w3 // 2) + self.fpn3 = csp(w3, w3) + + self.downsample3 = downsample(w3, w3) + self.pan4 = csp(w3 + w4, w4) + + self.downsample4 = downsample(w4, w4) + self.pan5 = csp(w4 + w5, w5) + + self.out3 = out(w3) + self.out4 = out(w4) + self.out5 = out(w5) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + x = torch.cat((self.upsample5(c5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), c5), dim=1) + n5 = self.pan5(x) + + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV4P6Network(nn.Module): + """Network architecture that corresponds approximately to the variant of YOLOv4 with four detection layers. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `4N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + [13, 17], + [31, 25], + [24, 51], + [61, 45], + [61, 45], + [48, 102], + [119, 96], + [97, 189], + [97, 189], + [217, 184], + [171, 384], + [324, 451], + [324, 451], + [545, 357], + [616, 618], + [1024, 1024], + ] + anchors_per_cell = 4 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 4) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 4.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=2, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int) -> nn.Module: + conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: + assert prior_shapes is not None + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV4Backbone( + widths=widths, depths=(1, 1, 3, 15, 15, 7, 7), activation=activation, normalization=normalization + ) + + w3 = widths[-4] + w4 = widths[-3] + w5 = widths[-2] + w6 = widths[-1] + + self.spp = spp(w6, w6) + + self.pre5 = conv(w5, w5 // 2) + self.upsample6 = upsample(w6, w5 // 2) + self.fpn5 = csp(w5, w5) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5, w4 // 2) + self.fpn4 = csp(w4, w4) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4, w3 // 2) + self.fpn3 = csp(w3, w3) + + self.downsample3 = downsample(w3, w3) + self.pan4 = csp(w3 + w4, w4) + + self.downsample4 = downsample(w4, w4) + self.pan5 = csp(w4 + w5, w5) + + self.downsample5 = downsample(w5, w5) + self.pan6 = csp(w5 + w6, w6) + + self.out3 = out(w3) + self.out4 = out(w4) + self.out5 = out(w5) + self.out6 = out(w6) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, c5, x = self.backbone(x)[-4:] + c6 = self.spp(x) + + x = torch.cat((self.upsample6(c6), self.pre5(c5)), dim=1) + p5 = self.fpn5(x) + x = torch.cat((self.upsample5(p5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + x = torch.cat((self.downsample5(n5), c6), dim=1) + n6 = self.pan6(x) + + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) + self.detect6(self.out6(n6), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV5Network(nn.Module): + """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` + and ``width`` parameters. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 16 (yolov5n), 32 + (yolov5s), 48 (yolov5m), 64 (yolov5l), and 80 (yolov5x). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolov5n, yolov5s), 2 (yolov5m), 3 (yolov5l), and 4 (yolov5x). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return FastSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def out(in_channels: int) -> nn.Module: + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs)])) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=depth, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: + assert prior_shapes is not None + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + self.backbone = backbone or YOLOV5Backbone( + depth=depth, width=width, activation=activation, normalization=normalization + ) + + self.spp = spp(width * 16, width * 16) + + self.pan3 = csp(width * 8, width * 4) + self.out3 = out(width * 4) + + self.fpn4 = nn.Sequential( + OrderedDict( + [ + ("csp", csp(width * 16, width * 8)), + ("conv", conv(width * 8, width * 4)), + ] + ) + ) + self.pan4 = csp(width * 8, width * 8) + self.out4 = out(width * 8) + + self.fpn5 = conv(width * 16, width * 8) + self.pan5 = csp(width * 16, width * 16) + self.out5 = out(width * 16) + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + self.downsample3 = downsample(width * 4, width * 4) + self.downsample4 = downsample(width * 8, width * 8) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample(p4), c3), dim=1) + + n3 = self.pan3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV7Network(nn.Module): + """Network architecture that corresponds to the W6 variant of YOLOv7 with four detection layers. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `4N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target. This parameter specifies `N` for the lead head. + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target. This parameter specifies `N` for the auxiliary head. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + aux_weight: Weight for the loss from the auxiliary heads. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (64, 128, 256, 512, 768, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + [13, 17], + [31, 25], + [24, 51], + [61, 45], + [61, 45], + [48, 102], + [119, 96], + [97, 189], + [97, 189], + [217, 184], + [171, 384], + [324, 451], + [324, 451], + [545, 357], + [616, 618], + [1024, 1024], + ] + anchors_per_cell = 4 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 4) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 4.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def elan(in_channels: int, out_channels: int) -> nn.Module: + return ELANStage( + in_channels, + out_channels, + split_channels=out_channels, + depth=4, + block_depth=1, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int, hidden_channels: int) -> nn.Module: + conv = Conv( + in_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=normalization + ) + outputs = nn.Conv2d(hidden_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStageWithAux: + assert prior_shapes is not None + return DetectionStageWithAux( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV7Backbone( + widths=widths, depth=2, block_depth=2, activation=activation, normalization=normalization + ) + + w3 = widths[-4] + w4 = widths[-3] + w5 = widths[-2] + w6 = widths[-1] + + self.spp = spp(w6, w6 // 2) + + self.pre5 = conv(w5, w5 // 2) + self.upsample6 = upsample(w6 // 2, w5 // 2) + self.fpn5 = elan(w5, w5 // 2) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5 // 2, w4 // 2) + self.fpn4 = elan(w4, w4 // 2) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4 // 2, w3 // 2) + self.fpn3 = elan(w3, w3 // 2) + + self.downsample3 = downsample(w3 // 2, w4 // 2) + self.pan4 = elan(w4, w4 // 2) + + self.downsample4 = downsample(w4 // 2, w5 // 2) + self.pan5 = elan(w5, w5 // 2) + + self.downsample5 = downsample(w5 // 2, w6 // 2) + self.pan6 = elan(w6, w6 // 2) + + self.out3 = out(w3 // 2, w3) + self.aux_out3 = out(w3 // 2, w3 + (w3 // 4)) + self.out4 = out(w4 // 2, w4) + self.aux_out4 = out(w4 // 2, w4 + (w4 // 4)) + self.out5 = out(w5 // 2, w5) + self.aux_out5 = out(w5 // 2, w5 + (w5 // 4)) + self.out6 = out(w6 // 2, w6) + self.aux_out6 = out(w6 // 2, w6 + (w6 // 4)) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, c5, x = self.backbone(x)[-4:] + c6 = self.spp(x) + + x = torch.cat((self.upsample6(c6), self.pre5(c5)), dim=1) + p5 = self.fpn5(x) + x = torch.cat((self.upsample5(p5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + x = torch.cat((self.downsample5(n5), c6), dim=1) + n6 = self.pan6(x) + + self.detect3(self.out3(n3), self.aux_out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), self.aux_out4(p4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), self.aux_out5(p5), targets, image_size, detections, losses, hits) + self.detect6(self.out6(n6), self.aux_out6(c6), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOXHead(nn.Module): + """A module that produces features for YOLO detection layer, decoupling the classification and localization + features. + + Args: + in_channels: Number of input channels that the module expects. + hidden_channels: Number of output channels in the hidden layers. + anchors_per_cell: Number of detections made at each spatial location of the feature map. + num_classes: Number of different classes that this model predicts. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + anchors_per_cell: int, + num_classes: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=norm) + + def linear(in_channels: int, out_channels: int) -> nn.Module: + return nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def features(num_channels: int) -> nn.Module: + return nn.Sequential( + conv(num_channels, num_channels, kernel_size=3), + conv(num_channels, num_channels, kernel_size=3), + ) + + def classprob(num_channels: int) -> nn.Module: + num_outputs = anchors_per_cell * num_classes + outputs = linear(num_channels, num_outputs) + return nn.Sequential(OrderedDict([("convs", features(num_channels)), (f"outputs_{num_outputs}", outputs)])) + + self.stem = conv(in_channels, hidden_channels) + self.feat = features(hidden_channels) + self.box = linear(hidden_channels, anchors_per_cell * 4) + self.confidence = linear(hidden_channels, anchors_per_cell) + self.classprob = classprob(hidden_channels) + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + features = self.feat(x) + box = self.box(features) + confidence = self.confidence(features) + classprob = self.classprob(x) + return torch.cat((box, confidence, classprob), dim=1) + + +class YOLOXNetwork(nn.Module): + """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the + ``depth`` and ``width`` parameters. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 24 (yolox-tiny), + 32 (yolox-s), 48 (yolox-m), and 64 (yolox-l). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolox-tiny, yolox-s), 2 (yolox-m), and 3 (yolox-l). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[PRIOR_SHAPES] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use one anchor per cell and the stride as the prior size. + if prior_shapes is None: + prior_shapes = [[8, 8], [16, 16], [32, 32]] + anchors_per_cell = 1 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return FastSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=depth, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def head(in_channels: int, hidden_channels: int) -> YOLOXHead: + return YOLOXHead( + in_channels, + hidden_channels, + anchors_per_cell, + num_classes, + activation=activation, + norm=normalization, + ) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: + assert prior_shapes is not None + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + self.backbone = backbone or YOLOV5Backbone( + depth=depth, width=width, activation=activation, normalization=normalization + ) + + self.spp = spp(width * 16, width * 16) + + self.pan3 = csp(width * 8, width * 4) + self.out3 = head(width * 4, width * 4) + + self.fpn4 = nn.Sequential( + OrderedDict( + [ + ("csp", csp(width * 16, width * 8)), + ("conv", conv(width * 8, width * 4)), + ] + ) + ) + self.pan4 = csp(width * 8, width * 8) + self.out4 = head(width * 8, width * 4) + + self.fpn5 = conv(width * 16, width * 8) + self.pan5 = csp(width * 16, width * 16) + self.out5 = head(width * 16, width * 4) + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + self.downsample3 = downsample(width * 4, width * 4) + self.downsample4 = downsample(width * 8, width * 8) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample(p4), c3), dim=1) + + n3 = self.pan3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class DarknetNetwork(nn.Module): + """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. + + Iterates through the layers from the configuration and creates corresponding PyTorch modules. If ``weights_path`` is + given and points to a Darknet model file, loads the convolutional layer weights from the file. + + Args: + config_path: Path to a Darknet configuration file that defines the network architecture. + weights_path: Path to a Darknet model file. If given, the model weights will be read from this file. + in_channels: Number of channels in the input image. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + """ + + def __init__( + self, config_path: str, weights_path: Optional[str] = None, in_channels: Optional[int] = None, **kwargs: Any + ) -> None: + super().__init__() + + with open(config_path) as config_file: + sections = self._read_config(config_file) + + if len(sections) < 2: + raise ValueError("The model configuration file should include at least two sections.") + + self.__dict__.update(sections[0]) + global_config = sections[0] + layer_configs = sections[1:] + + if in_channels is None: + in_channels = global_config.get("channels", 3) + assert isinstance(in_channels, int) + + self.layers = nn.ModuleList() + # num_inputs will contain the number of channels in the input of every layer up to the current layer. It is + # initialized with the number of channels in the input image. + num_inputs = [in_channels] + for layer_config in layer_configs: + config = {**global_config, **layer_config} + layer, num_outputs = _create_layer(config, num_inputs, **kwargs) + self.layers.append(layer) + num_inputs.append(num_outputs) + + if weights_path is not None: + with open(weights_path) as weight_file: + self.load_weights(weight_file) + + # A workaround for TorchScript compilation. For some reason, the compilation will crash with "Unknown type name + # 'ShortcutLayer'" without this. + self._ = ShortcutLayer(0) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + outputs: List[Tensor] = [] # Outputs from all layers + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = get_image_size(x) + + for layer in self.layers: + if isinstance(layer, (RouteLayer, ShortcutLayer)): + x = layer(outputs) + elif isinstance(layer, DetectionLayer): + x, preds = layer(x, image_size) + detections.append(x) + if targets is not None: + layer_losses, layer_hits = layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + else: + x = layer(x) + + outputs.append(x) + + return detections, losses, hits + + def load_weights(self, weight_file: io.IOBase) -> None: + """Loads weights to layer modules from a pretrained Darknet model. + + One may want to continue training from pretrained weights, on a dataset with a different number of object + categories. The number of kernels in the convolutional layers just before each detection layer depends on the + number of output classes. The Darknet solution is to truncate the weight file and stop reading weights at the + first incompatible layer. For this reason the function silently leaves the rest of the layers unchanged, when + the weight file ends. + + Args: + weight_file: A file-like object containing model weights in the Darknet binary format. + """ + if not isinstance(weight_file, io.IOBase): + raise ValueError("weight_file must be a file-like object.") + + version = np.fromfile(weight_file, count=3, dtype=np.int32) + images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) + print( + f"Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} " + f"that has been trained on {images_seen[0]} images." + ) + + def read(tensor: Tensor) -> int: + """Reads the contents of ``tensor`` from the current position of ``weight_file``. + + Returns the number of elements read. If there's no more data in ``weight_file``, returns 0. + """ + np_array = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) + num_elements = np_array.size + if num_elements > 0: + source = torch.from_numpy(np_array).view_as(tensor) + with torch.no_grad(): + tensor.copy_(source) + return num_elements + + for layer in self.layers: + # Weights are loaded only to convolutional layers + if not isinstance(layer, Conv): + continue + + # If convolution is followed by batch normalization, read the batch normalization parameters. Otherwise we + # read the convolution bias. + if isinstance(layer.norm, nn.Identity): + assert layer.conv.bias is not None + read(layer.conv.bias) + else: + assert isinstance(layer.norm, nn.BatchNorm2d) + assert layer.norm.running_mean is not None + assert layer.norm.running_var is not None + read(layer.norm.bias) + read(layer.norm.weight) + read(layer.norm.running_mean) + read(layer.norm.running_var) + + read_count = read(layer.conv.weight) + if read_count == 0: + return + + def _read_config(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: + """Reads a Darnet network configuration file and returns a list of configuration sections. + + Args: + config_file: The configuration file to read. + + Returns: + A list of configuration sections. + """ + section_re = re.compile(r"\[([^]]+)\]") + list_variables = ("layers", "anchors", "mask", "scales") + variable_types = { + "activation": str, + "anchors": int, + "angle": float, + "batch": int, + "batch_normalize": bool, + "beta_nms": float, + "burn_in": int, + "channels": int, + "classes": int, + "cls_normalizer": float, + "decay": float, + "exposure": float, + "filters": int, + "from": int, + "groups": int, + "group_id": int, + "height": int, + "hue": float, + "ignore_thresh": float, + "iou_loss": str, + "iou_normalizer": float, + "iou_thresh": float, + "jitter": float, + "layers": int, + "learning_rate": float, + "mask": int, + "max_batches": int, + "max_delta": float, + "momentum": float, + "mosaic": bool, + "new_coords": int, + "nms_kind": str, + "num": int, + "obj_normalizer": float, + "pad": bool, + "policy": str, + "random": bool, + "resize": float, + "saturation": float, + "scales": float, + "scale_x_y": float, + "size": int, + "steps": str, + "stride": int, + "subdivisions": int, + "truth_thresh": float, + "width": int, + } + + section = None + sections = [] + + def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int, float]]]: + """Converts a value to the correct type based on key.""" + if key not in variable_types: + warn("Unknown YOLO configuration variable: " + key) + return value + if key in list_variables: + return [variable_types[key](v) for v in value.split(",")] + else: + return variable_types[key](value) + + for line in config_file: + line = line.strip() + if (not line) or (line[0] == "#"): + continue + + section_match = section_re.match(line) + if section_match: + if section is not None: + sections.append(section) + section = {"type": section_match.group(1)} + else: + if section is None: + raise RuntimeError("Darknet network configuration file does not start with a section header.") + key, value = line.split("=") + key = key.rstrip() + value = value.lstrip() + section[key] = convert(key, value) + if section is not None: + sections.append(section) + + return sections + + +def _create_layer(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the + layer config. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + create_func: Dict[str, Callable[..., CREATE_LAYER_OUTPUT]] = { + "convolutional": _create_convolutional, + "maxpool": _create_maxpool, + "route": _create_route, + "shortcut": _create_shortcut, + "upsample": _create_upsample, + "yolo": _create_yolo, + } + return create_func[config["type"]](config, num_inputs, **kwargs) + + +def _create_convolutional(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a convolutional layer. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + batch_normalize = config.get("batch_normalize", False) + padding = (config["size"] - 1) // 2 if config["pad"] else 0 + + layer = Conv( + num_inputs[-1], + config["filters"], + kernel_size=config["size"], + stride=config["stride"], + padding=padding, + bias=not batch_normalize, + activation=config["activation"], + norm="batchnorm" if batch_normalize else None, + ) + return layer, config["filters"] + + +def _create_maxpool(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a max pooling layer. + + Padding is added so that the output resolution will be the input resolution divided by stride, rounded upwards. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = MaxPool(config["size"], config["stride"]) + return layer, num_inputs[-1] + + +def _create_route(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a routing layer. + + A routing layer concatenates the output (or part of it) from the layers specified by the "layers" configuration + option. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + num_chunks = config.get("groups", 1) + chunk_idx = config.get("group_id", 0) + + # 0 is the first layer, -1 is the previous layer + last = len(num_inputs) - 1 + source_layers = [layer if layer >= 0 else last + layer for layer in config["layers"]] + + layer = RouteLayer(source_layers, num_chunks, chunk_idx) + + # The number of outputs of a source layer is the number of inputs of the next layer. + num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers) + + return layer, num_outputs + + +def _create_shortcut(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a shortcut layer. + + A shortcut layer adds a residual connection from the layer specified by the "from" configuration option. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = ShortcutLayer(config["from"]) + return layer, num_inputs[-1] + + +def _create_upsample(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a layer that upsamples the data. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = nn.Upsample(scale_factor=config["stride"], mode="nearest") + return layer, num_inputs[-1] + + +def _create_yolo( + config: DARKNET_CONFIG, + num_inputs: List[int], + prior_shapes: Optional[PRIOR_SHAPES] = None, + matching_algorithm: Optional[str] = None, + matching_threshold: Optional[float] = None, + spatial_range: float = 5.0, + size_range: float = 4.0, + ignore_bg_threshold: Optional[float] = None, + overlap_func: Optional[str] = None, + predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, + overlap_loss_multiplier: Optional[float] = None, + confidence_loss_multiplier: Optional[float] = None, + class_loss_multiplier: Optional[float] = None, + **kwargs: Any, +) -> CREATE_LAYER_OUTPUT: + """Creates a YOLO detection layer. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. Not used by the detection layer. + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `M x N` pairs, where `M` is the number of detection layers and `N` is the number + of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the + highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the + largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output (always 0 for a detection layer). + """ + if prior_shapes is None: + # The "anchors" list alternates width and height. + dims = config["anchors"] + prior_shapes = [[dims[i], dims[i + 1]] for i in range(0, len(dims), 2)] + if ignore_bg_threshold is None: + ignore_bg_threshold = config.get("ignore_thresh", 1.0) + assert isinstance(ignore_bg_threshold, float) + if overlap_func is None: + overlap_func = config.get("iou_loss", "iou") + assert isinstance(overlap_func, str) + if overlap_loss_multiplier is None: + overlap_loss_multiplier = config.get("iou_normalizer", 1.0) + assert isinstance(overlap_loss_multiplier, float) + if confidence_loss_multiplier is None: + confidence_loss_multiplier = config.get("obj_normalizer", 1.0) + assert isinstance(confidence_loss_multiplier, float) + if class_loss_multiplier is None: + class_loss_multiplier = config.get("cls_normalizer", 1.0) + assert isinstance(class_loss_multiplier, float) + + layer = create_detection_layer( + num_classes=config["classes"], + prior_shapes=prior_shapes, + prior_shape_idxs=config["mask"], + matching_algorithm=matching_algorithm, + matching_threshold=matching_threshold, + spatial_range=spatial_range, + size_range=size_range, + ignore_bg_threshold=ignore_bg_threshold, + overlap_func=overlap_func, + predict_overlap=predict_overlap, + label_smoothing=label_smoothing, + overlap_loss_multiplier=overlap_loss_multiplier, + confidence_loss_multiplier=confidence_loss_multiplier, + class_loss_multiplier=class_loss_multiplier, + xy_scale=config.get("scale_x_y", 1.0), + input_is_normalized=config.get("new_coords", 0) > 0, + ) + return layer, 0 diff --git a/torchvision/models/yolo.py b/torchvision/models/yolo.py new file mode 100644 index 00000000000..8f3be429a63 --- /dev/null +++ b/torchvision/models/yolo.py @@ -0,0 +1,731 @@ +from collections import OrderedDict +from typing import List, Optional, Sequence, Tuple + +import torch +from torch import nn, Tensor + + +def _get_padding(kernel_size: int, stride: int) -> Tuple[int, nn.Module]: + """Returns the amount of padding needed by convolutional and max pooling layers. + + Determines the amount of padding needed to make the output size of the layer the input size divided by the stride. + The first value that the function returns is the amount of padding to be added to all sides of the input matrix + (``padding`` argument of the operation). If an uneven amount of padding is needed in different sides of the input, + the second variable that is returned is an ``nn.ZeroPad2d`` operation that adds an additional column and row of + padding. If the input size is not divisible by the stride, the output size will be rounded upwards. + + Args: + kernel_size: Size of the kernel. + stride: Stride of the operation. + + Returns: + padding, pad_op: The amount of padding to be added to all sides of the input and an ``nn.Identity`` or + ``nn.ZeroPad2d`` operation to add one more column and row of padding if necessary. + """ + # The output size is generally (input_size + padding - max(kernel_size, stride)) / stride + 1 and we want to + # make it equal to input_size / stride. + padding, remainder = divmod(max(kernel_size, stride) - stride, 2) + + # If the kernel size is an even number, we need one cell of extra padding, on top of the padding added by MaxPool2d + # on both sides. + pad_op: nn.Module = nn.Identity() if remainder == 0 else nn.ZeroPad2d((0, 1, 0, 1)) + + return padding, pad_op + + +def _create_activation_module(name: Optional[str]) -> nn.Module: + """Creates a layer activation module given its type as a string. + + Args: + name: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", "linear", + or "none". + """ + if name == "relu": + return nn.ReLU(inplace=True) + if name == "leaky": + return nn.LeakyReLU(0.1, inplace=True) + if name == "mish": + return Mish() + if name == "silu" or name == "swish": + return nn.SiLU(inplace=True) + if name == "logistic": + return nn.Sigmoid() + if name == "linear" or name == "none" or name is None: + return nn.Identity() + raise ValueError(f"Activation type `{name}´ is unknown.") + + +def _create_normalization_module(name: Optional[str], num_channels: int) -> nn.Module: + """Creates a layer normalization module given its type as a string. + + Group normalization uses always 8 channels. The most common network widths are divisible by this number. + + Args: + name: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + num_channels: The number of input channels that the module expects. + """ + if name == "batchnorm": + return nn.BatchNorm2d(num_channels, eps=0.001) + if name == "groupnorm": + return nn.GroupNorm(8, num_channels, eps=0.001) + if name == "none" or name is None: + return nn.Identity() + raise ValueError(f"Normalization layer type `{name}´ is unknown.") + + +class Conv(nn.Module): + """A convolutional layer with optional layer normalization and activation. + + If ``padding`` is ``None``, the module tries to add padding so much that the output size will be the input size + divided by the stride. If the input size is not divisible by the stride, the output size will be rounded upwards. + + Args: + in_channels: Number of input channels that the layer expects. + out_channels: Number of output channels that the convolution produces. + kernel_size: Size of the convolving kernel. + stride: Stride of the convolution. + padding: Padding added to all four sides of the input. + bias: If ``True``, adds a learnable bias to the output. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[int] = None, + bias: bool = False, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + + if padding is None: + padding, self.pad = _get_padding(kernel_size, stride) + else: + self.pad = nn.Identity() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.norm = _create_normalization_module(norm, out_channels) + self.act = _create_activation_module(activation) + + def forward(self, x: Tensor) -> Tensor: + x = self.pad(x) + x = self.conv(x) + x = self.norm(x) + return self.act(x) + + +class MaxPool(nn.Module): + """A max pooling layer with padding. + + The module tries to add padding so much that the output size will be the input size divided by the stride. If the + input size is not divisible by the stride, the output size will be rounded upwards. + """ + + def __init__(self, kernel_size: int, stride: int): + super().__init__() + padding, self.pad = _get_padding(kernel_size, stride) + self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) + + def forward(self, x: Tensor) -> Tensor: + x = self.pad(x) + return self.maxpool(x) + + +class RouteLayer(nn.Module): + """A routing layer concatenates the output (or part of it) from given layers. + + Args: + source_layers: Indices of the layers whose output will be concatenated. + num_chunks: Layer outputs will be split into this number of chunks. + chunk_idx: Only the chunks with this index will be concatenated. + """ + + def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: + super().__init__() + self.source_layers = source_layers + self.num_chunks = num_chunks + self.chunk_idx = chunk_idx + + def forward(self, outputs: List[Tensor]) -> Tensor: + chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] + return torch.cat(chunks, dim=1) + + +class ShortcutLayer(nn.Module): + """A shortcut layer adds a residual connection from the source layer. + + Args: + source_layer: Index of the layer whose output will be added to the output of the previous layer. + """ + + def __init__(self, source_layer: int) -> None: + super().__init__() + self.source_layer = source_layer + + def forward(self, outputs: List[Tensor]) -> Tensor: + return outputs[-1] + outputs[self.source_layer] + + +class Mish(nn.Module): + """Mish activation.""" + + def forward(self, x: Tensor) -> Tensor: + return x * torch.tanh(nn.functional.softplus(x)) + + +class ReOrg(nn.Module): + """Re-organizes the tensor so that every square region of four cells is placed into four different channels. + + The result is a tensor with half the width and height, and four times as many channels. + """ + + def forward(self, x: Tensor) -> Tensor: + tl = x[..., ::2, ::2] + bl = x[..., 1::2, ::2] + tr = x[..., ::2, 1::2] + br = x[..., 1::2, 1::2] + return torch.cat((tl, bl, tr, br), dim=1) + + +class BottleneckBlock(nn.Module): + """A residual block with a bottleneck layer. + + Args: + in_channels: Number of input channels that the block expects. + out_channels: Number of output channels that the block produces. + hidden_channels: Number of output channels the (hidden) bottleneck layer produces. By default the number of + output channels of the block. + shortcut: Whether the block should include a shortcut connection. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int] = None, + shortcut: bool = True, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + if hidden_channels is None: + hidden_channels = out_channels + + self.convs = nn.Sequential( + Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm), + Conv(hidden_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=norm), + ) + self.shortcut = shortcut and in_channels == out_channels + + def forward(self, x: Tensor) -> Tensor: + y = self.convs(x) + return x + y if self.shortcut else y + + +class TinyStage(nn.Module): + """One stage of the "tiny" network architecture from YOLOv4. + + Args: + num_channels: Number of channels in the input of the stage. Partial output will have as many channels and full + output will have twice as many channels. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + num_channels: int, + activation: Optional[str] = "leaky", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + hidden_channels = num_channels // 2 + self.conv1 = Conv(hidden_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + self.conv2 = Conv(hidden_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + self.mix = Conv(num_channels, num_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + partial = torch.chunk(x, 2, dim=1)[1] + y1 = self.conv1(partial) + y2 = self.conv2(y1) + partial_output = self.mix(torch.cat((y2, y1), dim=1)) + full_output = torch.cat((x, partial_output), dim=1) + return partial_output, full_output + + +class CSPStage(nn.Module): + """One stage of a Cross Stage Partial Network (CSPNet). + + Encapsulates a number of bottleneck blocks in the "fusion first" CSP structure. + + `Chien-Yao Wang et al. `_ + + Args: + in_channels: Number of input channels that the CSP stage expects. + out_channels: Number of output channels that the CSP stage produces. + depth: Number of bottleneck blocks that the CSP stage contains. + shortcut: Whether the bottleneck blocks should include a shortcut connection. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + depth: int = 1, + shortcut: bool = True, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + # Instead of splitting the N output channels of a convolution into two parts, we can equivalently perform two + # convolutions with N/2 output channels. + hidden_channels = out_channels // 2 + + self.split1 = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.split2 = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + bottlenecks: List[nn.Module] = [ + BottleneckBlock(hidden_channels, hidden_channels, shortcut=shortcut, norm=norm, activation=activation) + for _ in range(depth) + ] + self.bottlenecks = nn.Sequential(*bottlenecks) + self.mix = Conv(hidden_channels * 2, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + y1 = self.bottlenecks(self.split1(x)) + y2 = self.split2(x) + return self.mix(torch.cat((y1, y2), dim=1)) + + +class ELANStage(nn.Module): + """One stage of an Efficient Layer Aggregation Network (ELAN). + + `Chien-Yao Wang et al. `_ + + Args: + in_channels: Number of input channels that the ELAN stage expects. + out_channels: Number of output channels that the ELAN stage produces. + hidden_channels: Number of output channels that the computational blocks produce. The default value is half the + number of output channels of the block, as in YOLOv7-W6, but the value varies between the variants. + split_channels: Number of channels in each part after splitting the input to the cross stage connection and the + computational blocks. The default value is the number of hidden channels, as in all YOLOv7 backbones. Most + YOLOv7 heads use twice the number of hidden channels. + depth: Number of computational blocks that the ELAN stage contains. The default value is 2. YOLOv7 backbones use + 2 to 4 blocks per stage. + block_depth: Number of convolutional layers in one computational block. The default value is 2. YOLOv7 backbones + have two convolutions per block. YOLOv7 heads (except YOLOv7-X) have 2 to 8 blocks with only one convolution + in each. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int] = None, + split_channels: Optional[int] = None, + depth: int = 2, + block_depth: int = 2, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + + def block(in_channels: int, out_channels: int) -> nn.Module: + convs = [conv3x3(in_channels, out_channels)] + for _ in range(block_depth - 1): + convs.append(conv3x3(out_channels, out_channels)) + return nn.Sequential(*convs) + + # Instead of splitting the N output channels of a convolution into two parts, we can equivalently perform two + # convolutions with N/2 output channels. However, in many YOLOv7 architectures, the number of hidden channels is + # not exactly half the number of output channels. + if hidden_channels is None: + hidden_channels = out_channels // 2 + + if split_channels is None: + split_channels = hidden_channels + + self.split1 = Conv(in_channels, split_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.split2 = Conv(in_channels, split_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + blocks = [block(split_channels, hidden_channels)] + for _ in range(depth - 1): + blocks.append(block(hidden_channels, hidden_channels)) + self.blocks = nn.ModuleList(blocks) + + total_channels = (split_channels * 2) + (hidden_channels * depth) + self.mix = Conv(total_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + outputs = [self.split1(x), self.split2(x)] + x = outputs[-1] + for block in self.blocks: + x = block(x) + outputs.append(x) + return self.mix(torch.cat(outputs, dim=1)) + + +class CSPSPP(nn.Module): + """Spatial pyramid pooling module from the Cross Stage Partial Network from YOLOv4. + + Args: + in_channels: Number of input channels that the module expects. + out_channels: Number of output channels that the module produces. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, activation=activation, norm=norm) + + self.conv1 = nn.Sequential( + conv(in_channels, out_channels), + conv(out_channels, out_channels, kernel_size=3), + conv(out_channels, out_channels), + ) + self.conv2 = conv(in_channels, out_channels) + + self.maxpool1 = MaxPool(kernel_size=5, stride=1) + self.maxpool2 = MaxPool(kernel_size=9, stride=1) + self.maxpool3 = MaxPool(kernel_size=13, stride=1) + + self.mix1 = nn.Sequential( + conv(4 * out_channels, out_channels), + conv(out_channels, out_channels, kernel_size=3), + ) + self.mix2 = Conv(2 * out_channels, out_channels) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.conv1(x) + x2 = self.maxpool1(x1) + x3 = self.maxpool2(x1) + x4 = self.maxpool3(x1) + y1 = self.mix1(torch.cat((x1, x2, x3, x4), dim=1)) + y2 = self.conv2(x) + return self.mix2(torch.cat((y1, y2), dim=1)) + + +class FastSPP(nn.Module): + """Fast spatial pyramid pooling module from YOLOv5. + + Args: + in_channels: Number of input channels that the module expects. + out_channels: Number of output channels that the module produces. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + hidden_channels = in_channels // 2 + self.conv = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.maxpool = MaxPool(kernel_size=5, stride=1) + self.mix = Conv(hidden_channels * 4, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + y1 = self.conv(x) + y2 = self.maxpool(y1) + y3 = self.maxpool(y2) + y4 = self.maxpool(y3) + return self.mix(torch.cat((y1, y2, y3, y4), dim=1)) + + +class YOLOV4TinyBackbone(nn.Module): + """Backbone of the "tiny" network architecture from YOLOv4. + + Args: + in_channels: Number of channels in the input image. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + width: int = 32, + activation: Optional[str] = "leaky", + normalization: Optional[str] = "batchnorm", + ): + super().__init__() + + def smooth(num_channels: int) -> nn.Module: + return Conv(num_channels, num_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + conv_module = Conv( + in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization + ) + return nn.Sequential(OrderedDict([("downsample", conv_module), ("smooth", smooth(out_channels))])) + + def maxpool(out_channels: int) -> nn.Module: + return nn.Sequential( + OrderedDict( + [ + ("pad", nn.ZeroPad2d((0, 1, 0, 1))), + ("maxpool", MaxPool(kernel_size=2, stride=2)), + ("smooth", smooth(out_channels)), + ] + ) + ) + + def stage(out_channels: int, use_maxpool: bool) -> nn.Module: + if use_maxpool: + downsample_module = maxpool(out_channels) + else: + downsample_module = downsample(out_channels // 2, out_channels) + stage_module = TinyStage(out_channels, activation=activation, norm=normalization) + return nn.Sequential(OrderedDict([("downsample", downsample_module), ("stage", stage_module)])) + + stages = [ + Conv(in_channels, width, kernel_size=3, stride=2, activation=activation, norm=normalization), + stage(width * 2, False), + stage(width * 4, True), + stage(width * 8, True), + maxpool(width * 16), + ] + self.stages = nn.ModuleList(stages) + + def forward(self, x: Tensor) -> List[Tensor]: + c1 = self.stages[0](x) + c2, x = self.stages[1](c1) + c3, x = self.stages[2](x) + c4, x = self.stages[3](x) + c5 = self.stages[4](x) + return [c1, c2, c3, c4, c5] + + +class YOLOV4Backbone(nn.Module): + """A backbone that corresponds approximately to the Cross Stage Partial Network from YOLOv4. + + Args: + in_channels: Number of channels in the input image. + widths: Number of channels at each network stage. Typically ``(32, 64, 128, 256, 512, 1024)``. The P6 variant + adds one more stage with 1024 channels. + depths: Number of bottleneck layers at each network stage. Typically ``(1, 1, 2, 8, 8, 4)``. The P6 variant uses + ``(1, 1, 3, 15, 15, 7, 7)``. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024), + depths: Sequence[int] = (1, 1, 2, 8, 8, 4), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + if len(widths) != len(depths): + raise ValueError("Width and depth has to be given for an equal number of stages.") + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def stage(in_channels: int, out_channels: int, depth: int) -> nn.Module: + csp = CSPStage( + out_channels, + out_channels, + depth=depth, + shortcut=True, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("csp", csp), + ] + ) + ) + + convs = [conv3x3(in_channels, widths[0])] + [conv3x3(widths[0], widths[0]) for _ in range(depths[0] - 1)] + self.stem = nn.Sequential(*convs) + self.stages = nn.ModuleList( + stage(in_channels, out_channels, depth) + for in_channels, out_channels, depth in zip(widths[:-1], widths[1:], depths[1:]) + ) + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.stem(x) + outputs: List[Tensor] = [] + for stage in self.stages: + x = stage(x) + outputs.append(x) + return outputs + + +class YOLOV5Backbone(nn.Module): + """The Cross Stage Partial Network backbone from YOLOv5. + + Args: + in_channels: Number of channels in the input image. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 16 (yolov5n), 32 + (yolov5s), 48 (yolov5m), 64 (yolov5l), and 80 (yolov5x). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolov5n, yolov5s), 2 (yolov5m), 3 (yolov5l), and 4 (yolov5x). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def downsample(in_channels: int, out_channels: int, kernel_size: int = 3) -> nn.Module: + return Conv( + in_channels, out_channels, kernel_size=kernel_size, stride=2, activation=activation, norm=normalization + ) + + def stage(in_channels: int, out_channels: int, depth: int) -> nn.Module: + csp = CSPStage( + out_channels, + out_channels, + depth=depth, + shortcut=True, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("csp", csp), + ] + ) + ) + + stages = [ + downsample(in_channels, width, kernel_size=6), + stage(width, width * 2, depth), + stage(width * 2, width * 4, depth * 2), + stage(width * 4, width * 8, depth * 3), + stage(width * 8, width * 16, depth), + ] + self.stages = nn.ModuleList(stages) + + def forward(self, x: Tensor) -> List[Tensor]: + c1 = self.stages[0](x) + c2 = self.stages[1](c1) + c3 = self.stages[2](c2) + c4 = self.stages[3](c3) + c5 = self.stages[4](c4) + return [c1, c2, c3, c4, c5] + + +class YOLOV7Backbone(nn.Module): + """A backbone that corresponds to the W6 variant of the Efficient Layer Aggregation Network from YOLOv7. + + Args: + in_channels: Number of channels in the input image. + widths: Number of channels at each network stage. Before the first stage there will be one extra split of + spatial resolution by a ``ReOrg`` layer, producing ``in_channels * 4`` channels. + depth: Number of computational blocks at each network stage. YOLOv7-W6 backbone uses 2. + block_depth: Number of convolutional layers in one computational block. YOLOv7-W6 backbone uses 2. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + widths: Sequence[int] = (64, 128, 256, 512, 768, 1024), + depth: int = 2, + block_depth: int = 2, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def stage(in_channels: int, out_channels: int) -> nn.Module: + elan = ELANStage( + out_channels, + out_channels, + depth=depth, + block_depth=block_depth, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("elan", elan), + ] + ) + ) + + self.stem = nn.Sequential(*[ReOrg(), conv3x3(in_channels * 4, widths[0])]) + self.stages = nn.ModuleList( + stage(in_channels, out_channels) for in_channels, out_channels in zip(widths[:-1], widths[1:]) + ) + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.stem(x) + outputs: List[Tensor] = [] + for stage in self.stages: + x = stage(x) + outputs.append(x) + return outputs