Skip to content

Commit 1a8d7ac

Browse files
authored
[Enhancement] Support DeformConv TensorRT fp16 (open-mmlab#468)
* add DCN TensorRT fp16 support * fix getOutputDimensions
1 parent 843d3c9 commit 1a8d7ac

File tree

4 files changed

+50
-6
lines changed

4 files changed

+50
-6
lines changed

csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, int *permu
6161

6262
template void memcpyPermute<float>(float *dst, const float *src, int *src_size, int *permute,
6363
int src_dim, cudaStream_t stream);
64+
template void memcpyPermute<half>(half *dst, const half *src, int *src_size, int *permute,
65+
int src_dim, cudaStream_t stream);
6466

6567
cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t *cudnn_dtype) {
6668
switch (trt_dtype) {

csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions(
6767
bool DeformableConvPluginDynamic::supportsFormatCombination(
6868
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
6969
if (pos == 0) {
70-
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
70+
return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT ||
71+
ioDesc[pos].type == nvinfer1::DataType::kHALF) &&
7172
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
7273
} else {
7374
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
@@ -136,9 +137,14 @@ int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input
136137
mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle,
137138
stream);
138139
break;
140+
case nvinfer1::DataType::kHALF:
141+
deform_conv<half>((half *)x, (half *)weight, (half *)offset, (half *)output, workSpace, batch,
142+
channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0],
143+
mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1],
144+
mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream);
145+
break;
139146
default:
140147
return 1;
141-
break;
142148
}
143149

144150
return 0;

csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,10 @@ template void deform_conv<float>(const float* input, const float* weight, const
163163
int dW, int dH, int padW, int padH, int dilationW, int dilationH,
164164
int group, int deformable_group, int im2col_step,
165165
cublasHandle_t cublas_handle, cudaStream_t stream);
166+
167+
template void deform_conv<__half>(const __half* input, const __half* weight, const __half* offset,
168+
__half* output, void* workspace, int batchSize, int nInputPlane,
169+
int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH,
170+
int dW, int dH, int padW, int padH, int dilationW, int dilationH,
171+
int group, int deformable_group, int im2col_step,
172+
cublasHandle_t cublas_handle, cudaStream_t stream);

csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@
6363
// modified from
6464
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
6565

66+
#include <cuda_fp16.h>
67+
6668
#include "common_cuda_helper.hpp"
6769

6870
template <typename scalar_t>
6971
__device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* __restrict__ input,
7072
const int height, const int width,
71-
scalar_t h, scalar_t w) {
72-
if (h <= -1.f || height <= h || w <= -1.f || width <= w) {
73+
float h, float w) {
74+
if (h <= -1 || height <= h || w <= -1 || width <= w) {
7375
return 0;
7476
}
7577

@@ -94,6 +96,33 @@ __device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* _
9496
return val;
9597
}
9698

99+
template <>
100+
__device__ __forceinline__ __half deformable_im2col_bilinear(const __half* __restrict__ input,
101+
const int height, const int width,
102+
float h, float w) {
103+
if (h <= -1 || height <= h || w <= -1 || width <= w) {
104+
return 0;
105+
}
106+
107+
const int h_low = floorf(h);
108+
const int w_low = floorf(w);
109+
110+
input += h_low * width;
111+
const float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f;
112+
const int w_high = w_low + 1;
113+
const float v2 = (h_low >= 0 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f;
114+
const float lw = w - w_low;
115+
const float v_low = fmaf(v2 - v1, lw, v1);
116+
input += width;
117+
const float v3 = (h_low <= height - 2 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f;
118+
const float v4 =
119+
(h_low <= height - 2 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f;
120+
const float v_high = fmaf(v4 - v3, lw, v3);
121+
const float lh = h - h_low;
122+
const float val = fmaf(v_high - v_low, lh, v_low);
123+
return __float2half(val);
124+
}
125+
97126
template <typename scalar_t>
98127
__global__ void deformable_im2col_gpu_kernel(
99128
const int n, const scalar_t* __restrict__ data_im, const scalar_t* __restrict__ data_offset,
@@ -134,8 +163,8 @@ __global__ void deformable_im2col_gpu_kernel(
134163
const scalar_t offset_h = data_offset_ptr[data_offset_h];
135164
const int data_offset_w = data_offset_h + hw_col;
136165
const scalar_t offset_w = data_offset_ptr[data_offset_w];
137-
const scalar_t h_im = h_in + i * dilation_h + offset_h;
138-
const scalar_t w_im = w_in + j * dilation_w + offset_w;
166+
const scalar_t h_im = h_in + i * dilation_h + (float)offset_h;
167+
const scalar_t w_im = w_in + j * dilation_w + (float)offset_w;
139168
const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im);
140169
*data_col_ptr = val;
141170
data_col_ptr += data_col_step;

0 commit comments

Comments
 (0)