Skip to content

Commit bca9ddf

Browse files
kurtamohlermarkc-614
authored andcommitted
[MPS] Add grid_sampler_3d for MPS (pytorch#160541)
This PR adds support for `grid_sampler_3d` for MPS with "bilinear" interpolation. NOTE: "nearest" interpolation is not yet supported Fixes pytorch#159882 Pull Request resolved: pytorch#160541 Approved by: https://github.com/malfet
1 parent 44b8d6c commit bca9ddf

File tree

8 files changed

+559
-1
lines changed

8 files changed

+559
-1
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
#include <c10/metal/common.h>
3+
4+
#ifdef __METAL__
5+
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
6+
enum class GridSamplerPadding { Zeros, Border, Reflection };
7+
#else
8+
#include <ATen/native/GridSamplerUtils.h>
9+
using at::native::GridSamplerInterpolation;
10+
using at::native::GridSamplerPadding;
11+
#endif
12+
13+
template <unsigned N = 5, typename idx_type_t = int32_t>
14+
struct GridSamplerParams {
15+
int32_t sampler_dims;
16+
::c10::metal::array<idx_type_t, N> output_sizes;
17+
::c10::metal::array<idx_type_t, N> output_strides;
18+
::c10::metal::array<idx_type_t, N> input_sizes;
19+
::c10::metal::array<idx_type_t, N> input_strides;
20+
::c10::metal::array<idx_type_t, N> grid_sizes;
21+
::c10::metal::array<idx_type_t, N> grid_strides;
22+
GridSamplerInterpolation interpolation_mode;
23+
GridSamplerPadding padding_mode;
24+
bool align_corners;
25+
};
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
#include <ATen/native/mps/kernels/GridSampler.h>
2+
#include <c10/metal/utils.h>
3+
#include <metal_array>
4+
#include <metal_stdlib>
5+
6+
using namespace metal;
7+
using namespace c10::metal;
8+
9+
struct GridSamplerOffsets {
10+
int32_t output;
11+
int32_t input;
12+
int32_t grid;
13+
14+
GridSamplerOffsets() : output(0), input(0), grid(0) {}
15+
};
16+
17+
// Find offsets into the tensors that this thread will operate on,
18+
// based on the thread ID.
19+
static GridSamplerOffsets find_grid_sampler_offsets(
20+
constant int32_t* output_sizes,
21+
constant int32_t* output_strides,
22+
constant int32_t* input_sizes,
23+
constant int32_t* input_strides,
24+
constant int32_t* grid_sizes,
25+
constant int32_t* grid_strides,
26+
int32_t sampler_dims,
27+
uint tid) {
28+
auto dims = sampler_dims + 2;
29+
auto output_idx = static_cast<int32_t>(tid);
30+
GridSamplerOffsets offsets;
31+
32+
for (auto dim = dims - 1; dim >= 0; dim--) {
33+
auto dim_idx = output_idx % output_sizes[dim];
34+
output_idx = output_idx / output_sizes[dim];
35+
36+
// Select the output element that this thread will calculate.
37+
// output shape:
38+
// 2 sampler dims: (N, C, Hout, Wout)
39+
// 3 sampler dims: (N, C, Dout, Hout, Wout)
40+
offsets.output += output_strides[dim] * dim_idx;
41+
42+
// Select the batch and channel for the input.
43+
// input shape:
44+
// 2 sampler dims: (N, C, Hin, Win)
45+
// 3 sampler dims: (N, C, Din, Hin, Win)
46+
if (dim < 2) {
47+
offsets.input += input_strides[dim] * dim_idx;
48+
}
49+
50+
// Select the grid coordinates for the output element.
51+
// grid shape:
52+
// 2 sampler dims: (N, Hout, Wout, 2)
53+
// 3 sampler dims: (N, Dout, Hout, Wout, 3)
54+
if (dim == 0) {
55+
offsets.grid += grid_strides[dim] * dim_idx;
56+
} else if (dim >= 2) {
57+
offsets.grid += grid_strides[dim - 1] * dim_idx;
58+
}
59+
}
60+
61+
return offsets;
62+
}
63+
64+
// Mod function which gives postive output when `a` is negative
65+
static int32_t mod(int32_t a, int32_t b) {
66+
auto r = a % b;
67+
return r + (r < 0 ? b : 0);
68+
}
69+
70+
// Sentinel index value to indicate zero padding
71+
constant int32_t IDX_ZERO = -1;
72+
73+
// Apply padding to an index into the input
74+
static int32_t pad_input_index(
75+
int32_t idx,
76+
int32_t input_size,
77+
GridSamplerPadding padding_mode,
78+
bool align_corners) {
79+
int32_t idx_padded = idx;
80+
81+
if (padding_mode == GridSamplerPadding::Zeros) {
82+
idx_padded = (idx < 0) ? IDX_ZERO : idx_padded;
83+
idx_padded = (idx >= input_size) ? IDX_ZERO : idx_padded;
84+
85+
} else if (padding_mode == GridSamplerPadding::Border) {
86+
idx_padded = (idx < 0) ? 0 : idx_padded;
87+
idx_padded = (idx >= input_size) ? input_size - 1 : idx_padded;
88+
89+
} else if (padding_mode == GridSamplerPadding::Reflection) {
90+
auto scale_length = align_corners ? (input_size - 1) : input_size;
91+
auto idx_mod = mod(idx, scale_length);
92+
auto idx_mod_reverse = (input_size - 1) - idx_mod;
93+
bool is_reverse = (abs(idx - idx_mod) / scale_length) % 2 == 1;
94+
idx_padded = is_reverse ? idx_mod_reverse : idx_mod;
95+
}
96+
return idx_padded;
97+
}
98+
99+
template <int32_t dims, typename T>
100+
T get_tensor_val(
101+
constant T* input,
102+
constant int32_t* input_strides,
103+
int32_t indices[dims]) {
104+
bool found_idx_zero = false;
105+
int32_t offset = 0;
106+
107+
for (auto dim = 0; dim < dims; dim++) {
108+
auto idx = indices[dim];
109+
found_idx_zero = found_idx_zero || (idx == IDX_ZERO);
110+
offset += (found_idx_zero ? 0 : idx) * input_strides[dim];
111+
}
112+
113+
return found_idx_zero ? 0 : input[offset];
114+
}
115+
116+
// This function performs 3D linear interpolation for one value. One way to
117+
// think of how this works is to imagine a unit cube where each corner of the
118+
// cube has one scalar value associated with it. Inside the cube, the values
119+
// change linearly, so the gradient is constant. The values associated with each
120+
// corner are given by the `input`, indexed at all eight different combinations
121+
// of the `left_indices` and `right_indices`. Given a 3D coordinate anywhere
122+
// within the cube, specified by the `scales` argument, we must calculate the
123+
// value associated with that position.
124+
template <typename T>
125+
T interpolate_linear_3d(
126+
constant T* input,
127+
constant int32_t* input_strides,
128+
int32_t left_indices[3],
129+
int32_t right_indices[3],
130+
opmath_t<T> scales[3]) {
131+
int32_t a_idx[3] = {left_indices[0], left_indices[1], left_indices[2]};
132+
int32_t b_idx[3] = {left_indices[0], left_indices[1], right_indices[2]};
133+
int32_t c_idx[3] = {left_indices[0], right_indices[1], left_indices[2]};
134+
int32_t d_idx[3] = {left_indices[0], right_indices[1], right_indices[2]};
135+
int32_t e_idx[3] = {right_indices[0], left_indices[1], left_indices[2]};
136+
int32_t f_idx[3] = {right_indices[0], left_indices[1], right_indices[2]};
137+
int32_t g_idx[3] = {right_indices[0], right_indices[1], left_indices[2]};
138+
int32_t h_idx[3] = {right_indices[0], right_indices[1], right_indices[2]};
139+
auto a =
140+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, a_idx));
141+
auto b =
142+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, b_idx));
143+
auto c =
144+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, c_idx));
145+
auto d =
146+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, d_idx));
147+
auto e =
148+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, e_idx));
149+
auto f =
150+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, f_idx));
151+
auto g =
152+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, g_idx));
153+
auto h =
154+
static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, h_idx));
155+
156+
auto scale0_right = scales[0];
157+
auto scale1_right = scales[1];
158+
auto scale2_right = scales[2];
159+
auto scale0_left = 1 - scale0_right;
160+
auto scale1_left = 1 - scale1_right;
161+
auto scale2_left = 1 - scale2_right;
162+
163+
return static_cast<T>(
164+
scale0_left * scale1_left * scale2_left * a +
165+
scale0_left * scale1_left * scale2_right * b +
166+
scale0_left * scale1_right * scale2_left * c +
167+
scale0_left * scale1_right * scale2_right * d +
168+
scale0_right * scale1_left * scale2_left * e +
169+
scale0_right * scale1_left * scale2_right * f +
170+
scale0_right * scale1_right * scale2_left * g +
171+
scale0_right * scale1_right * scale2_right * h);
172+
}
173+
174+
// Calculates a single output element.
175+
// `input` shape:
176+
// 2 sampler dims: (Hin, Win)
177+
// 3 sampler dims: (Din, Hin, Win)
178+
// `coords` values:
179+
// 2 sampler dims: (Wcoord, Hcoord)
180+
// 3 sampler dims: (Wcoord, Hcoord, Dcoord)
181+
template <typename T>
182+
void grid_sampler_single_element(
183+
device T* output,
184+
constant T* input,
185+
constant T* coords,
186+
int32_t dims,
187+
constant int32_t* input_sizes,
188+
constant int32_t* input_strides,
189+
GridSamplerInterpolation interpolation_mode,
190+
GridSamplerPadding padding_mode,
191+
bool align_corners) {
192+
int32_t left_indices[3];
193+
int32_t right_indices[3];
194+
opmath_t<T> scales[3];
195+
196+
// For each dimension, find the pair of indices in the cooresponding dimension
197+
// of `input` which surround the grid coordinate in that dimension. We'll do
198+
// this by mapping different coordiante spaces onto each other. There are
199+
// basically three different coordinate spaces to keep in mind:
200+
//
201+
// * aligned grid space
202+
// - `-1` refers to the leftmost input value.
203+
// - `1` refers to the rightmost input value.
204+
//
205+
// * unaligned grid space
206+
// - `-1` refers to the midpoint between the leftmost input value and
207+
// a padding value to the left of that.
208+
// - `1` refers to the midpoint between the rightmost input value and
209+
// a padding value to the right of that.
210+
//
211+
// * input index space
212+
// - `n` refers to the n-th value of the input.
213+
// - `0` refers to the leftmost input value.
214+
// - `N-1` refers to the rightmost input value.
215+
//
216+
// If `align_corners == False`, then the coordinates are is in unaligned grid
217+
// space, and we will map it onto aligned grid space. If `align_corners ==
218+
// True`, then coordinates are already in aligned grid space.
219+
//
220+
// Then we will map unaligned grid space onto input index space, making it
221+
// relatively simple to find the two input indices that surround the
222+
// coordinate.
223+
for (auto coord_dim = 0; coord_dim < dims; coord_dim++) {
224+
auto input_dim = dims - coord_dim - 1;
225+
auto input_size = input_sizes[input_dim];
226+
auto coord = static_cast<opmath_t<T>>(coords[coord_dim]);
227+
228+
// Interpret nan as -1
229+
coord = isnan(coord) ? -1 : coord;
230+
231+
if (!align_corners) {
232+
// Map unaligned grid space to aligned grid space
233+
auto corner_alignment_factor = static_cast<opmath_t<T>>(input_size) /
234+
static_cast<opmath_t<T>>(input_size - 1);
235+
coord = coord * corner_alignment_factor;
236+
}
237+
238+
// Map aligned grid space to input index space
239+
coord = (coord + 1) * (static_cast<opmath_t<T>>(input_size - 1) / 2);
240+
241+
// Get the input indices surrounding the coordinate, apply padding to them,
242+
// and obtain the scaling factor between the two for interpolation.
243+
auto left_idx = static_cast<int32_t>(floor(coord));
244+
auto right_idx = static_cast<int32_t>(ceil(coord));
245+
left_indices[input_dim] =
246+
pad_input_index(left_idx, input_size, padding_mode, align_corners);
247+
right_indices[input_dim] =
248+
pad_input_index(right_idx, input_size, padding_mode, align_corners);
249+
250+
auto scale = coord - left_idx;
251+
252+
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
253+
// TODO: For some reason, rounding the scale to 0 or 1 and then using
254+
// linear interpolation seems to work perfectly with zero padding mode,
255+
// but we get flaky failures with border and reflection padding modes.
256+
// Need to investigate and fix it.
257+
scale = (scale <= 0.5) ? 0 : 1;
258+
}
259+
scales[input_dim] = scale;
260+
}
261+
262+
// Now that we have the bounding indices and scale factor for each dimension
263+
// of the input, we can interpolate.
264+
if (dims == 3) {
265+
*output = interpolate_linear_3d(
266+
input, input_strides, left_indices, right_indices, scales);
267+
}
268+
}
269+
270+
template <typename T>
271+
kernel void grid_sampler(
272+
device T* output [[buffer(0)]],
273+
constant T* input [[buffer(1)]],
274+
constant T* grid [[buffer(2)]],
275+
constant GridSamplerParams<5>& params [[buffer(3)]],
276+
uint tid [[thread_position_in_grid]]) {
277+
auto output_sizes = params.output_sizes.data();
278+
auto output_strides = params.output_strides.data();
279+
auto input_sizes = params.input_sizes.data();
280+
auto input_strides = params.input_strides.data();
281+
auto grid_sizes = params.grid_sizes.data();
282+
auto grid_strides = params.grid_strides.data();
283+
auto sampler_dims = params.sampler_dims;
284+
285+
auto offsets = find_grid_sampler_offsets(
286+
output_sizes,
287+
output_strides,
288+
input_sizes,
289+
input_strides,
290+
grid_sizes,
291+
grid_strides,
292+
sampler_dims,
293+
tid);
294+
295+
output += offsets.output;
296+
input += offsets.input;
297+
auto coords = grid + offsets.grid;
298+
299+
input_sizes += 2;
300+
input_strides += 2;
301+
302+
auto interpolation_mode = params.interpolation_mode;
303+
auto padding_mode = params.padding_mode;
304+
auto align_corners = params.align_corners;
305+
306+
grid_sampler_single_element(
307+
output,
308+
input,
309+
coords,
310+
sampler_dims,
311+
input_sizes,
312+
input_strides,
313+
interpolation_mode,
314+
padding_mode,
315+
align_corners);
316+
}
317+
318+
#define REGISTER_GRID_SAMPLER_OP(DTYPE) \
319+
template [[host_name("grid_sampler_" #DTYPE)]] \
320+
kernel void grid_sampler<DTYPE>( \
321+
device DTYPE * output [[buffer(0)]], \
322+
constant DTYPE * input [[buffer(1)]], \
323+
constant DTYPE * grid [[buffer(2)]], \
324+
constant GridSamplerParams<5> & params [[buffer(3)]], \
325+
uint tid [[thread_position_in_grid]]);
326+
327+
REGISTER_GRID_SAMPLER_OP(float);
328+
REGISTER_GRID_SAMPLER_OP(half);
329+
REGISTER_GRID_SAMPLER_OP(bfloat);

0 commit comments

Comments
 (0)