Skip to content

Commit 76a8849

Browse files
authored
ggml : add CLBlast q5_0, q5_1, q8_0 dequant kernels (#1225)
* Implement q5_0, q5_1 and q8_0 * Work around q5_0 OpenCL issue * Fix q8_0 dequant kernel * Move cl kernels into ggml-opencl.c * Use two memcpy calls for q5_0 buffer transfer
1 parent 6bc4400 commit 76a8849

File tree

2 files changed

+205
-78
lines changed

2 files changed

+205
-78
lines changed

ggml-opencl-dequant.cl

Lines changed: 0 additions & 63 deletions
This file was deleted.

ggml-opencl.c

Lines changed: 205 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,141 @@
33
#define CL_TARGET_OPENCL_VERSION 110
44
#include <clblast_c.h>
55

6+
#include <stdlib.h>
67
#include <stdio.h>
78
#include <string.h>
89

910
#include "ggml.h"
1011

11-
#include "ggml-opencl-dequant.cl"
12+
#define MULTILINE_QUOTE(...) #__VA_ARGS__
13+
const char * clblast_dequant = MULTILINE_QUOTE(
14+
15+
struct block_q4_0
16+
{
17+
float d;
18+
uchar qs[16];
19+
};
20+
21+
__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
22+
const uint i = get_global_id(0) / 32;
23+
const uint l = get_local_id(0);
24+
25+
const float d = blocks[i].d;
26+
27+
const uchar vi = blocks[i].qs[l];
28+
29+
const uint index = i*32 + l*2;
30+
result[index + 0] = ((vi & 0xf) - 8)*d;
31+
result[index + 1] = ((vi >> 4) - 8)*d;
32+
}
33+
34+
struct block_q4_1
35+
{
36+
float d;
37+
float m;
38+
uchar qs[16];
39+
};
40+
41+
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
42+
const uint i = get_global_id(0) / 32;
43+
const uint l = get_local_id(0);
44+
45+
const float d = blocks[i].d;
46+
const float m = blocks[i].m;
47+
48+
const uchar vi = blocks[i].qs[l];
49+
50+
const uint index = i*32 + l*2;
51+
result[index + 0] = (vi & 0xf) * d + m;
52+
result[index + 1] = (vi >> 4) * d + m;
53+
}
54+
55+
struct block_q4_2
56+
{
57+
ushort d;
58+
uchar qs[8];
59+
};
60+
61+
__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
62+
const uint i = get_global_id(0) / 16;
63+
const uint l = get_local_id(0);
64+
65+
const float d = vload_half(0, (__global half*) &blocks[i].d);
66+
67+
const uchar vi = blocks[i].qs[l];
68+
69+
const uint index = i*16 + l*2;
70+
result[index + 0] = ((vi & 0xf) - 8)*d;
71+
result[index + 1] = ((vi >> 4) - 8)*d;
72+
}
73+
74+
75+
struct block_q5_0
76+
{
77+
float d;
78+
uint qh;
79+
uchar qs[16];
80+
};
81+
82+
__kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
83+
const uint i = get_global_id(0) / 32;
84+
const uint l = get_local_id(0);
85+
86+
const float d = blocks[i].d;
87+
88+
const uchar vi = blocks[i].qs[l];
89+
90+
const uint l2 = l * 2;
91+
92+
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
93+
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
94+
95+
const uint index = i*32 + l2;
96+
result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
97+
result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
98+
}
99+
100+
struct block_q5_1
101+
{
102+
ushort d;
103+
ushort m;
104+
uint qh;
105+
uchar qs[16];
106+
};
107+
108+
__kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
109+
const uint i = get_global_id(0) / 32;
110+
const uint l = get_local_id(0);
111+
112+
const float d = vload_half(0, (__global half*) &blocks[i].d);
113+
const float m = vload_half(0, (__global half*) &blocks[i].m);
114+
115+
const uchar vi = blocks[i].qs[l];
116+
117+
const uint l2 = l * 2;
118+
119+
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
120+
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
121+
122+
const uint index = i*32 + l2;
123+
result[index + 0] = ((vi & 0xf) | vh0)*d + m;
124+
result[index + 1] = ((vi >> 4) | vh1)*d + m;
125+
}
126+
127+
struct block_q8_0
128+
{
129+
float d;
130+
char qs[32];
131+
};
132+
133+
__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
134+
const uint i = get_global_id(0) / 32;
135+
const uint l = get_local_id(0);
136+
137+
result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
138+
}
139+
140+
);
12141

13142
#define CL_CHECK(err, name) \
14143
do { \
@@ -19,12 +148,26 @@
19148
} \
20149
} while (0)
21150

151+
#define QK5_0 32
152+
typedef struct {
153+
ggml_fp16_t d; // delta
154+
uint8_t qh[4]; // 5-th bit of quants
155+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
156+
} block_q5_0;
157+
158+
159+
typedef struct {
160+
float d; // delta
161+
uint32_t qh; // 5-th bit of quants
162+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
163+
} cl_block_q5_0;
164+
22165
static cl_platform_id platform;
23166
static cl_device_id device;
24167
static cl_context context;
25168
static cl_command_queue queue;
26169
static cl_program program;
27-
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2;
170+
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
28171
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
29172
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
30173

@@ -97,6 +240,12 @@ void ggml_cl_init(void) {
97240
CL_CHECK(err, "clCreateKernel");
98241
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
99242
CL_CHECK(err, "clCreateKernel");
243+
kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
244+
CL_CHECK(err, "clCreateKernel");
245+
kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
246+
CL_CHECK(err, "clCreateKernel");
247+
kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
248+
CL_CHECK(err, "clCreateKernel");
100249
}
101250

102251
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -125,6 +274,7 @@ void ggml_cl_sgemm_wrapper(
125274
cl_kernel kernel;
126275
size_t global = n * k, local, size_qb;
127276
bool dequant;
277+
cl_block_q5_0* cl_host_b;
128278

129279
switch (btype) {
130280
case GGML_TYPE_F32:
@@ -146,7 +296,36 @@ void ggml_cl_sgemm_wrapper(
146296
dequant = true;
147297
kernel = kernel_q4_2;
148298
local = 8;
149-
size_qb = global * (sizeof(short) + local) / 16;
299+
size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
300+
break;
301+
case GGML_TYPE_Q5_0:
302+
dequant = true;
303+
kernel = kernel_q5_0;
304+
local = 16;
305+
// For some reason OpenCL seems to be incapable of working with structs of size 22.
306+
// 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
307+
// TODO Find the reason, fix and remove workaround.
308+
const block_q5_0* b = (const block_q5_0*) host_b;
309+
cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
310+
for (size_t i = 0; i < global / 32; i++) {
311+
cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
312+
memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
313+
memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
314+
}
315+
host_b = (const float*) cl_host_b;
316+
size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
317+
break;
318+
case GGML_TYPE_Q5_1:
319+
dequant = true;
320+
kernel = kernel_q5_1;
321+
local = 16;
322+
size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32;
323+
break;
324+
case GGML_TYPE_Q8_0:
325+
dequant = true;
326+
kernel = kernel_q8_0;
327+
local = 32;
328+
size_qb = global * (sizeof(float) + local) / 32;
150329
break;
151330
default:
152331
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
@@ -171,12 +350,15 @@ void ggml_cl_sgemm_wrapper(
171350
err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
172351
err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
173352
CL_CHECK(err, "clSetKernelArg");
174-
clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
353+
err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
354+
CL_CHECK(err, "clEnqueueWriteBuffer qb");
175355
} else {
176-
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
356+
err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
357+
CL_CHECK(err, "clEnqueueWriteBuffer b");
177358
}
178359

179-
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
360+
err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
361+
CL_CHECK(err, "clEnqueueWriteBuffer a");
180362
if (dequant) {
181363
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
182364
CL_CHECK(err, "clEnqueueNDRangeKernel");
@@ -188,15 +370,20 @@ void ggml_cl_sgemm_wrapper(
188370
clReleaseEvent(ev_b);
189371

190372
cl_event ev_sgemm;
191-
CLBlastSgemm((CLBlastLayout)order,
192-
(CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
193-
m, n, k,
194-
alpha,
195-
cl_buffer_a, 0, lda,
196-
cl_buffer_b, 0, ldb,
197-
beta,
198-
cl_buffer_c, 0, ldc,
199-
&queue, &ev_sgemm);
373+
CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
374+
(CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
375+
m, n, k,
376+
alpha,
377+
cl_buffer_a, 0, lda,
378+
cl_buffer_b, 0, ldb,
379+
beta,
380+
cl_buffer_c, 0, ldc,
381+
&queue, &ev_sgemm);
382+
383+
if (status != CLBlastSuccess) {
384+
fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
385+
abort();
386+
}
200387

201388
cl_event ev_c;
202389
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
@@ -205,4 +392,7 @@ void ggml_cl_sgemm_wrapper(
205392
clWaitForEvents(1, &ev_c);
206393
clReleaseEvent(ev_sgemm);
207394
clReleaseEvent(ev_c);
395+
if (btype == GGML_TYPE_Q5_0) {
396+
free((void*) cl_host_b);
397+
}
208398
}

0 commit comments

Comments
 (0)