Skip to content

Commit da04003

Browse files
authored
ggml-cuda : perform cublas fp16 matrix multiplication as fp16 (#3370)
* ggml-cuda : perform cublas fp16 matrix multiplication as fp16 * try to fix rocm build * restrict fp16 mat mul to volta and up
1 parent e519621 commit da04003

File tree

1 file changed

+96
-24
lines changed

1 file changed

+96
-24
lines changed

ggml-cuda.cu

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
// for rocblas_initialize()
1515
#include "rocblas/rocblas.h"
1616
#endif // __HIP_PLATFORM_AMD__
17+
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
1718
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
1819
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
1920
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
21+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
2022
#define CUBLAS_OP_N HIPBLAS_OP_N
2123
#define CUBLAS_OP_T HIPBLAS_OP_T
2224
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
@@ -235,8 +237,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
235237
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
236238
}
237239

240+
template<typename T>
241+
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
242+
typedef to_t_cuda_t<float> to_fp32_cuda_t;
243+
typedef to_t_cuda_t<half> to_fp16_cuda_t;
244+
238245
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
239-
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
240246
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
241247
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
242248
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1515,6 +1521,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
15151521
v.y = x[ib + iqs + 1];
15161522
}
15171523

1524+
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
1525+
const float * x = (const float *) vx;
1526+
1527+
// automatic half -> float type cast if dfloat == float
1528+
v.x = x[ib + iqs + 0];
1529+
v.y = x[ib + iqs + 1];
1530+
}
1531+
15181532
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
15191533
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
15201534

@@ -1554,8 +1568,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
15541568
reinterpret_cast<half&>(y[ib].ds.y) = sum;
15551569
}
15561570

1557-
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1558-
static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
1571+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1572+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
15591573
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
15601574

15611575
if (i >= k) {
@@ -4826,6 +4840,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
48264840
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
48274841
}
48284842

4843+
static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
4844+
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
4845+
dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4846+
}
4847+
48294848
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
48304849
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
48314850
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -4835,6 +4854,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
48354854
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
48364855
}
48374856

4857+
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
4858+
switch (type) {
4859+
case GGML_TYPE_F32:
4860+
return convert_fp32_to_fp16_cuda;
4861+
default:
4862+
return nullptr;
4863+
}
4864+
}
4865+
48384866
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
48394867
switch (type) {
48404868
case GGML_TYPE_Q4_0:
@@ -6016,8 +6044,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
60166044
GGML_ASSERT(src1_ddf_i != nullptr);
60176045
GGML_ASSERT(dst_dd_i != nullptr);
60186046

6019-
const float alpha = 1.0f;
6020-
const float beta = 0.0f;
60216047

60226048
const int64_t ne00 = src0->ne[0];
60236049

@@ -6026,33 +6052,79 @@ inline void ggml_cuda_op_mul_mat_cublas(
60266052
const int64_t ne0 = dst->ne[0];
60276053
const int64_t row_diff = row_high - row_low;
60286054

6029-
float * src0_ddq_as_f32;
6030-
size_t src0_as = 0;
6031-
6032-
if (src0->type != GGML_TYPE_F32) {
6033-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6034-
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6035-
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6036-
}
6037-
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6038-
60396055
int id;
60406056
CUDA_CHECK(cudaGetDevice(&id));
60416057

60426058
// the main device has a larger memory buffer to hold the results from all GPUs
60436059
// ldc == nrows of the matrix that cuBLAS writes into
60446060
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
60456061

6046-
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6047-
CUBLAS_CHECK(
6048-
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6049-
row_diff, src1_ncols, ne10,
6050-
&alpha, src0_ddf_i, ne00,
6051-
src1_ddf_i, ne10,
6052-
&beta, dst_dd_i, ldc));
6062+
const int compute_capability = g_compute_capabilities[id];
6063+
6064+
if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
6065+
// convert src1 to fp16, multiply as fp16, convert dst to fp32
6066+
half * src1_as_f16 = nullptr;
6067+
size_t src1_as = 0;
6068+
if (src1->type != GGML_TYPE_F16) {
6069+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
6070+
GGML_ASSERT(to_fp16_cuda != nullptr);
6071+
size_t ne = src1_ncols*ne10;
6072+
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
6073+
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6074+
}
6075+
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6076+
6077+
size_t dst_as = 0;
6078+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6079+
6080+
const half alpha_f16 = 1.0f;
6081+
const half beta_f16 = 0.0f;
6082+
6083+
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6084+
CUBLAS_CHECK(
6085+
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6086+
row_diff, src1_ncols, ne10,
6087+
&alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
6088+
src1_ptr, CUDA_R_16F, ne10,
6089+
&beta_f16, dst_f16, CUDA_R_16F, ldc,
6090+
CUBLAS_COMPUTE_16F,
6091+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
6092+
6093+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6094+
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6095+
6096+
ggml_cuda_pool_free(dst_f16, dst_as);
60536097

6054-
if (src0_as > 0) {
6055-
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6098+
if (src1_as != 0) {
6099+
ggml_cuda_pool_free(src1_as_f16, src1_as);
6100+
}
6101+
}
6102+
else {
6103+
float * src0_ddq_as_f32 = nullptr;
6104+
size_t src0_as = 0;
6105+
6106+
if (src0->type != GGML_TYPE_F32) {
6107+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6108+
GGML_ASSERT(to_fp32_cuda != nullptr);
6109+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6110+
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6111+
}
6112+
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6113+
6114+
const float alpha = 1.0f;
6115+
const float beta = 0.0f;
6116+
6117+
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6118+
CUBLAS_CHECK(
6119+
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6120+
row_diff, src1_ncols, ne10,
6121+
&alpha, src0_ddf_i, ne00,
6122+
src1_ddf_i, ne10,
6123+
&beta, dst_dd_i, ldc));
6124+
6125+
if (src0_as != 0) {
6126+
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6127+
}
60566128
}
60576129

60586130
(void) dst;

0 commit comments

Comments
 (0)