Skip to content

Commit ad62d3a

Browse files
JohannesGaesslerhodlen
authored andcommitted
CUDA: fixed redundant value dequantization (ggml-org#4809)
1 parent b167e41 commit ad62d3a

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,14 +1872,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
18721872
v.y = x[ib + iqs + 1];
18731873
}
18741874

1875-
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
1876-
const float * x = (const float *) vx;
1877-
1878-
// automatic half -> float type cast if dfloat == float
1879-
v.x = x[ib + iqs + 0];
1880-
v.y = x[ib + iqs + 1];
1881-
}
1882-
18831875
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
18841876
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
18851877

@@ -1983,7 +1975,7 @@ static __global__ void k_get_rows_float(
19831975

19841976
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
19851977
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1986-
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
1978+
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
19871979

19881980
if (i >= k) {
19891981
return;
@@ -2002,6 +1994,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
20021994
y[iybs + iqs + y_offset] = v.y;
20031995
}
20041996

1997+
template <typename src_t, typename dst_t>
1998+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1999+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
2000+
2001+
if (i >= k) {
2002+
return;
2003+
}
2004+
2005+
const src_t * x = (src_t *) vx;
2006+
2007+
y[i] = x[i];
2008+
}
2009+
20052010
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
20062011
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
20072012

@@ -5609,7 +5614,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
56095614

56105615
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
56115616
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5612-
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
5617+
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
56135618
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
56145619
}
56155620

@@ -5659,6 +5664,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
56595664
#endif
56605665
}
56615666

5667+
template <typename src_t, typename dst_t>
5668+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5669+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
5670+
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5671+
}
5672+
56625673
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
56635674
switch (type) {
56645675
case GGML_TYPE_Q4_0:
@@ -5682,7 +5693,7 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
56825693
case GGML_TYPE_Q6_K:
56835694
return dequantize_row_q6_K_cuda;
56845695
case GGML_TYPE_F32:
5685-
return dequantize_block_cuda<1, 1, convert_f32>;
5696+
return convert_unary_cuda<float>;
56865697
default:
56875698
return nullptr;
56885699
}
@@ -5711,7 +5722,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
57115722
case GGML_TYPE_Q6_K:
57125723
return dequantize_row_q6_K_cuda;
57135724
case GGML_TYPE_F16:
5714-
return dequantize_block_cuda<1, 1, convert_f16>;
5725+
return convert_unary_cuda<half>;
57155726
default:
57165727
return nullptr;
57175728
}

0 commit comments

Comments
 (0)