14
14
// for rocblas_initialize()
15
15
#include " rocblas/rocblas.h"
16
16
#endif // __HIP_PLATFORM_AMD__
17
+ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
17
18
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
18
19
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
19
20
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
21
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
20
22
#define CUBLAS_OP_N HIPBLAS_OP_N
21
23
#define CUBLAS_OP_T HIPBLAS_OP_T
22
24
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
@@ -235,8 +237,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
235
237
return *((int *) (x8 + sizeof (int ) * i32)); // assume at least 4 byte alignment
236
238
}
237
239
240
+ template <typename T>
241
+ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
242
+ typedef to_t_cuda_t <float > to_fp32_cuda_t ;
243
+ typedef to_t_cuda_t <half> to_fp16_cuda_t ;
244
+
238
245
typedef void (*dequantize_kernel_t )(const void * vx, const int ib, const int iqs, dfloat2 & v);
239
- typedef void (*to_fp32_cuda_t )(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
240
246
typedef void (*dot_kernel_k_t )(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
241
247
typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
242
248
typedef void (*ggml_cuda_func_t )(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1515,6 +1521,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1515
1521
v.y = x[ib + iqs + 1 ];
1516
1522
}
1517
1523
1524
+ static __device__ void convert_f32 (const void * vx, const int ib, const int iqs, dfloat2 & v){
1525
+ const float * x = (const float *) vx;
1526
+
1527
+ // automatic half -> float type cast if dfloat == float
1528
+ v.x = x[ib + iqs + 0 ];
1529
+ v.y = x[ib + iqs + 1 ];
1530
+ }
1531
+
1518
1532
static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
1519
1533
const int ix = blockDim .x *blockIdx .x + threadIdx .x ;
1520
1534
@@ -1554,8 +1568,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
1554
1568
reinterpret_cast <half&>(y[ib].ds .y ) = sum;
1555
1569
}
1556
1570
1557
- template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1558
- static __global__ void dequantize_block (const void * __restrict__ vx, float * __restrict__ y, const int k) {
1571
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
1572
+ static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1559
1573
const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
1560
1574
1561
1575
if (i >= k) {
@@ -4826,6 +4840,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
4826
4840
dequantize_block<1 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4827
4841
}
4828
4842
4843
+ static void convert_fp32_to_fp16_cuda (const void * vx, half * y, const int k, cudaStream_t stream) {
4844
+ const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
4845
+ dequantize_block<1 , 1 , convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
4846
+ }
4847
+
4829
4848
static void convert_mul_mat_vec_f16_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4830
4849
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
4831
4850
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
@@ -4835,6 +4854,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
4835
4854
<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
4836
4855
}
4837
4856
4857
+ static to_fp16_cuda_t ggml_get_to_fp16_cuda (ggml_type type) {
4858
+ switch (type) {
4859
+ case GGML_TYPE_F32:
4860
+ return convert_fp32_to_fp16_cuda;
4861
+ default :
4862
+ return nullptr ;
4863
+ }
4864
+ }
4865
+
4838
4866
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
4839
4867
switch (type) {
4840
4868
case GGML_TYPE_Q4_0:
@@ -6016,8 +6044,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
6016
6044
GGML_ASSERT (src1_ddf_i != nullptr );
6017
6045
GGML_ASSERT (dst_dd_i != nullptr );
6018
6046
6019
- const float alpha = 1 .0f ;
6020
- const float beta = 0 .0f ;
6021
6047
6022
6048
const int64_t ne00 = src0->ne [0 ];
6023
6049
@@ -6026,33 +6052,79 @@ inline void ggml_cuda_op_mul_mat_cublas(
6026
6052
const int64_t ne0 = dst->ne [0 ];
6027
6053
const int64_t row_diff = row_high - row_low;
6028
6054
6029
- float * src0_ddq_as_f32;
6030
- size_t src0_as = 0 ;
6031
-
6032
- if (src0->type != GGML_TYPE_F32) {
6033
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src0->type );
6034
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc (row_diff*ne00 * sizeof (float ), &src0_as); // NOLINT
6035
- to_fp32_cuda (src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6036
- }
6037
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6038
-
6039
6055
int id;
6040
6056
CUDA_CHECK (cudaGetDevice (&id));
6041
6057
6042
6058
// the main device has a larger memory buffer to hold the results from all GPUs
6043
6059
// ldc == nrows of the matrix that cuBLAS writes into
6044
6060
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
6045
6061
6046
- CUBLAS_CHECK (cublasSetStream (g_cublas_handles[id], stream));
6047
- CUBLAS_CHECK (
6048
- cublasSgemm (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6049
- row_diff, src1_ncols, ne10,
6050
- &alpha, src0_ddf_i, ne00,
6051
- src1_ddf_i, ne10,
6052
- &beta, dst_dd_i, ldc));
6062
+ const int compute_capability = g_compute_capabilities[id];
6063
+
6064
+ if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous (src0) && ldc == row_diff) {
6065
+ // convert src1 to fp16, multiply as fp16, convert dst to fp32
6066
+ half * src1_as_f16 = nullptr ;
6067
+ size_t src1_as = 0 ;
6068
+ if (src1->type != GGML_TYPE_F16) {
6069
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
6070
+ GGML_ASSERT (to_fp16_cuda != nullptr );
6071
+ size_t ne = src1_ncols*ne10;
6072
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as);
6073
+ to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
6074
+ }
6075
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6076
+
6077
+ size_t dst_as = 0 ;
6078
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
6079
+
6080
+ const half alpha_f16 = 1 .0f ;
6081
+ const half beta_f16 = 0 .0f ;
6082
+
6083
+ CUBLAS_CHECK (cublasSetStream (g_cublas_handles[id], stream));
6084
+ CUBLAS_CHECK (
6085
+ cublasGemmEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6086
+ row_diff, src1_ncols, ne10,
6087
+ &alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
6088
+ src1_ptr, CUDA_R_16F, ne10,
6089
+ &beta_f16, dst_f16, CUDA_R_16F, ldc,
6090
+ CUBLAS_COMPUTE_16F,
6091
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
6092
+
6093
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
6094
+ to_fp32_cuda (dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6095
+
6096
+ ggml_cuda_pool_free (dst_f16, dst_as);
6053
6097
6054
- if (src0_as > 0 ) {
6055
- ggml_cuda_pool_free (src0_ddq_as_f32, src0_as);
6098
+ if (src1_as != 0 ) {
6099
+ ggml_cuda_pool_free (src1_as_f16, src1_as);
6100
+ }
6101
+ }
6102
+ else {
6103
+ float * src0_ddq_as_f32 = nullptr ;
6104
+ size_t src0_as = 0 ;
6105
+
6106
+ if (src0->type != GGML_TYPE_F32) {
6107
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src0->type );
6108
+ GGML_ASSERT (to_fp32_cuda != nullptr );
6109
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc (row_diff*ne00 * sizeof (float ), &src0_as); // NOLINT
6110
+ to_fp32_cuda (src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6111
+ }
6112
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6113
+
6114
+ const float alpha = 1 .0f ;
6115
+ const float beta = 0 .0f ;
6116
+
6117
+ CUBLAS_CHECK (cublasSetStream (g_cublas_handles[id], stream));
6118
+ CUBLAS_CHECK (
6119
+ cublasSgemm (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6120
+ row_diff, src1_ncols, ne10,
6121
+ &alpha, src0_ddf_i, ne00,
6122
+ src1_ddf_i, ne10,
6123
+ &beta, dst_dd_i, ldc));
6124
+
6125
+ if (src0_as != 0 ) {
6126
+ ggml_cuda_pool_free (src0_ddq_as_f32, src0_as);
6127
+ }
6056
6128
}
6057
6129
6058
6130
(void ) dst;
0 commit comments