Skip to content

Commit 06c4018

Browse files
committed
Use of torch7 naming scheme for ROIAlign forward and backward
1 parent f2a3ec8 commit 06c4018

File tree

2 files changed

+76
-72
lines changed

2 files changed

+76
-72
lines changed

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,17 @@ void pre_calc_for_bilinear_interpolate(
112112
template <typename T>
113113
void ROIAlignForward_cpu_kernel(
114114
const int nthreads,
115-
const T* bottom_data,
115+
const T* input,
116116
const T& spatial_scale,
117117
const int channels,
118118
const int height,
119119
const int width,
120120
const int pooled_height,
121121
const int pooled_width,
122122
const int sampling_ratio,
123-
const T* bottom_rois,
123+
const T* rois,
124124
//int roi_cols,
125-
T* top_data) {
125+
T* output) {
126126
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
127127
int roi_cols = 5;
128128

@@ -134,22 +134,22 @@ void ROIAlignForward_cpu_kernel(
134134
int index_n = n * channels * pooled_width * pooled_height;
135135

136136
// roi could have 4 or 5 columns
137-
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
137+
const T* offset_rois = rois + n * roi_cols;
138138
int roi_batch_ind = 0;
139139
if (roi_cols == 5) {
140-
roi_batch_ind = offset_bottom_rois[0];
141-
offset_bottom_rois++;
140+
roi_batch_ind = offset_rois[0];
141+
offset_rois++;
142142
}
143143

144144
// Do not using rounding; this implementation detail is critical
145-
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
146-
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
147-
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
148-
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
149-
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
150-
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
151-
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
152-
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
145+
T roi_start_w = offset_rois[0] * spatial_scale;
146+
T roi_start_h = offset_rois[1] * spatial_scale;
147+
T roi_end_w = offset_rois[2] * spatial_scale;
148+
T roi_end_h = offset_rois[3] * spatial_scale;
149+
// T roi_start_w = round(offset_rois[0] * spatial_scale);
150+
// T roi_start_h = round(offset_rois[1] * spatial_scale);
151+
// T roi_end_w = round(offset_rois[2] * spatial_scale);
152+
// T roi_end_h = round(offset_rois[3] * spatial_scale);
153153

154154
// Force malformed ROIs to be 1x1
155155
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
@@ -188,8 +188,8 @@ void ROIAlignForward_cpu_kernel(
188188

189189
for (int c = 0; c < channels; c++) {
190190
int index_n_c = index_n + c * pooled_width * pooled_height;
191-
const T* offset_bottom_data =
192-
bottom_data + (roi_batch_ind * channels + c) * height * width;
191+
const T* offset_input =
192+
input + (roi_batch_ind * channels + c) * height * width;
193193
int pre_calc_index = 0;
194194

195195
for (int ph = 0; ph < pooled_height; ph++) {
@@ -200,17 +200,17 @@ void ROIAlignForward_cpu_kernel(
200200
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
201201
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
202202
PreCalc<T> pc = pre_calc[pre_calc_index];
203-
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
204-
pc.w2 * offset_bottom_data[pc.pos2] +
205-
pc.w3 * offset_bottom_data[pc.pos3] +
206-
pc.w4 * offset_bottom_data[pc.pos4];
203+
output_val += pc.w1 * offset_input[pc.pos1] +
204+
pc.w2 * offset_input[pc.pos2] +
205+
pc.w3 * offset_input[pc.pos3] +
206+
pc.w4 * offset_input[pc.pos4];
207207

208208
pre_calc_index += 1;
209209
}
210210
}
211211
output_val /= count;
212212

213-
top_data[index] = output_val;
213+
output[index] = output_val;
214214
} // for pw
215215
} // for ph
216216
} // for c

torchvision/csrc/cuda/ROIAlign_cuda.cu

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
template <typename T>
15-
__device__ T bilinear_interpolate(const T* bottom_data,
15+
__device__ T bilinear_interpolate(const T* input,
1616
const int height, const int width,
1717
T y, T x,
1818
const int index /* index for debug only*/) {
@@ -48,11 +48,12 @@ __device__ T bilinear_interpolate(const T* bottom_data,
4848
T ly = y - y_low;
4949
T lx = x - x_low;
5050
T hy = 1. - ly, hx = 1. - lx;
51+
5152
// do bilinear interpolation
52-
T v1 = bottom_data[y_low * width + x_low];
53-
T v2 = bottom_data[y_low * width + x_high];
54-
T v3 = bottom_data[y_high * width + x_low];
55-
T v4 = bottom_data[y_high * width + x_high];
53+
T v1 = input[y_low * width + x_low];
54+
T v2 = input[y_low * width + x_high];
55+
T v3 = input[y_high * width + x_low];
56+
T v4 = input[y_high * width + x_high];
5657
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
5758

5859
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
@@ -61,39 +62,35 @@ __device__ T bilinear_interpolate(const T* bottom_data,
6162
}
6263

6364
template <typename T>
64-
__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
65+
__global__ void RoIAlignForward(const int nthreads, const T* input,
6566
const T spatial_scale, const int channels,
6667
const int height, const int width,
6768
const int pooled_height, const int pooled_width,
6869
const int sampling_ratio,
69-
const T* bottom_rois, T* top_data) {
70+
const T* rois, T* output) {
7071
CUDA_1D_KERNEL_LOOP(index, nthreads) {
7172
// (n, c, ph, pw) is an element in the pooled output
7273
int pw = index % pooled_width;
7374
int ph = (index / pooled_width) % pooled_height;
7475
int c = (index / pooled_width / pooled_height) % channels;
7576
int n = index / pooled_width / pooled_height / channels;
7677

77-
const T* offset_bottom_rois = bottom_rois + n * 5;
78-
int roi_batch_ind = offset_bottom_rois[0];
78+
const T* offset_rois = rois + n * 5;
79+
int roi_batch_ind = offset_rois[0];
7980

8081
// Do not using rounding; this implementation detail is critical
81-
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
82-
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
83-
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
84-
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
85-
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
86-
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
87-
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
88-
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
82+
T roi_start_w = offset_rois[1] * spatial_scale;
83+
T roi_start_h = offset_rois[2] * spatial_scale;
84+
T roi_end_w = offset_rois[3] * spatial_scale;
85+
T roi_end_h = offset_rois[4] * spatial_scale;
8986

9087
// Force malformed ROIs to be 1x1
9188
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
9289
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
9390
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
9491
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
9592

96-
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
93+
const T* offset_input = input + (roi_batch_ind * channels + c) * height * width;
9794

9895
// We use roi_bin_grid to sample the grid and mimic integral
9996
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
@@ -110,13 +107,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
110107
{
111108
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
112109

113-
T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
110+
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
114111
output_val += val;
115112
}
116113
}
117114
output_val /= count;
118115

119-
top_data[index] = output_val;
116+
output[index] = output_val;
120117
}
121118
}
122119

@@ -162,10 +159,10 @@ __device__ void bilinear_interpolate_gradient(
162159
T hy = 1. - ly, hx = 1. - lx;
163160

164161
// reference in forward
165-
// T v1 = bottom_data[y_low * width + x_low];
166-
// T v2 = bottom_data[y_low * width + x_high];
167-
// T v3 = bottom_data[y_high * width + x_low];
168-
// T v4 = bottom_data[y_high * width + x_high];
162+
// T v1 = input[y_low * width + x_low];
163+
// T v2 = input[y_low * width + x_high];
164+
// T v3 = input[y_high * width + x_low];
165+
// T v4 = input[y_high * width + x_high];
169166
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
170167

171168
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
@@ -174,44 +171,42 @@ __device__ void bilinear_interpolate_gradient(
174171
}
175172

176173
template <typename T>
177-
__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
174+
__global__ void RoIAlignBackwardFeature(const int nthreads, const T* grad_output,
178175
const int num_rois, const T spatial_scale,
179176
const int channels, const int height, const int width,
180177
const int pooled_height, const int pooled_width,
181178
const int sampling_ratio,
182-
T* bottom_diff,
183-
const T* bottom_rois) {
179+
T* grad_input,
180+
const T* rois,
181+
const int n_stride, const int c_stride,
182+
const int h_stride, const int w_stride) {
184183
CUDA_1D_KERNEL_LOOP(index, nthreads) {
185184
// (n, c, ph, pw) is an element in the pooled output
186185
int pw = index % pooled_width;
187186
int ph = (index / pooled_width) % pooled_height;
188187
int c = (index / pooled_width / pooled_height) % channels;
189188
int n = index / pooled_width / pooled_height / channels;
190189

191-
const T* offset_bottom_rois = bottom_rois + n * 5;
192-
int roi_batch_ind = offset_bottom_rois[0];
190+
const T* offset_rois = rois + n * 5;
191+
int roi_batch_ind = offset_rois[0];
193192

194193
// Do not using rounding; this implementation detail is critical
195-
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
196-
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
197-
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
198-
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
199-
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
200-
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
201-
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
202-
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
203-
194+
T roi_start_w = offset_rois[1] * spatial_scale;
195+
T roi_start_h = offset_rois[2] * spatial_scale;
196+
T roi_end_w = offset_rois[3] * spatial_scale;
197+
T roi_end_h = offset_rois[4] * spatial_scale;
198+
204199
// Force malformed ROIs to be 1x1
205200
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
206201
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
207202
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
208203
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
209204

210-
T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
205+
T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width;
211206

212-
int top_offset = (n * channels + c) * pooled_height * pooled_width;
213-
const T* offset_top_diff = top_diff + top_offset;
214-
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
207+
int top_offset = (n * channels + c) * pooled_height * pooled_width;
208+
const T* offset_grad_output = grad_output + top_offset;
209+
const T grad_output_this_bin = offset_grad_output[ph * pooled_width + pw];
215210

216211
// We use roi_bin_grid to sample the grid and mimic integral
217212
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
@@ -235,17 +230,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
235230
x_low, x_high, y_low, y_high,
236231
index);
237232

238-
T g1 = top_diff_this_bin * w1 / count;
239-
T g2 = top_diff_this_bin * w2 / count;
240-
T g3 = top_diff_this_bin * w3 / count;
241-
T g4 = top_diff_this_bin * w4 / count;
233+
T g1 = grad_output_this_bin * w1 / count;
234+
T g2 = grad_output_this_bin * w2 / count;
235+
T g3 = grad_output_this_bin * w3 / count;
236+
T g4 = grad_output_this_bin * w4 / count;
242237

243238
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
244239
{
245-
atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
246-
atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
247-
atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
248-
atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
240+
atomicAdd(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
241+
atomicAdd(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
242+
atomicAdd(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
243+
atomicAdd(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
249244
} // if
250245
} // ix
251246
} // iy
@@ -326,6 +321,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
326321
return grad_input;
327322
}
328323

324+
int n_stride = grad.stride(0);
325+
int c_stride = grad.stride(1);
326+
int h_stride = grad.stride(2);
327+
int w_stride = grad.stride(3);
328+
329329
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
330330
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
331331
grad.numel(),
@@ -339,7 +339,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
339339
pooled_width,
340340
sampling_ratio,
341341
grad_input.data<scalar_t>(),
342-
rois.data<scalar_t>());
342+
rois.data<scalar_t>(),
343+
n_stride,
344+
c_stride,
345+
h_stride,
346+
w_stride);
343347
});
344348
THCudaCheck(cudaGetLastError());
345349
return grad_input;

0 commit comments

Comments
 (0)