Skip to content

Commit 62c2f37

Browse files
committed
Generalize quantize_fns for simpler FP16 handling
1 parent 98ed165 commit 62c2f37

File tree

7 files changed

+167
-613
lines changed

7 files changed

+167
-613
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ void test_roundtrip_on_chunk(
143143
const ggml_tensor * layer,
144144
int64_t offset,
145145
int64_t chunk_size,
146-
const quantize_fns_t & qfns,
146+
const ggml_type_handling_t & qfns,
147147
bool use_reference,
148148
float * input_scratch,
149149
char * quantized_scratch,
@@ -159,11 +159,11 @@ void test_roundtrip_on_chunk(
159159
}
160160

161161
if (use_reference) {
162-
qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size);
162+
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size);
163163
} else {
164-
qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
164+
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
165165
}
166-
qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size);
166+
qfns.to_float(quantized_scratch, output_scratch, chunk_size);
167167

168168
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
169169
}
@@ -173,7 +173,7 @@ void test_roundtrip_on_chunk(
173173
void test_roundtrip_on_layer(
174174
std::string & name,
175175
bool print_layer_stats,
176-
const quantize_fns_t & qfns,
176+
const ggml_type_handling_t & qfns,
177177
bool use_reference,
178178
const ggml_tensor * layer,
179179
std::vector<float> & input_scratch,
@@ -374,8 +374,8 @@ int main(int argc, char ** argv) {
374374
if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
375375
continue;
376376
}
377-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
378-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
377+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
378+
if (qfns.from_float && qfns.to_float) {
379379
if (params.verbose) {
380380
printf("testing %s ...\n", ggml_type_name(type));
381381
}

ggml.c

Lines changed: 110 additions & 553 deletions
Large diffs are not rendered by default.

ggml.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ extern "C" {
224224
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
225225
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
226226

227-
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
228-
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
227+
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
228+
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
229229

230230
struct ggml_object;
231231
struct ggml_context;
@@ -1169,26 +1169,19 @@ extern "C" {
11691169
// Internal types and functions exposed for tests and benchmarks
11701170
//
11711171

1172-
#ifdef __cplusplus
1173-
// restrict not standard in C++
1174-
#define GGML_RESTRICT
1175-
#else
1176-
#define GGML_RESTRICT restrict
1177-
#endif
1178-
typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
1179-
typedef void (*quantize_row_q_t) (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
1180-
typedef void (*vec_dot_q_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
1172+
typedef void (*ggml_to_float_t)(const void * x, float * y, int k);
1173+
typedef void (*ggml_from_float_t)(const float * x, void * y, int k);
1174+
typedef void (*ggml_vec_dot_t)(const int n, float * s, const void * x, const void * y);
11811175

11821176
typedef struct {
1183-
dequantize_row_q_t dequantize_row_q;
1184-
quantize_row_q_t quantize_row_q;
1185-
quantize_row_q_t quantize_row_q_reference;
1186-
quantize_row_q_t quantize_row_q_dot;
1187-
vec_dot_q_t vec_dot_q;
1188-
enum ggml_type vec_dot_type;
1189-
} quantize_fns_t;
1190-
1191-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
1177+
ggml_to_float_t to_float;
1178+
ggml_from_float_t from_float;
1179+
ggml_from_float_t from_float_reference;
1180+
ggml_vec_dot_t vec_dot;
1181+
enum ggml_type vec_dot_type;
1182+
} ggml_type_handling_t;
1183+
1184+
ggml_type_handling_t ggml_internal_get_type_handling(enum ggml_type i);
11921185

11931186
#ifdef __cplusplus
11941187
}

pocs/vdot/q8dot.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
136136

137137
auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1;
138138

139-
auto funcs = ggml_internal_get_quantize_fn(ggml_type);
139+
auto funcs = ggml_internal_get_type_handling(ggml_type);
140140

141141
Stat simple, ggml;
142142

@@ -156,8 +156,8 @@ int main(int argc, char** argv) {
156156

157157
t1 = std::chrono::high_resolution_clock::now();
158158
float fs;
159-
if (type == 0) funcs.vec_dot_q(kVecSize * QK4_1, &fs, x40.data(), y.data());
160-
else funcs.vec_dot_q(kVecSize * QK4_1, &fs, x41.data(), y.data());
159+
if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, x40.data(), y.data());
160+
else funcs.vec_dot(kVecSize * QK4_1, &fs, x41.data(), y.data());
161161
t2 = std::chrono::high_resolution_clock::now();
162162
t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
163163
if (iloop > 3) ggml.addResult(fs, t);

pocs/vdot/vdot.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ int main(int argc, char** argv) {
231231
int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
232232
int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
233233

234-
auto funcs = useQ4_1 ? ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1) : ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0);
234+
auto funcs = useQ4_1 ? ggml_internal_get_type_handling(GGML_TYPE_Q4_1) : ggml_internal_get_type_handling(GGML_TYPE_Q4_0);
235235

236236
std::vector<block_q4_0> q40;
237237
std::vector<block_q4_1> q41;
@@ -257,9 +257,9 @@ int main(int argc, char** argv) {
257257
// Note, we do not include this in the timing as in practical application
258258
// we already have the quantized model weights.
259259
if (useQ4_1) {
260-
funcs.quantize_row_q(x1.data(), q41.data(), kVecSize);
260+
funcs.from_float(x1.data(), q41.data(), kVecSize);
261261
} else {
262-
funcs.quantize_row_q(x1.data(), q40.data(), kVecSize);
262+
funcs.from_float(x1.data(), q40.data(), kVecSize);
263263
}
264264

265265
// Now measure time the dot product needs using the "scalar" version above
@@ -278,9 +278,10 @@ int main(int argc, char** argv) {
278278
dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
279279
}
280280
else {
281-
funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize);
282-
if (useQ4_1) funcs.vec_dot_q(kVecSize, &result, q41.data(), q8.data());
283-
else funcs.vec_dot_q(kVecSize, &result, q40.data(), q8.data());
281+
auto vdot = ggml_internal_get_type_handling(funcs.vec_dot_type);
282+
vdot.from_float(y1.data(), q8.data(), kVecSize);
283+
if (useQ4_1) funcs.vec_dot(kVecSize, &result, q41.data(), q8.data());
284+
else funcs.vec_dot(kVecSize, &result, q40.data(), q8.data());
284285
}
285286
sumq += result;
286287
t2 = std::chrono::high_resolution_clock::now();

tests/test-quantize-fns.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,26 @@ float array_rmse(const float * a1, const float * a2, size_t n) {
3737
}
3838

3939
// Total quantization error on test data
40-
float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
40+
float total_quantization_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data) {
4141
std::vector<uint8_t> tmp_q(2*test_size);
4242
std::vector<float> tmp_out(test_size);
4343

44-
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
45-
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
44+
qfns.from_float(test_data, tmp_q.data(), test_size);
45+
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
4646
return array_rmse(test_data, tmp_out.data(), test_size);
4747
}
4848

4949
// Total quantization error on test data
50-
float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
50+
float reference_quantization_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data) {
5151
std::vector<uint8_t> tmp_q(2*test_size);
5252
std::vector<float> tmp_out(test_size);
5353
std::vector<float> tmp_out_ref(test_size);
5454

55-
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
56-
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
55+
qfns.from_float(test_data, tmp_q.data(), test_size);
56+
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
5757

58-
qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size);
59-
qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size);
58+
qfns.from_float_reference(test_data, tmp_q.data(), test_size);
59+
qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
6060

6161
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
6262
}
@@ -70,15 +70,17 @@ float dot_product(const float * a1, const float * a2, size_t test_size) {
7070
}
7171

7272
// Total dot product error
73-
float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
73+
float dot_product_error(ggml_type_handling_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
7474
std::vector<uint8_t> tmp_q1(2*test_size);
7575
std::vector<uint8_t> tmp_q2(2*test_size);
7676

77-
qfns.quantize_row_q (test_data1, tmp_q1.data(), test_size);
78-
qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
77+
auto vdot = ggml_internal_get_type_handling(qfns.vec_dot_type);
78+
79+
qfns.from_float(test_data1, tmp_q1.data(), test_size);
80+
vdot.from_float(test_data2, tmp_q2.data(), test_size);
7981

8082
float result = INFINITY;
81-
qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
83+
qfns.vec_dot(test_size, &result, tmp_q1.data(), tmp_q2.data());
8284

8385
const float dot_ref = dot_product(test_data1, test_data2, test_size);
8486

@@ -120,9 +122,9 @@ int main(int argc, char * argv[]) {
120122

121123
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
122124
ggml_type type = (ggml_type) i;
123-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
125+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
124126

125-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
127+
if (qfns.from_float && qfns.to_float) {
126128
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
127129
const float max_quantization_error =
128130
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :

tests/test-quantize-perf.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,20 @@ int main(int argc, char * argv[]) {
220220

221221
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
222222
ggml_type type = (ggml_type) i;
223-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
223+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
224224
if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
225225
continue;
226226
}
227227

228-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
228+
if (qfns.from_float && qfns.to_float) {
229229
printf("%s\n", ggml_type_name(type));
230230

231231
if (params.op_quantize_row_q_reference) {
232232
printf(" quantize_row_q_reference\n");
233233
for (size_t size : params.test_sizes) {
234234
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
235235
auto quantize_fn = [&](void ) {
236-
qfns.quantize_row_q_reference(test_data1, test_q1, size);
236+
qfns.from_float_reference(test_data1, test_q1, size);
237237
return test_q1[0];
238238
};
239239
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -247,7 +247,7 @@ int main(int argc, char * argv[]) {
247247
for (size_t size : params.test_sizes) {
248248
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
249249
auto quantize_fn = [&](void ) {
250-
qfns.quantize_row_q(test_data1, test_q1, size);
250+
qfns.from_float(test_data1, test_q1, size);
251251
return test_q1[0];
252252
};
253253
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -258,11 +258,11 @@ int main(int argc, char * argv[]) {
258258

259259
if (params.op_dequantize_row_q) {
260260
printf(" dequantize_row_q\n");
261-
qfns.quantize_row_q(test_data1, test_q1, largest);
261+
qfns.from_float(test_data1, test_q1, largest);
262262
for (size_t size : params.test_sizes) {
263263
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
264264
auto quantize_fn = [&](void ) {
265-
qfns.dequantize_row_q(test_q1, test_out, size);
265+
qfns.to_float(test_q1, test_out, size);
266266
return test_out[0];
267267
};
268268
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -276,7 +276,8 @@ int main(int argc, char * argv[]) {
276276
for (size_t size : params.test_sizes) {
277277
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
278278
auto quantize_fn = [&](void ) {
279-
qfns.quantize_row_q_dot(test_data1, test_q1, size);
279+
auto vdot = ggml_internal_get_type_handling(qfns.vec_dot_type);
280+
vdot.from_float(test_data1, test_q1, size);
280281
return test_q1[0];
281282
};
282283
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -287,13 +288,13 @@ int main(int argc, char * argv[]) {
287288

288289
if (params.op_vec_dot_q) {
289290
printf(" vec_dot_q\n");
290-
qfns.quantize_row_q(test_data1, test_q1, largest);
291-
qfns.quantize_row_q(test_data2, test_q2, largest);
291+
qfns.from_float(test_data1, test_q1, largest);
292+
qfns.from_float(test_data2, test_q2, largest);
292293
for (size_t size : params.test_sizes) {
293294
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
294295
auto quantize_fn = [&](void ) {
295296
float result;
296-
qfns.vec_dot_q(size, &result, test_q1, test_q2);
297+
qfns.vec_dot(size, &result, test_q1, test_q2);
297298
return result;
298299
};
299300
size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);

0 commit comments

Comments
 (0)