12
12
13
13
14
14
template <typename T>
15
- __device__ T bilinear_interpolate (const T* bottom_data ,
15
+ __device__ T bilinear_interpolate (const T* input ,
16
16
const int height, const int width,
17
17
T y, T x,
18
18
const int index /* index for debug only*/ ) {
@@ -48,11 +48,12 @@ __device__ T bilinear_interpolate(const T* bottom_data,
48
48
T ly = y - y_low;
49
49
T lx = x - x_low;
50
50
T hy = 1 . - ly, hx = 1 . - lx;
51
+
51
52
// do bilinear interpolation
52
- T v1 = bottom_data [y_low * width + x_low];
53
- T v2 = bottom_data [y_low * width + x_high];
54
- T v3 = bottom_data [y_high * width + x_low];
55
- T v4 = bottom_data [y_high * width + x_high];
53
+ T v1 = input [y_low * width + x_low];
54
+ T v2 = input [y_low * width + x_high];
55
+ T v3 = input [y_high * width + x_low];
56
+ T v4 = input [y_high * width + x_high];
56
57
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
57
58
58
59
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
@@ -61,39 +62,35 @@ __device__ T bilinear_interpolate(const T* bottom_data,
61
62
}
62
63
63
64
template <typename T>
64
- __global__ void RoIAlignForward (const int nthreads, const T* bottom_data ,
65
+ __global__ void RoIAlignForward (const int nthreads, const T* input ,
65
66
const T spatial_scale, const int channels,
66
67
const int height, const int width,
67
68
const int pooled_height, const int pooled_width,
68
69
const int sampling_ratio,
69
- const T* bottom_rois , T* top_data ) {
70
+ const T* rois , T* output ) {
70
71
CUDA_1D_KERNEL_LOOP (index, nthreads) {
71
72
// (n, c, ph, pw) is an element in the pooled output
72
73
int pw = index % pooled_width;
73
74
int ph = (index / pooled_width) % pooled_height;
74
75
int c = (index / pooled_width / pooled_height) % channels;
75
76
int n = index / pooled_width / pooled_height / channels;
76
77
77
- const T* offset_bottom_rois = bottom_rois + n * 5 ;
78
- int roi_batch_ind = offset_bottom_rois [0 ];
78
+ const T* offset_rois = rois + n * 5 ;
79
+ int roi_batch_ind = offset_rois [0 ];
79
80
80
81
// Do not using rounding; this implementation detail is critical
81
- T roi_start_w = offset_bottom_rois[1 ] * spatial_scale;
82
- T roi_start_h = offset_bottom_rois[2 ] * spatial_scale;
83
- T roi_end_w = offset_bottom_rois[3 ] * spatial_scale;
84
- T roi_end_h = offset_bottom_rois[4 ] * spatial_scale;
85
- // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
86
- // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
87
- // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
88
- // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
82
+ T roi_start_w = offset_rois[1 ] * spatial_scale;
83
+ T roi_start_h = offset_rois[2 ] * spatial_scale;
84
+ T roi_end_w = offset_rois[3 ] * spatial_scale;
85
+ T roi_end_h = offset_rois[4 ] * spatial_scale;
89
86
90
87
// Force malformed ROIs to be 1x1
91
88
T roi_width = max (roi_end_w - roi_start_w, (T)1 .);
92
89
T roi_height = max (roi_end_h - roi_start_h, (T)1 .);
93
90
T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
94
91
T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
95
92
96
- const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
93
+ const T* offset_input = input + (roi_batch_ind * channels + c) * height * width;
97
94
98
95
// We use roi_bin_grid to sample the grid and mimic integral
99
96
int roi_bin_grid_h = (sampling_ratio > 0 ) ? sampling_ratio : ceil (roi_height / pooled_height); // e.g., = 2
@@ -110,13 +107,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
110
107
{
111
108
const T x = roi_start_w + pw * bin_size_w + static_cast <T>(ix + .5f ) * bin_size_w / static_cast <T>(roi_bin_grid_w);
112
109
113
- T val = bilinear_interpolate (offset_bottom_data , height, width, y, x, index);
110
+ T val = bilinear_interpolate (offset_input , height, width, y, x, index);
114
111
output_val += val;
115
112
}
116
113
}
117
114
output_val /= count;
118
115
119
- top_data [index] = output_val;
116
+ output [index] = output_val;
120
117
}
121
118
}
122
119
@@ -162,10 +159,10 @@ __device__ void bilinear_interpolate_gradient(
162
159
T hy = 1 . - ly, hx = 1 . - lx;
163
160
164
161
// reference in forward
165
- // T v1 = bottom_data [y_low * width + x_low];
166
- // T v2 = bottom_data [y_low * width + x_high];
167
- // T v3 = bottom_data [y_high * width + x_low];
168
- // T v4 = bottom_data [y_high * width + x_high];
162
+ // T v1 = input [y_low * width + x_low];
163
+ // T v2 = input [y_low * width + x_high];
164
+ // T v3 = input [y_high * width + x_low];
165
+ // T v4 = input [y_high * width + x_high];
169
166
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
170
167
171
168
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
@@ -174,44 +171,42 @@ __device__ void bilinear_interpolate_gradient(
174
171
}
175
172
176
173
template <typename T>
177
- __global__ void RoIAlignBackwardFeature (const int nthreads, const T* top_diff ,
174
+ __global__ void RoIAlignBackwardFeature (const int nthreads, const T* grad_output ,
178
175
const int num_rois, const T spatial_scale,
179
176
const int channels, const int height, const int width,
180
177
const int pooled_height, const int pooled_width,
181
178
const int sampling_ratio,
182
- T* bottom_diff,
183
- const T* bottom_rois) {
179
+ T* grad_input,
180
+ const T* rois,
181
+ const int n_stride, const int c_stride,
182
+ const int h_stride, const int w_stride) {
184
183
CUDA_1D_KERNEL_LOOP (index, nthreads) {
185
184
// (n, c, ph, pw) is an element in the pooled output
186
185
int pw = index % pooled_width;
187
186
int ph = (index / pooled_width) % pooled_height;
188
187
int c = (index / pooled_width / pooled_height) % channels;
189
188
int n = index / pooled_width / pooled_height / channels;
190
189
191
- const T* offset_bottom_rois = bottom_rois + n * 5 ;
192
- int roi_batch_ind = offset_bottom_rois [0 ];
190
+ const T* offset_rois = rois + n * 5 ;
191
+ int roi_batch_ind = offset_rois [0 ];
193
192
194
193
// Do not using rounding; this implementation detail is critical
195
- T roi_start_w = offset_bottom_rois[1 ] * spatial_scale;
196
- T roi_start_h = offset_bottom_rois[2 ] * spatial_scale;
197
- T roi_end_w = offset_bottom_rois[3 ] * spatial_scale;
198
- T roi_end_h = offset_bottom_rois[4 ] * spatial_scale;
199
- // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
200
- // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
201
- // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
202
- // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
203
-
194
+ T roi_start_w = offset_rois[1 ] * spatial_scale;
195
+ T roi_start_h = offset_rois[2 ] * spatial_scale;
196
+ T roi_end_w = offset_rois[3 ] * spatial_scale;
197
+ T roi_end_h = offset_rois[4 ] * spatial_scale;
198
+
204
199
// Force malformed ROIs to be 1x1
205
200
T roi_width = max (roi_end_w - roi_start_w, (T)1 .);
206
201
T roi_height = max (roi_end_h - roi_start_h, (T)1 .);
207
202
T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
208
203
T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
209
204
210
- T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
205
+ T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width;
211
206
212
- int top_offset = (n * channels + c) * pooled_height * pooled_width;
213
- const T* offset_top_diff = top_diff + top_offset;
214
- const T top_diff_this_bin = offset_top_diff [ph * pooled_width + pw];
207
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
208
+ const T* offset_grad_output = grad_output + top_offset;
209
+ const T grad_output_this_bin = offset_grad_output [ph * pooled_width + pw];
215
210
216
211
// We use roi_bin_grid to sample the grid and mimic integral
217
212
int roi_bin_grid_h = (sampling_ratio > 0 ) ? sampling_ratio : ceil (roi_height / pooled_height); // e.g., = 2
@@ -235,17 +230,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
235
230
x_low, x_high, y_low, y_high,
236
231
index);
237
232
238
- T g1 = top_diff_this_bin * w1 / count;
239
- T g2 = top_diff_this_bin * w2 / count;
240
- T g3 = top_diff_this_bin * w3 / count;
241
- T g4 = top_diff_this_bin * w4 / count;
233
+ T g1 = grad_output_this_bin * w1 / count;
234
+ T g2 = grad_output_this_bin * w2 / count;
235
+ T g3 = grad_output_this_bin * w3 / count;
236
+ T g4 = grad_output_this_bin * w4 / count;
242
237
243
238
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0 )
244
239
{
245
- atomicAdd (offset_bottom_diff + y_low * width + x_low, static_cast <T>(g1));
246
- atomicAdd (offset_bottom_diff + y_low * width + x_high, static_cast <T>(g2));
247
- atomicAdd (offset_bottom_diff + y_high * width + x_low, static_cast <T>(g3));
248
- atomicAdd (offset_bottom_diff + y_high * width + x_high, static_cast <T>(g4));
240
+ atomicAdd (offset_grad_input + y_low * width + x_low, static_cast <T>(g1));
241
+ atomicAdd (offset_grad_input + y_low * width + x_high, static_cast <T>(g2));
242
+ atomicAdd (offset_grad_input + y_high * width + x_low, static_cast <T>(g3));
243
+ atomicAdd (offset_grad_input + y_high * width + x_high, static_cast <T>(g4));
249
244
} // if
250
245
} // ix
251
246
} // iy
@@ -326,6 +321,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
326
321
return grad_input;
327
322
}
328
323
324
+ int n_stride = grad.stride (0 );
325
+ int c_stride = grad.stride (1 );
326
+ int h_stride = grad.stride (2 );
327
+ int w_stride = grad.stride (3 );
328
+
329
329
AT_DISPATCH_FLOATING_TYPES (grad.type (), " ROIAlign_backward" , [&] {
330
330
RoIAlignBackwardFeature<scalar_t ><<<grid, block, 0 , stream>>> (
331
331
grad.numel (),
@@ -339,7 +339,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
339
339
pooled_width,
340
340
sampling_ratio,
341
341
grad_input.data <scalar_t >(),
342
- rois.data <scalar_t >());
342
+ rois.data <scalar_t >(),
343
+ n_stride,
344
+ c_stride,
345
+ h_stride,
346
+ w_stride);
343
347
});
344
348
THCudaCheck (cudaGetLastError ());
345
349
return grad_input;
0 commit comments