Skip to content

Commit 7d58c84

Browse files
committed
Generalize quantize_fns for simpler FP16 handling
1 parent f4cef87 commit 7d58c84

File tree

7 files changed

+157
-653
lines changed

7 files changed

+157
-653
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ void test_roundtrip_on_chunk(
143143
const ggml_tensor * layer,
144144
int64_t offset,
145145
int64_t chunk_size,
146-
const quantize_fns_t & qfns,
146+
const ggml_type_handling_t & qfns,
147147
bool use_reference,
148148
float * input_scratch,
149149
char * quantized_scratch,
@@ -159,11 +159,11 @@ void test_roundtrip_on_chunk(
159159
}
160160

161161
if (use_reference) {
162-
qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size);
162+
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size);
163163
} else {
164-
qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
164+
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
165165
}
166-
qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size);
166+
qfns.to_float(quantized_scratch, output_scratch, chunk_size);
167167

168168
update_error_stats(chunk_size, input_scratch, output_scratch, stats);
169169
}
@@ -173,7 +173,7 @@ void test_roundtrip_on_chunk(
173173
void test_roundtrip_on_layer(
174174
std::string & name,
175175
bool print_layer_stats,
176-
const quantize_fns_t & qfns,
176+
const ggml_type_handling_t & qfns,
177177
bool use_reference,
178178
const ggml_tensor * layer,
179179
std::vector<float> & input_scratch,
@@ -374,8 +374,8 @@ int main(int argc, char ** argv) {
374374
if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
375375
continue;
376376
}
377-
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
378-
if (qfns.quantize_row_q && qfns.dequantize_row_q) {
377+
ggml_type_handling_t qfns = ggml_internal_get_type_handling(type);
378+
if (qfns.from_float && qfns.to_float) {
379379
if (params.verbose) {
380380
printf("testing %s ...\n", ggml_type_name(type));
381381
}

0 commit comments

Comments
 (0)