Skip to content

Commit eb36362

Browse files
cuda : deduplicated dequantization code (#1453)
1 parent 79b2d5b commit eb36362

File tree

1 file changed

+33
-121
lines changed

1 file changed

+33
-121
lines changed

ggml-cuda.cu

Lines changed: 33 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ typedef struct {
8383
} block_q8_0;
8484
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
8585

86-
#define CUDA_DMMV_BLOCK_SIZE 32
86+
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
87+
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
8788

8889
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
8990
const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -170,104 +171,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
170171
v1 = __half2float(x[ib + 1]);
171172
}
172173

173-
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
174-
static const int qk = QK4_0;
174+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
175+
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
176+
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
175177

176-
const block_q4_0 * x = (const block_q4_0 *) vx;
177-
178-
const int i = blockIdx.x;
179-
180-
const float d = x[i].d;
181-
182-
for (int j = 0; j < qk/2; ++j) {
183-
const int x0 = (x[i].qs[j] & 0xf) - 8;
184-
const int x1 = (x[i].qs[j] >> 4) - 8;
185-
186-
y[i*qk + j + 0 ] = x0*d;
187-
y[i*qk + j + qk/2] = x1*d;
188-
}
189-
}
190-
191-
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
192-
static const int qk = QK4_1;
193-
194-
const block_q4_1 * x = (const block_q4_1 *) vx;
195-
196-
const int i = blockIdx.x;
197-
198-
const float d = x[i].d;
199-
const float m = x[i].m;
200-
201-
for (int j = 0; j < qk/2; ++j) {
202-
const int x0 = (x[i].qs[j] & 0xf);
203-
const int x1 = (x[i].qs[j] >> 4);
204-
205-
y[i*qk + j + 0 ] = x0*d + m;
206-
y[i*qk + j + qk/2] = x1*d + m;
207-
}
208-
}
209-
210-
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
211-
static const int qk = QK5_0;
212-
213-
const block_q5_0 * x = (const block_q5_0 *) vx;
214-
215-
const int i = blockIdx.x;
216-
217-
const float d = x[i].d;
218-
219-
uint32_t qh;
220-
memcpy(&qh, x[i].qh, sizeof(qh));
221-
222-
for (int j = 0; j < qk/2; ++j) {
223-
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
224-
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
225-
226-
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
227-
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
228-
229-
y[i*qk + j + 0 ] = x0*d;
230-
y[i*qk + j + qk/2] = x1*d;
178+
if (i >= k) {
179+
return;
231180
}
232-
}
233-
234-
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
235-
static const int qk = QK5_1;
236-
237-
const block_q5_1 * x = (const block_q5_1 *) vx;
238-
239-
const int i = blockIdx.x;
240-
241-
const float d = x[i].d;
242-
const float m = x[i].m;
243181

244-
uint32_t qh;
245-
memcpy(&qh, x[i].qh, sizeof(qh));
246-
247-
for (int j = 0; j < qk/2; ++j) {
248-
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
249-
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
250-
251-
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
252-
const int x1 = (x[i].qs[j] >> 4) | xh_1;
253-
254-
y[i*qk + j + 0 ] = x0*d + m;
255-
y[i*qk + j + qk/2] = x1*d + m;
256-
}
257-
}
258-
259-
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
260-
static const int qk = QK8_0;
261-
262-
const block_q8_0 * x = (const block_q8_0 *) vx;
263-
264-
const int i = blockIdx.x;
265-
266-
const float d = x[i].d;
182+
const int ib = i/qk; // block index
183+
const int iqs = (i%qk)/qr; // quant index
184+
const int iybs = i - i%qk; // y block start index
185+
const int y_offset = qr == 1 ? 1 : qk/2;
267186

268-
for (int j = 0; j < qk; ++j) {
269-
y[i*qk + j] = x[i].qs[j]*d;
270-
}
187+
// dequantize
188+
float & v0 = y[iybs + iqs + 0];
189+
float & v1 = y[iybs + iqs + y_offset];
190+
dequantize_kernel(vx, ib, iqs, v0, v1);
271191
}
272192

273193
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -308,29 +228,29 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
308228
}
309229
}
310230

311-
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
312-
const int nb = k / QK4_0;
313-
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
231+
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
232+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
233+
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
314234
}
315235

316-
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
317-
const int nb = k / QK4_1;
318-
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
236+
static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
237+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
238+
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
319239
}
320240

321-
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
322-
const int nb = k / QK5_0;
323-
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
241+
static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
242+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
243+
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
324244
}
325245

326-
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
327-
const int nb = k / QK5_1;
328-
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
246+
static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
247+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
248+
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
329249
}
330250

331-
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
332-
const int nb = k / QK8_0;
333-
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
251+
static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
252+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
253+
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
334254
}
335255

336256
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -363,17 +283,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
363283
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
364284
}
365285

366-
// TODO: optimize
367-
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
368-
const half * x = (const half *) vx;
369-
370-
const int i = blockIdx.x;
371-
372-
y[i] = __half2float(x[i]);
373-
}
374-
375-
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
376-
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
286+
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
287+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
288+
dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
377289
}
378290

379291
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {

0 commit comments

Comments
 (0)