Skip to content

ROIPool: Support for all datatypes #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 262 additions & 88 deletions test/test_layers.py

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions torchvision/csrc/ROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
#endif

// Interface for Python
at::Tensor ROIAlign_forward(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
at::Tensor ROIAlign_forward(const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
const float spatial_scale, // The scale of the image features. ROIs will be scaled to this.
const int pooled_height, // The height of the pooled feature map.
const int pooled_width, // The width of the pooled feature
const int sampling_ratio) // The number of points to sample in each bin along each axis.
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
Expand Down Expand Up @@ -40,6 +41,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad,
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
}

250 changes: 215 additions & 35 deletions torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,46 +110,37 @@ void pre_calc_for_bilinear_interpolate(
}

template <typename T>
void ROIAlignForward_cpu_kernel(
void ROIAlignForward(
const int nthreads,
const T* bottom_data,
const T* input,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* bottom_rois,
//int roi_cols,
T* top_data) {
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
int roi_cols = 5;

const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;

// roi could have 4 or 5 columns
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == 5) {
roi_batch_ind = offset_bottom_rois[0];
offset_bottom_rois++;
}
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// T roi_start_w = round(offset_rois[0] * spatial_scale);
// T roi_start_h = round(offset_rois[1] * spatial_scale);
// T roi_end_w = round(offset_rois[2] * spatial_scale);
// T roi_end_h = round(offset_rois[3] * spatial_scale);

// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
Expand Down Expand Up @@ -188,8 +179,8 @@ void ROIAlignForward_cpu_kernel(

for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;

for (int ph = 0; ph < pooled_height; ph++) {
Expand All @@ -200,46 +191,186 @@ void ROIAlignForward_cpu_kernel(
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
pc.w4 * offset_bottom_data[pc.pos4];
output_val += pc.w1 * offset_input[pc.pos1] +
pc.w2 * offset_input[pc.pos2] +
pc.w3 * offset_input[pc.pos3] +
pc.w4 * offset_input[pc.pos4];

pre_calc_index += 1;
}
}
output_val /= count;

top_data[index] = output_val;
output[index] = output_val;
} // for pw
} // for ph
} // for c
} // for n
}

template <typename T>
void bilinear_interpolate_gradient(
const int height, const int width,
T y, T x,
T& w1, T& w2, T& w3, T& w4,
int& x_low, int& x_high, int& y_low, int& y_high,
const int index /* index for debug only*/) {

// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}

if (y <= 0) y = 0;
if (x <= 0) x = 0;

y_low = (int)y;
x_low = (int)x;

if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}

if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}

T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;

// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

return;
}

template <class T>
inline void add(T* address, const T& val) {
*address += val;
}

template <typename T>
void ROIAlignBackward(
const int nthreads,
const T* grad_output,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
T* grad_input,
const T* rois,
const int n_stride, const int c_stride,
const int h_stride, const int w_stride) {
for (int index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;

// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width);

int output_offset = n*n_stride + c*c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

for (int iy = 0; iy < roi_bin_grid_h; iy++)
{
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++)
{
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);

T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;

bilinear_interpolate_gradient(height, width, y, x,
w1, w2, w3, w4,
x_low, x_high, y_low, y_high,
index);

T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // for
} // ROIAlignBackward


at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);

at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());

auto output_size = num_rois * pooled_height * pooled_width * channels;

if (output.numel() == 0)
return output;

AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
ROIAlignForward_cpu_kernel<scalar_t>(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
ROIAlignForward<scalar_t>(
output_size,
input.data<scalar_t>(),
spatial_scale,
Expand All @@ -254,3 +385,52 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
});
return output;
}


at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio) {
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");

at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

// handle possibly empty gradients
if (grad.numel() == 0)
{
return grad_input;
}

// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
ROIAlignBackward<scalar_t>(
grad.numel(),
grad.data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
grad_input.data<scalar_t>(),
rois.data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
return grad_input;
}
Loading