|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. |
| 2 | +import torch |
| 3 | +from torch import nn |
| 4 | + |
| 5 | +from torch.jit.annotations import List, Optional, Dict |
| 6 | +from .image_list import ImageList |
| 7 | + |
| 8 | + |
| 9 | +class AnchorGenerator(nn.Module): |
| 10 | + """ |
| 11 | + Module that generates anchors for a set of feature maps and |
| 12 | + image sizes. |
| 13 | +
|
| 14 | + The module support computing anchors at multiple sizes and aspect ratios |
| 15 | + per feature map. This module assumes aspect ratio = height / width for |
| 16 | + each anchor. |
| 17 | +
|
| 18 | + sizes and aspect_ratios should have the same number of elements, and it should |
| 19 | + correspond to the number of feature maps. |
| 20 | +
|
| 21 | + sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, |
| 22 | + and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors |
| 23 | + per spatial location for feature map i. |
| 24 | +
|
| 25 | + Arguments: |
| 26 | + sizes (Tuple[Tuple[int]]): |
| 27 | + aspect_ratios (Tuple[Tuple[float]]): |
| 28 | + """ |
| 29 | + |
| 30 | + __annotations__ = { |
| 31 | + "cell_anchors": Optional[List[torch.Tensor]], |
| 32 | + "_cache": Dict[str, List[torch.Tensor]] |
| 33 | + } |
| 34 | + |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + sizes=((128, 256, 512),), |
| 38 | + aspect_ratios=((0.5, 1.0, 2.0),), |
| 39 | + ): |
| 40 | + super(AnchorGenerator, self).__init__() |
| 41 | + |
| 42 | + if not isinstance(sizes[0], (list, tuple)): |
| 43 | + # TODO change this |
| 44 | + sizes = tuple((s,) for s in sizes) |
| 45 | + if not isinstance(aspect_ratios[0], (list, tuple)): |
| 46 | + aspect_ratios = (aspect_ratios,) * len(sizes) |
| 47 | + |
| 48 | + assert len(sizes) == len(aspect_ratios) |
| 49 | + |
| 50 | + self.sizes = sizes |
| 51 | + self.aspect_ratios = aspect_ratios |
| 52 | + self.cell_anchors = None |
| 53 | + self._cache = {} |
| 54 | + |
| 55 | + # TODO: https://github.com/pytorch/pytorch/issues/26792 |
| 56 | + # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. |
| 57 | + # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) |
| 58 | + # This method assumes aspect ratio = height / width for an anchor. |
| 59 | + def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): |
| 60 | + # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 |
| 61 | + scales = torch.as_tensor(scales, dtype=dtype, device=device) |
| 62 | + aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) |
| 63 | + h_ratios = torch.sqrt(aspect_ratios) |
| 64 | + w_ratios = 1 / h_ratios |
| 65 | + |
| 66 | + ws = (w_ratios[:, None] * scales[None, :]).view(-1) |
| 67 | + hs = (h_ratios[:, None] * scales[None, :]).view(-1) |
| 68 | + |
| 69 | + base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 |
| 70 | + return base_anchors.round() |
| 71 | + |
| 72 | + def set_cell_anchors(self, dtype, device): |
| 73 | + # type: (int, Device) -> None # noqa: F821 |
| 74 | + if self.cell_anchors is not None: |
| 75 | + cell_anchors = self.cell_anchors |
| 76 | + assert cell_anchors is not None |
| 77 | + # suppose that all anchors have the same device |
| 78 | + # which is a valid assumption in the current state of the codebase |
| 79 | + if cell_anchors[0].device == device: |
| 80 | + return |
| 81 | + |
| 82 | + cell_anchors = [ |
| 83 | + self.generate_anchors( |
| 84 | + sizes, |
| 85 | + aspect_ratios, |
| 86 | + dtype, |
| 87 | + device |
| 88 | + ) |
| 89 | + for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) |
| 90 | + ] |
| 91 | + self.cell_anchors = cell_anchors |
| 92 | + |
| 93 | + def num_anchors_per_location(self): |
| 94 | + return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] |
| 95 | + |
| 96 | + # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), |
| 97 | + # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. |
| 98 | + def grid_anchors(self, grid_sizes, strides): |
| 99 | + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] |
| 100 | + anchors = [] |
| 101 | + cell_anchors = self.cell_anchors |
| 102 | + assert cell_anchors is not None |
| 103 | + assert len(grid_sizes) == len(strides) == len(cell_anchors) |
| 104 | + |
| 105 | + for size, stride, base_anchors in zip( |
| 106 | + grid_sizes, strides, cell_anchors |
| 107 | + ): |
| 108 | + grid_height, grid_width = size |
| 109 | + stride_height, stride_width = stride |
| 110 | + device = base_anchors.device |
| 111 | + |
| 112 | + # For output anchor, compute [x_center, y_center, x_center, y_center] |
| 113 | + shifts_x = torch.arange( |
| 114 | + 0, grid_width, dtype=torch.float32, device=device |
| 115 | + ) * stride_width |
| 116 | + shifts_y = torch.arange( |
| 117 | + 0, grid_height, dtype=torch.float32, device=device |
| 118 | + ) * stride_height |
| 119 | + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) |
| 120 | + shift_x = shift_x.reshape(-1) |
| 121 | + shift_y = shift_y.reshape(-1) |
| 122 | + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) |
| 123 | + |
| 124 | + # For every (base anchor, output anchor) pair, |
| 125 | + # offset each zero-centered base anchor by the center of the output anchor. |
| 126 | + anchors.append( |
| 127 | + (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) |
| 128 | + ) |
| 129 | + |
| 130 | + return anchors |
| 131 | + |
| 132 | + def cached_grid_anchors(self, grid_sizes, strides): |
| 133 | + # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] |
| 134 | + key = str(grid_sizes) + str(strides) |
| 135 | + if key in self._cache: |
| 136 | + return self._cache[key] |
| 137 | + anchors = self.grid_anchors(grid_sizes, strides) |
| 138 | + self._cache[key] = anchors |
| 139 | + return anchors |
| 140 | + |
| 141 | + def forward(self, image_list, feature_maps): |
| 142 | + # type: (ImageList, List[Tensor]) -> List[Tensor] |
| 143 | + grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) |
| 144 | + image_size = image_list.tensors.shape[-2:] |
| 145 | + dtype, device = feature_maps[0].dtype, feature_maps[0].device |
| 146 | + strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), |
| 147 | + torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] |
| 148 | + self.set_cell_anchors(dtype, device) |
| 149 | + anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) |
| 150 | + anchors = torch.jit.annotate(List[List[torch.Tensor]], []) |
| 151 | + for i, (image_height, image_width) in enumerate(image_list.image_sizes): |
| 152 | + anchors_in_image = [] |
| 153 | + for anchors_per_feature_map in anchors_over_all_feature_maps: |
| 154 | + anchors_in_image.append(anchors_per_feature_map) |
| 155 | + anchors.append(anchors_in_image) |
| 156 | + anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] |
| 157 | + # Clear the cache in case that memory leaks. |
| 158 | + self._cache.clear() |
| 159 | + return anchors |
0 commit comments