|
| 1 | +#include <ATen/native/mps/kernels/GridSampler.h> |
| 2 | +#include <c10/metal/utils.h> |
| 3 | +#include <metal_array> |
| 4 | +#include <metal_stdlib> |
| 5 | + |
| 6 | +using namespace metal; |
| 7 | +using namespace c10::metal; |
| 8 | + |
| 9 | +struct GridSamplerOffsets { |
| 10 | + int32_t output; |
| 11 | + int32_t input; |
| 12 | + int32_t grid; |
| 13 | + |
| 14 | + GridSamplerOffsets() : output(0), input(0), grid(0) {} |
| 15 | +}; |
| 16 | + |
| 17 | +// Find offsets into the tensors that this thread will operate on, |
| 18 | +// based on the thread ID. |
| 19 | +static GridSamplerOffsets find_grid_sampler_offsets( |
| 20 | + constant int32_t* output_sizes, |
| 21 | + constant int32_t* output_strides, |
| 22 | + constant int32_t* input_sizes, |
| 23 | + constant int32_t* input_strides, |
| 24 | + constant int32_t* grid_sizes, |
| 25 | + constant int32_t* grid_strides, |
| 26 | + int32_t sampler_dims, |
| 27 | + uint tid) { |
| 28 | + auto dims = sampler_dims + 2; |
| 29 | + auto output_idx = static_cast<int32_t>(tid); |
| 30 | + GridSamplerOffsets offsets; |
| 31 | + |
| 32 | + for (auto dim = dims - 1; dim >= 0; dim--) { |
| 33 | + auto dim_idx = output_idx % output_sizes[dim]; |
| 34 | + output_idx = output_idx / output_sizes[dim]; |
| 35 | + |
| 36 | + // Select the output element that this thread will calculate. |
| 37 | + // output shape: |
| 38 | + // 2 sampler dims: (N, C, Hout, Wout) |
| 39 | + // 3 sampler dims: (N, C, Dout, Hout, Wout) |
| 40 | + offsets.output += output_strides[dim] * dim_idx; |
| 41 | + |
| 42 | + // Select the batch and channel for the input. |
| 43 | + // input shape: |
| 44 | + // 2 sampler dims: (N, C, Hin, Win) |
| 45 | + // 3 sampler dims: (N, C, Din, Hin, Win) |
| 46 | + if (dim < 2) { |
| 47 | + offsets.input += input_strides[dim] * dim_idx; |
| 48 | + } |
| 49 | + |
| 50 | + // Select the grid coordinates for the output element. |
| 51 | + // grid shape: |
| 52 | + // 2 sampler dims: (N, Hout, Wout, 2) |
| 53 | + // 3 sampler dims: (N, Dout, Hout, Wout, 3) |
| 54 | + if (dim == 0) { |
| 55 | + offsets.grid += grid_strides[dim] * dim_idx; |
| 56 | + } else if (dim >= 2) { |
| 57 | + offsets.grid += grid_strides[dim - 1] * dim_idx; |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + return offsets; |
| 62 | +} |
| 63 | + |
| 64 | +// Mod function which gives postive output when `a` is negative |
| 65 | +static int32_t mod(int32_t a, int32_t b) { |
| 66 | + auto r = a % b; |
| 67 | + return r + (r < 0 ? b : 0); |
| 68 | +} |
| 69 | + |
| 70 | +// Sentinel index value to indicate zero padding |
| 71 | +constant int32_t IDX_ZERO = -1; |
| 72 | + |
| 73 | +// Apply padding to an index into the input |
| 74 | +static int32_t pad_input_index( |
| 75 | + int32_t idx, |
| 76 | + int32_t input_size, |
| 77 | + GridSamplerPadding padding_mode, |
| 78 | + bool align_corners) { |
| 79 | + int32_t idx_padded = idx; |
| 80 | + |
| 81 | + if (padding_mode == GridSamplerPadding::Zeros) { |
| 82 | + idx_padded = (idx < 0) ? IDX_ZERO : idx_padded; |
| 83 | + idx_padded = (idx >= input_size) ? IDX_ZERO : idx_padded; |
| 84 | + |
| 85 | + } else if (padding_mode == GridSamplerPadding::Border) { |
| 86 | + idx_padded = (idx < 0) ? 0 : idx_padded; |
| 87 | + idx_padded = (idx >= input_size) ? input_size - 1 : idx_padded; |
| 88 | + |
| 89 | + } else if (padding_mode == GridSamplerPadding::Reflection) { |
| 90 | + auto scale_length = align_corners ? (input_size - 1) : input_size; |
| 91 | + auto idx_mod = mod(idx, scale_length); |
| 92 | + auto idx_mod_reverse = (input_size - 1) - idx_mod; |
| 93 | + bool is_reverse = (abs(idx - idx_mod) / scale_length) % 2 == 1; |
| 94 | + idx_padded = is_reverse ? idx_mod_reverse : idx_mod; |
| 95 | + } |
| 96 | + return idx_padded; |
| 97 | +} |
| 98 | + |
| 99 | +template <int32_t dims, typename T> |
| 100 | +T get_tensor_val( |
| 101 | + constant T* input, |
| 102 | + constant int32_t* input_strides, |
| 103 | + int32_t indices[dims]) { |
| 104 | + bool found_idx_zero = false; |
| 105 | + int32_t offset = 0; |
| 106 | + |
| 107 | + for (auto dim = 0; dim < dims; dim++) { |
| 108 | + auto idx = indices[dim]; |
| 109 | + found_idx_zero = found_idx_zero || (idx == IDX_ZERO); |
| 110 | + offset += (found_idx_zero ? 0 : idx) * input_strides[dim]; |
| 111 | + } |
| 112 | + |
| 113 | + return found_idx_zero ? 0 : input[offset]; |
| 114 | +} |
| 115 | + |
| 116 | +// This function performs 3D linear interpolation for one value. One way to |
| 117 | +// think of how this works is to imagine a unit cube where each corner of the |
| 118 | +// cube has one scalar value associated with it. Inside the cube, the values |
| 119 | +// change linearly, so the gradient is constant. The values associated with each |
| 120 | +// corner are given by the `input`, indexed at all eight different combinations |
| 121 | +// of the `left_indices` and `right_indices`. Given a 3D coordinate anywhere |
| 122 | +// within the cube, specified by the `scales` argument, we must calculate the |
| 123 | +// value associated with that position. |
| 124 | +template <typename T> |
| 125 | +T interpolate_linear_3d( |
| 126 | + constant T* input, |
| 127 | + constant int32_t* input_strides, |
| 128 | + int32_t left_indices[3], |
| 129 | + int32_t right_indices[3], |
| 130 | + opmath_t<T> scales[3]) { |
| 131 | + int32_t a_idx[3] = {left_indices[0], left_indices[1], left_indices[2]}; |
| 132 | + int32_t b_idx[3] = {left_indices[0], left_indices[1], right_indices[2]}; |
| 133 | + int32_t c_idx[3] = {left_indices[0], right_indices[1], left_indices[2]}; |
| 134 | + int32_t d_idx[3] = {left_indices[0], right_indices[1], right_indices[2]}; |
| 135 | + int32_t e_idx[3] = {right_indices[0], left_indices[1], left_indices[2]}; |
| 136 | + int32_t f_idx[3] = {right_indices[0], left_indices[1], right_indices[2]}; |
| 137 | + int32_t g_idx[3] = {right_indices[0], right_indices[1], left_indices[2]}; |
| 138 | + int32_t h_idx[3] = {right_indices[0], right_indices[1], right_indices[2]}; |
| 139 | + auto a = |
| 140 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, a_idx)); |
| 141 | + auto b = |
| 142 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, b_idx)); |
| 143 | + auto c = |
| 144 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, c_idx)); |
| 145 | + auto d = |
| 146 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, d_idx)); |
| 147 | + auto e = |
| 148 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, e_idx)); |
| 149 | + auto f = |
| 150 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, f_idx)); |
| 151 | + auto g = |
| 152 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, g_idx)); |
| 153 | + auto h = |
| 154 | + static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, h_idx)); |
| 155 | + |
| 156 | + auto scale0_right = scales[0]; |
| 157 | + auto scale1_right = scales[1]; |
| 158 | + auto scale2_right = scales[2]; |
| 159 | + auto scale0_left = 1 - scale0_right; |
| 160 | + auto scale1_left = 1 - scale1_right; |
| 161 | + auto scale2_left = 1 - scale2_right; |
| 162 | + |
| 163 | + return static_cast<T>( |
| 164 | + scale0_left * scale1_left * scale2_left * a + |
| 165 | + scale0_left * scale1_left * scale2_right * b + |
| 166 | + scale0_left * scale1_right * scale2_left * c + |
| 167 | + scale0_left * scale1_right * scale2_right * d + |
| 168 | + scale0_right * scale1_left * scale2_left * e + |
| 169 | + scale0_right * scale1_left * scale2_right * f + |
| 170 | + scale0_right * scale1_right * scale2_left * g + |
| 171 | + scale0_right * scale1_right * scale2_right * h); |
| 172 | +} |
| 173 | + |
| 174 | +// Calculates a single output element. |
| 175 | +// `input` shape: |
| 176 | +// 2 sampler dims: (Hin, Win) |
| 177 | +// 3 sampler dims: (Din, Hin, Win) |
| 178 | +// `coords` values: |
| 179 | +// 2 sampler dims: (Wcoord, Hcoord) |
| 180 | +// 3 sampler dims: (Wcoord, Hcoord, Dcoord) |
| 181 | +template <typename T> |
| 182 | +void grid_sampler_single_element( |
| 183 | + device T* output, |
| 184 | + constant T* input, |
| 185 | + constant T* coords, |
| 186 | + int32_t dims, |
| 187 | + constant int32_t* input_sizes, |
| 188 | + constant int32_t* input_strides, |
| 189 | + GridSamplerInterpolation interpolation_mode, |
| 190 | + GridSamplerPadding padding_mode, |
| 191 | + bool align_corners) { |
| 192 | + int32_t left_indices[3]; |
| 193 | + int32_t right_indices[3]; |
| 194 | + opmath_t<T> scales[3]; |
| 195 | + |
| 196 | + // For each dimension, find the pair of indices in the cooresponding dimension |
| 197 | + // of `input` which surround the grid coordinate in that dimension. We'll do |
| 198 | + // this by mapping different coordiante spaces onto each other. There are |
| 199 | + // basically three different coordinate spaces to keep in mind: |
| 200 | + // |
| 201 | + // * aligned grid space |
| 202 | + // - `-1` refers to the leftmost input value. |
| 203 | + // - `1` refers to the rightmost input value. |
| 204 | + // |
| 205 | + // * unaligned grid space |
| 206 | + // - `-1` refers to the midpoint between the leftmost input value and |
| 207 | + // a padding value to the left of that. |
| 208 | + // - `1` refers to the midpoint between the rightmost input value and |
| 209 | + // a padding value to the right of that. |
| 210 | + // |
| 211 | + // * input index space |
| 212 | + // - `n` refers to the n-th value of the input. |
| 213 | + // - `0` refers to the leftmost input value. |
| 214 | + // - `N-1` refers to the rightmost input value. |
| 215 | + // |
| 216 | + // If `align_corners == False`, then the coordinates are is in unaligned grid |
| 217 | + // space, and we will map it onto aligned grid space. If `align_corners == |
| 218 | + // True`, then coordinates are already in aligned grid space. |
| 219 | + // |
| 220 | + // Then we will map unaligned grid space onto input index space, making it |
| 221 | + // relatively simple to find the two input indices that surround the |
| 222 | + // coordinate. |
| 223 | + for (auto coord_dim = 0; coord_dim < dims; coord_dim++) { |
| 224 | + auto input_dim = dims - coord_dim - 1; |
| 225 | + auto input_size = input_sizes[input_dim]; |
| 226 | + auto coord = static_cast<opmath_t<T>>(coords[coord_dim]); |
| 227 | + |
| 228 | + // Interpret nan as -1 |
| 229 | + coord = isnan(coord) ? -1 : coord; |
| 230 | + |
| 231 | + if (!align_corners) { |
| 232 | + // Map unaligned grid space to aligned grid space |
| 233 | + auto corner_alignment_factor = static_cast<opmath_t<T>>(input_size) / |
| 234 | + static_cast<opmath_t<T>>(input_size - 1); |
| 235 | + coord = coord * corner_alignment_factor; |
| 236 | + } |
| 237 | + |
| 238 | + // Map aligned grid space to input index space |
| 239 | + coord = (coord + 1) * (static_cast<opmath_t<T>>(input_size - 1) / 2); |
| 240 | + |
| 241 | + // Get the input indices surrounding the coordinate, apply padding to them, |
| 242 | + // and obtain the scaling factor between the two for interpolation. |
| 243 | + auto left_idx = static_cast<int32_t>(floor(coord)); |
| 244 | + auto right_idx = static_cast<int32_t>(ceil(coord)); |
| 245 | + left_indices[input_dim] = |
| 246 | + pad_input_index(left_idx, input_size, padding_mode, align_corners); |
| 247 | + right_indices[input_dim] = |
| 248 | + pad_input_index(right_idx, input_size, padding_mode, align_corners); |
| 249 | + |
| 250 | + auto scale = coord - left_idx; |
| 251 | + |
| 252 | + if (interpolation_mode == GridSamplerInterpolation::Nearest) { |
| 253 | + // TODO: For some reason, rounding the scale to 0 or 1 and then using |
| 254 | + // linear interpolation seems to work perfectly with zero padding mode, |
| 255 | + // but we get flaky failures with border and reflection padding modes. |
| 256 | + // Need to investigate and fix it. |
| 257 | + scale = (scale <= 0.5) ? 0 : 1; |
| 258 | + } |
| 259 | + scales[input_dim] = scale; |
| 260 | + } |
| 261 | + |
| 262 | + // Now that we have the bounding indices and scale factor for each dimension |
| 263 | + // of the input, we can interpolate. |
| 264 | + if (dims == 3) { |
| 265 | + *output = interpolate_linear_3d( |
| 266 | + input, input_strides, left_indices, right_indices, scales); |
| 267 | + } |
| 268 | +} |
| 269 | + |
| 270 | +template <typename T> |
| 271 | +kernel void grid_sampler( |
| 272 | + device T* output [[buffer(0)]], |
| 273 | + constant T* input [[buffer(1)]], |
| 274 | + constant T* grid [[buffer(2)]], |
| 275 | + constant GridSamplerParams<5>& params [[buffer(3)]], |
| 276 | + uint tid [[thread_position_in_grid]]) { |
| 277 | + auto output_sizes = params.output_sizes.data(); |
| 278 | + auto output_strides = params.output_strides.data(); |
| 279 | + auto input_sizes = params.input_sizes.data(); |
| 280 | + auto input_strides = params.input_strides.data(); |
| 281 | + auto grid_sizes = params.grid_sizes.data(); |
| 282 | + auto grid_strides = params.grid_strides.data(); |
| 283 | + auto sampler_dims = params.sampler_dims; |
| 284 | + |
| 285 | + auto offsets = find_grid_sampler_offsets( |
| 286 | + output_sizes, |
| 287 | + output_strides, |
| 288 | + input_sizes, |
| 289 | + input_strides, |
| 290 | + grid_sizes, |
| 291 | + grid_strides, |
| 292 | + sampler_dims, |
| 293 | + tid); |
| 294 | + |
| 295 | + output += offsets.output; |
| 296 | + input += offsets.input; |
| 297 | + auto coords = grid + offsets.grid; |
| 298 | + |
| 299 | + input_sizes += 2; |
| 300 | + input_strides += 2; |
| 301 | + |
| 302 | + auto interpolation_mode = params.interpolation_mode; |
| 303 | + auto padding_mode = params.padding_mode; |
| 304 | + auto align_corners = params.align_corners; |
| 305 | + |
| 306 | + grid_sampler_single_element( |
| 307 | + output, |
| 308 | + input, |
| 309 | + coords, |
| 310 | + sampler_dims, |
| 311 | + input_sizes, |
| 312 | + input_strides, |
| 313 | + interpolation_mode, |
| 314 | + padding_mode, |
| 315 | + align_corners); |
| 316 | +} |
| 317 | + |
| 318 | +#define REGISTER_GRID_SAMPLER_OP(DTYPE) \ |
| 319 | + template [[host_name("grid_sampler_" #DTYPE)]] \ |
| 320 | + kernel void grid_sampler<DTYPE>( \ |
| 321 | + device DTYPE * output [[buffer(0)]], \ |
| 322 | + constant DTYPE * input [[buffer(1)]], \ |
| 323 | + constant DTYPE * grid [[buffer(2)]], \ |
| 324 | + constant GridSamplerParams<5> & params [[buffer(3)]], \ |
| 325 | + uint tid [[thread_position_in_grid]]); |
| 326 | + |
| 327 | +REGISTER_GRID_SAMPLER_OP(float); |
| 328 | +REGISTER_GRID_SAMPLER_OP(half); |
| 329 | +REGISTER_GRID_SAMPLER_OP(bfloat); |
0 commit comments