Skip to content

Commit 7a8c050

Browse files
committed
ggml : reuse quantum structs across backends
ggml-ci
1 parent a167b6d commit 7a8c050

File tree

6 files changed

+420
-695
lines changed

6 files changed

+420
-695
lines changed

ggml-common.h

Lines changed: 361 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,363 @@
1-
#pragma once
1+
#ifndef GGML_COMMON_DECL
2+
3+
#if defined(GGML_COMMON_DECL_C)
4+
#include <stdint.h>
5+
6+
typedef uint16_t ggml_half;
7+
typedef uint32_t ggml_half2;
8+
9+
#define GGML_COMMON_AGGR
10+
11+
#define GGML_COMMON_DECL
12+
#elif defined(GGML_COMMON_DECL_METAL)
13+
#include <metal_stdlib>
14+
15+
typedef half ggml_half;
16+
typedef half2 ggml_half2;
17+
18+
#define GGML_COMMON_AGGR
19+
20+
#define GGML_COMMON_DECL
21+
#elif defined(GGML_COMMON_DECL_CUDA)
22+
#include <cuda_fp16.h>
23+
#include <cstdint>
24+
25+
typedef half ggml_half;
26+
typedef half2 ggml_half2;
27+
28+
#define GGML_COMMON_AGGR data
29+
30+
#define GGML_COMMON_DECL
31+
#endif
32+
33+
#if defined(GGML_COMMON_DECL)
34+
35+
#ifndef __cplusplus
36+
#ifndef static_assert
37+
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
38+
#define static_assert(cond, msg) _Static_assert(cond, msg)
39+
#else
40+
#define static_assert(cond, msg) struct global_scope_noop_trick
41+
#endif
42+
#endif
43+
#endif
44+
45+
// QK = number of values after dequantization
46+
// QR = QK / number of values before dequantization
47+
// QI = number of 32 bit integers before dequantization
48+
49+
#define QK4_0 32
50+
#define QI4_0 (QK4_0 / (4 * QR4_0))
51+
#define QR4_0 2
52+
typedef struct {
53+
ggml_half d; // delta
54+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
55+
} block_q4_0;
56+
static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
57+
58+
#define QK4_1 32
59+
#define QI4_1 (QK4_1 / (4 * QR4_1))
60+
#define QR4_1 2
61+
typedef struct {
62+
union {
63+
struct {
64+
ggml_half d; // delta
65+
ggml_half m; // min
66+
} GGML_COMMON_AGGR;
67+
ggml_half2 dm;
68+
};
69+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
70+
} block_q4_1;
71+
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
72+
73+
#define QK5_0 32
74+
#define QI5_0 (QK5_0 / (4 * QR5_0))
75+
#define QR5_0 2
76+
typedef struct {
77+
ggml_half d; // delta
78+
uint8_t qh[4]; // 5-th bit of quants
79+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
80+
} block_q5_0;
81+
static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
82+
83+
#define QK5_1 32
84+
#define QI5_1 (QK5_1 / (4 * QR5_1))
85+
#define QR5_1 2
86+
typedef struct {
87+
union {
88+
struct {
89+
ggml_half d; // delta
90+
ggml_half m; // min
91+
} GGML_COMMON_AGGR;
92+
ggml_half2 dm;
93+
};
94+
uint8_t qh[4]; // 5-th bit of quants
95+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
96+
} block_q5_1;
97+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
98+
99+
#define QK8_0 32
100+
#define QI8_0 (QK8_0 / (4 * QR8_0))
101+
#define QR8_0 1
102+
typedef struct {
103+
ggml_half d; // delta
104+
int8_t qs[QK8_0]; // quants
105+
} block_q8_0;
106+
static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
107+
108+
#define QK8_1 32
109+
#define QI8_1 (QK8_1 / (4 * QR8_1))
110+
#define QR8_1 1
111+
typedef struct {
112+
union {
113+
struct {
114+
ggml_half xxxd; // delta
115+
ggml_half xxxs; // d * sum(qs[i])
116+
} GGML_COMMON_AGGR;
117+
ggml_half2 ds;
118+
};
119+
int8_t qs[QK8_1]; // quants
120+
} block_q8_1;
121+
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
122+
123+
//
124+
// Super-block quantization structures
125+
//
126+
127+
// Super-block size
128+
#ifdef GGML_QKK_64
129+
#define QK_K 64
130+
#define K_SCALE_SIZE 4
131+
#else
132+
#define QK_K 256
133+
#define K_SCALE_SIZE 12
134+
#endif
135+
136+
// 2-bit quantization
137+
// weight is represented as x = a * q + b
138+
// 16 blocks of 16 elements each
139+
// Effectively 2.625 bits per weight
140+
#define QI2_K (QK_K / (4*QR2_K))
141+
#define QR2_K 4
142+
typedef struct {
143+
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
144+
uint8_t qs[QK_K/4]; // quants
145+
union {
146+
struct {
147+
ggml_half d; // super-block scale for quantized scales
148+
ggml_half dmin; // super-block scale for quantized mins
149+
} GGML_COMMON_AGGR;
150+
ggml_half2 dm;
151+
};
152+
} block_q2_K;
153+
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
154+
155+
// 3-bit quantization
156+
// weight is represented as x = a * q
157+
// 16 blocks of 16 elements each
158+
// Effectively 3.4375 bits per weight
159+
#define QI3_K (QK_K / (4*QR3_K))
160+
#define QR3_K 4
161+
#ifdef GGML_QKK_64
162+
typedef struct {
163+
uint8_t hmask[QK_K/8]; // quants - high bit
164+
uint8_t qs[QK_K/4]; // quants - low 2 bits
165+
uint8_t scales[2];
166+
ggml_half d; // super-block scale
167+
} block_q3_K;
168+
static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
169+
#else
170+
typedef struct {
171+
uint8_t hmask[QK_K/8]; // quants - high bit
172+
uint8_t qs[QK_K/4]; // quants - low 2 bits
173+
uint8_t scales[12]; // scales, quantized with 6 bits
174+
ggml_half d; // super-block scale
175+
} block_q3_K;
176+
static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
177+
#endif
178+
179+
// 4-bit quantization
180+
// 8 blocks of 32 elements each
181+
// weight is represented as x = a * q + b
182+
// Effectively 4.5 bits per weight
183+
#define QI4_K (QK_K / (4*QR4_K))
184+
#define QR4_K 2
185+
#ifdef GGML_QKK_64
186+
typedef struct {
187+
ggml_half d[2]; // super-block scales/mins
188+
uint8_t scales[2]; // 4-bit block scales/mins
189+
uint8_t qs[QK_K/2]; // 4--bit quants
190+
} block_q4_K;
191+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding");
192+
#else
193+
typedef struct {
194+
union {
195+
struct {
196+
ggml_half d; // super-block scale for quantized scales
197+
ggml_half dmin; // super-block scale for quantized mins
198+
} GGML_COMMON_AGGR;
199+
ggml_half2 dm;
200+
};
201+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
202+
uint8_t qs[QK_K/2]; // 4--bit quants
203+
} block_q4_K;
204+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
205+
#endif
206+
207+
// 5-bit quantization
208+
// 8 blocks of 32 elements each
209+
// weight is represented as x = a * q + b
210+
// Effectively 5.5 bits per weight
211+
#define QI5_K (QK_K / (4*QR5_K))
212+
#define QR5_K 2
213+
#ifdef GGML_QKK_64
214+
typedef struct {
215+
ggml_half d; // super-block scale
216+
int8_t scales[QK_K/16]; // 8-bit block scales
217+
uint8_t qh[QK_K/8]; // quants, high bit
218+
uint8_t qs[QK_K/2]; // quants, low 4 bits
219+
} block_q5_K;
220+
static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
221+
#else
222+
typedef struct {
223+
union {
224+
struct {
225+
ggml_half d; // super-block scale for quantized scales
226+
ggml_half dmin; // super-block scale for quantized mins
227+
} GGML_COMMON_AGGR;
228+
ggml_half2 dm;
229+
};
230+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
231+
uint8_t qh[QK_K/8]; // quants, high bit
232+
uint8_t qs[QK_K/2]; // quants, low 4 bits
233+
} block_q5_K;
234+
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
235+
#endif
236+
237+
// 6-bit quantization
238+
// weight is represented as x = a * q
239+
// 16 blocks of 16 elements each
240+
// Effectively 6.5625 bits per weight
241+
#define QI6_K (QK_K / (4*QR6_K))
242+
#define QR6_K 2
243+
typedef struct {
244+
uint8_t ql[QK_K/2]; // quants, lower 4 bits
245+
uint8_t qh[QK_K/4]; // quants, upper 2 bits
246+
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
247+
ggml_half d; // super-block scale
248+
} block_q6_K;
249+
static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
250+
251+
// This is only used for intermediate quantization and dot products
252+
typedef struct {
253+
float d; // delta
254+
int8_t qs[QK_K]; // quants
255+
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
256+
} block_q8_K;
257+
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
258+
259+
// (Almost) "true" 2-bit quantization.
260+
// Due to the need to use blocks as per ggml design, it ends up using
261+
// 2.0625 bpw because of the 16-bit scale for each block of 256.
262+
#define QI2_XXS (QK_K / (4*QR2_XXS))
263+
#define QR2_XXS 8
264+
typedef struct {
265+
ggml_half d;
266+
uint16_t qs[QK_K/8];
267+
} block_iq2_xxs;
268+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
269+
270+
// 2.3125 bpw quants
271+
#define QI2_XS (QK_K / (4*QR2_XS))
272+
#define QR2_XS 8
273+
typedef struct {
274+
ggml_half d;
275+
uint16_t qs[QK_K/8];
276+
uint8_t scales[QK_K/32];
277+
} block_iq2_xs;
278+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
279+
280+
// 2.5625 bpw quants
281+
#define QI2_S (QK_K / (4*QR2_S))
282+
#define QR2_S 8
283+
typedef struct {
284+
ggml_half d;
285+
uint8_t qs[QK_K/4];
286+
uint8_t qh[QK_K/32];
287+
uint8_t scales[QK_K/32];
288+
} block_iq2_s;
289+
static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
290+
291+
// (Almost) "true" 3-bit quantization.
292+
// Due to the need to use blocks as per ggml design, it ends up using
293+
// 3.0625 bpw because of the 16-bit scale for each block of 256.
294+
#define QI3_XXS (QK_K / (4*QR3_XXS))
295+
#define QR3_XXS 8
296+
typedef struct {
297+
ggml_half d;
298+
uint8_t qs[3*QK_K/8];
299+
} block_iq3_xxs;
300+
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
301+
302+
// 3.4375 bpw
303+
#if QK_K == 64
304+
#define IQ3S_N_SCALE 2
305+
#else
306+
#define IQ3S_N_SCALE QK_K/64
307+
#endif
308+
#define QI3_XS (QK_K / (4*QR3_XS))
309+
#define QR3_XS 8
310+
typedef struct {
311+
ggml_half d;
312+
uint8_t qs[QK_K/4];
313+
uint8_t qh[QK_K/32];
314+
uint8_t signs[QK_K/8];
315+
uint8_t scales[IQ3S_N_SCALE];
316+
} block_iq3_s;
317+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
318+
319+
#define QI1_S (QK_K / (4*QR1_S))
320+
#define QR1_S 8
321+
typedef struct {
322+
ggml_half d;
323+
uint8_t qs[QK_K/8];
324+
uint8_t scales[QK_K/16];
325+
} block_iq1_s;
326+
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
327+
328+
// Non-linear quants
329+
#define QK4_NL 32
330+
#define QI4_NL (QK4_NL / (4*QR4_NL))
331+
#define QR4_NL 2
332+
typedef struct {
333+
ggml_half d;
334+
uint8_t qs[QK4_NL/2];
335+
} block_iq4_nl;
336+
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
337+
338+
#if QK_K == 64
339+
#define block_iq4_xs block_iq4_nl
340+
#define QI4_XS QI4_NL
341+
#define QR4_XS QR4_NL
342+
//typedef struct block_iq4_nl block_iq4_xs;
343+
#else
344+
#define QI4_XS (QK_K / (4*QR4_XS))
345+
#define QR4_XS 8
346+
typedef struct {
347+
ggml_half d;
348+
uint16_t scales_h;
349+
uint8_t scales_l[QK_K/64];
350+
uint8_t qs[QK_K/2];
351+
} block_iq4_xs;
352+
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
353+
#endif
354+
355+
#endif // GGML_COMMON_DECL
356+
#endif // GGML_COMMON_DECL
357+
358+
////////////////////////////////////////////////////////////////////////////////
359+
360+
#ifndef GGML_COMMON_IMPL
2361

3362
#if defined(GGML_COMMON_IMPL_C)
4363
#include <stdint.h>
@@ -777,3 +1136,4 @@ GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ2XXS)
7771136
GGML_TABLE_END()
7781137

7791138
#endif // GGML_COMMON_IMPL
1139+
#endif // GGML_COMMON_IMPL

0 commit comments

Comments
 (0)