diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 3850b2833ab..7ed01733d9d 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -112,7 +112,7 @@ void pre_calc_for_bilinear_interpolate( template void ROIAlignForward_cpu_kernel( const int nthreads, - const T* bottom_data, + const T* input, const T& spatial_scale, const int channels, const int height, @@ -120,9 +120,9 @@ void ROIAlignForward_cpu_kernel( const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, + const T* rois, //int roi_cols, - T* top_data) { + T* output) { //AT_ASSERT(roi_cols == 4 || roi_cols == 5); int roi_cols = 5; @@ -134,22 +134,22 @@ void ROIAlignForward_cpu_kernel( int index_n = n * channels * pooled_width * pooled_height; // roi could have 4 or 5 columns - const T* offset_bottom_rois = bottom_rois + n * roi_cols; + const T* offset_rois = rois + n * roi_cols; int roi_batch_ind = 0; if (roi_cols == 5) { - roi_batch_ind = offset_bottom_rois[0]; - offset_bottom_rois++; + roi_batch_ind = offset_rois[0]; + offset_rois++; } // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[0] * spatial_scale; - T roi_start_h = offset_bottom_rois[1] * spatial_scale; - T roi_end_w = offset_bottom_rois[2] * spatial_scale; - T roi_end_h = offset_bottom_rois[3] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + T roi_start_w = offset_rois[0] * spatial_scale; + T roi_start_h = offset_rois[1] * spatial_scale; + T roi_end_w = offset_rois[2] * spatial_scale; + T roi_end_h = offset_rois[3] * spatial_scale; + // T roi_start_w = round(offset_rois[0] * spatial_scale); + // T roi_start_h = round(offset_rois[1] * spatial_scale); + // T roi_end_w = round(offset_rois[2] * spatial_scale); + // T roi_end_h = round(offset_rois[3] * spatial_scale); // Force malformed ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); @@ -188,8 +188,8 @@ void ROIAlignForward_cpu_kernel( for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { @@ -200,17 +200,17 @@ void ROIAlignForward_cpu_kernel( for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_bottom_data[pc.pos1] + - pc.w2 * offset_bottom_data[pc.pos2] + - pc.w3 * offset_bottom_data[pc.pos3] + - pc.w4 * offset_bottom_data[pc.pos4]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; pre_calc_index += 1; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } // for pw } // for ph } // for c diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index 8ae35930533..eba66c18a14 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -16,8 +16,8 @@ std::tuple ROIPool_forward_cpu(const at::Tensor &input, int input_height = input.size(2); int input_width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); // define accessors for indexing auto input_a = input.accessor(); @@ -107,7 +107,7 @@ at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); // handle possibly empty gradients if (grad.numel() == 0) diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 9cc5ae28934..adbec2fa1bb 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -5,14 +5,11 @@ #include #include -// TODO make it in a common file -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) +#include "cuda_helpers.h" template -__device__ T bilinear_interpolate(const T* bottom_data, +__device__ T bilinear_interpolate(const T* input, const int height, const int width, T y, T x, const int index /* index for debug only*/) { @@ -48,11 +45,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, T ly = y - y_low; T lx = x - x_low; T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); @@ -61,12 +59,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, } template -__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, +__global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, T* top_data) { + const T* rois, T* output) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -74,18 +72,14 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); @@ -93,7 +87,7 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -110,13 +104,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, { const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); output_val += val; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } } @@ -162,10 +156,10 @@ __device__ void bilinear_interpolate_gradient( T hy = 1. - ly, hx = 1. - lx; // reference in forward - // T v1 = bottom_data[y_low * width + x_low]; - // T v2 = bottom_data[y_low * width + x_high]; - // T v3 = bottom_data[y_high * width + x_low]; - // T v4 = bottom_data[y_high * width + x_high]; + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; @@ -174,13 +168,15 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, +__global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - T* bottom_diff, - const T* bottom_rois) { + T* grad_input, + const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -188,30 +184,26 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); T roi_height = max(roi_end_h - roi_start_h, (T)1.); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width; - int top_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_top_diff = top_diff + top_offset; - const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_grad_output = grad_output + top_offset; + const T grad_output_this_bin = offset_grad_output[ph * pooled_width + pw]; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -235,17 +227,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, x_low, x_high, y_low, y_high, index); - T g1 = top_diff_this_bin * w1 / count; - T g2 = top_diff_this_bin * w2 / count; - T g3 = top_diff_this_bin * w3 / count; - T g4 = top_diff_this_bin * w4 / count; + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); - atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); - atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); - atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + atomicAdd(offset_grad_input + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_grad_input + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_grad_input + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_grad_input + y_high * width + x_high, static_cast(g4)); } // if } // ix } // iy @@ -326,8 +318,13 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, return grad_input; } + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { - RoIAlignBackwardFeature<<>>( + RoIAlignBackward<<>>( grad.numel(), grad.data(), num_rois, @@ -339,7 +336,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, pooled_width, sampling_ratio, grad_input.data(), - rois.data()); + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); }); THCudaCheck(cudaGetLastError()); return grad_input; diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 2ba8dc33e25..514400b0267 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -108,16 +108,16 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, const float spatial_scale, const int pooled_height, const int pooled_width) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -159,13 +159,13 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, const int height, const int width) { // Check if input tensors are CUDA tensors - AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); - AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(argmax.device().is_cuda(), "argmax must be a CUDA tensor"); auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream();