6363// modified from
6464// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
6565
66+ #include < cuda_fp16.h>
67+
6668#include " common_cuda_helper.hpp"
6769
6870template <typename scalar_t >
6971__device__ __forceinline__ scalar_t deformable_im2col_bilinear (const scalar_t * __restrict__ input,
7072 const int height, const int width,
71- scalar_t h, scalar_t w) {
72- if (h <= -1 . f || height <= h || w <= -1 . f || width <= w) {
73+ float h, float w) {
74+ if (h <= -1 || height <= h || w <= -1 || width <= w) {
7375 return 0 ;
7476 }
7577
@@ -94,6 +96,33 @@ __device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* _
9496 return val;
9597}
9698
99+ template <>
100+ __device__ __forceinline__ __half deformable_im2col_bilinear (const __half* __restrict__ input,
101+ const int height, const int width,
102+ float h, float w) {
103+ if (h <= -1 || height <= h || w <= -1 || width <= w) {
104+ return 0 ;
105+ }
106+
107+ const int h_low = floorf (h);
108+ const int w_low = floorf (w);
109+
110+ input += h_low * width;
111+ const float v1 = (h_low >= 0 && w_low >= 0 ) ? __half2float (input[w_low]) : 0 .0f ;
112+ const int w_high = w_low + 1 ;
113+ const float v2 = (h_low >= 0 && w_high <= width - 1 ) ? __half2float (input[w_high]) : 0 .0f ;
114+ const float lw = w - w_low;
115+ const float v_low = fmaf (v2 - v1, lw, v1);
116+ input += width;
117+ const float v3 = (h_low <= height - 2 && w_low >= 0 ) ? __half2float (input[w_low]) : 0 .0f ;
118+ const float v4 =
119+ (h_low <= height - 2 && w_high <= width - 1 ) ? __half2float (input[w_high]) : 0 .0f ;
120+ const float v_high = fmaf (v4 - v3, lw, v3);
121+ const float lh = h - h_low;
122+ const float val = fmaf (v_high - v_low, lh, v_low);
123+ return __float2half (val);
124+ }
125+
97126template <typename scalar_t >
98127__global__ void deformable_im2col_gpu_kernel (
99128 const int n, const scalar_t * __restrict__ data_im, const scalar_t * __restrict__ data_offset,
@@ -134,8 +163,8 @@ __global__ void deformable_im2col_gpu_kernel(
134163 const scalar_t offset_h = data_offset_ptr[data_offset_h];
135164 const int data_offset_w = data_offset_h + hw_col;
136165 const scalar_t offset_w = data_offset_ptr[data_offset_w];
137- const scalar_t h_im = h_in + i * dilation_h + offset_h;
138- const scalar_t w_im = w_in + j * dilation_w + offset_w;
166+ const scalar_t h_im = h_in + i * dilation_h + ( float ) offset_h;
167+ const scalar_t w_im = w_in + j * dilation_w + ( float ) offset_w;
139168 const scalar_t val = deformable_im2col_bilinear (data_im_ptr, height, width, h_im, w_im);
140169 *data_col_ptr = val;
141170 data_col_ptr += data_col_step;
0 commit comments