Skip to content

Commit 4f0154b

Browse files
authored
llama : support requantizing models instead of only allowing quantization from 16/32bit (#1691)
* Add support for quantizing already quantized models * Threaded dequantizing and f16 to f32 conversion * Clean up thread blocks with spares calculation a bit * Use std::runtime_error exceptions.
1 parent ef3171d commit 4f0154b

File tree

3 files changed

+134
-40
lines changed

3 files changed

+134
-40
lines changed

examples/quantize/quantize.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama.h"
44

55
#include <cstdio>
6+
#include <cstring>
67
#include <map>
78
#include <string>
89

@@ -53,27 +54,49 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st
5354
// usage:
5455
// ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
5556
//
57+
void usage(const char * executable) {
58+
fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n", executable);
59+
fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
60+
fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
61+
fprintf(stderr, "Allowed quantization types:\n");
62+
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
63+
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
64+
}
65+
exit(1);
66+
}
67+
5668
int main(int argc, char ** argv) {
5769
if (argc < 3) {
58-
fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
59-
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
60-
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
70+
usage(argv[0]);
71+
}
72+
73+
llama_model_quantize_params params = llama_model_quantize_default_params();
74+
75+
int arg_idx = 1;
76+
77+
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
78+
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
79+
params.quantize_output_tensor = false;
80+
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
81+
params.allow_requantize = true;
82+
} else {
83+
usage(argv[0]);
6184
}
62-
return 1;
85+
}
86+
87+
if (argc - arg_idx < 3) {
88+
usage(argv[0]);
6389
}
6490

6591
llama_init_backend();
6692

6793
// parse command line arguments
68-
const std::string fname_inp = argv[1];
94+
const std::string fname_inp = argv[arg_idx];
95+
arg_idx++;
6996
std::string fname_out;
70-
int nthread;
71-
llama_ftype ftype;
7297

73-
int arg_idx = 2;
7498
std::string ftype_str;
75-
if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
76-
// argv[2] is the ftype
99+
if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
77100
std::string fpath;
78101
const size_t pos = fname_inp.find_last_of('/');
79102
if (pos != std::string::npos) {
@@ -84,16 +107,14 @@ int main(int argc, char ** argv) {
84107
arg_idx++;
85108
}
86109
else {
87-
// argv[2] is the output path
88110
fname_out = argv[arg_idx];
89111
arg_idx++;
90112

91113
if (argc <= arg_idx) {
92114
fprintf(stderr, "%s: missing ftype\n", __func__);
93115
return 1;
94116
}
95-
// argv[3] is the ftype
96-
if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
117+
if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
97118
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
98119
return 1;
99120
}
@@ -103,21 +124,19 @@ int main(int argc, char ** argv) {
103124
// parse nthreads
104125
if (argc > arg_idx) {
105126
try {
106-
nthread = std::stoi(argv[arg_idx]);
127+
params.nthread = std::stoi(argv[arg_idx]);
107128
}
108129
catch (const std::exception & e) {
109130
fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
110131
return 1;
111132
}
112-
} else {
113-
nthread = 0;
114133
}
115134

116135
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
117136

118137
fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
119-
if (nthread > 0) {
120-
fprintf(stderr, " using %d threads", nthread);
138+
if (params.nthread > 0) {
139+
fprintf(stderr, " using %d threads", params.nthread);
121140
}
122141
fprintf(stderr, "\n");
123142

@@ -129,7 +148,7 @@ int main(int argc, char ** argv) {
129148
{
130149
const int64_t t_start_us = llama_time_us();
131150

132-
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
151+
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), &params)) {
133152
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
134153
return 1;
135154
}

llama.cpp

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,17 @@ struct llama_context_params llama_context_default_params() {
886886
return result;
887887
}
888888

889+
struct llama_model_quantize_params llama_model_quantize_default_params() {
890+
struct llama_model_quantize_params result = {
891+
/*.nthread =*/ 0,
892+
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
893+
/*.allow_requantize =*/ false,
894+
/*.quantize_output_tensor =*/ true,
895+
};
896+
897+
return result;
898+
}
899+
889900
bool llama_mmap_supported() {
890901
return llama_mmap::SUPPORTED;
891902
}
@@ -2231,9 +2242,70 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
22312242
// quantization
22322243
//
22332244

2234-
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype, int nthread) {
2245+
static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llama_buffer & output, const int nelements, const int nthread) {
2246+
if (output.size < nelements * sizeof(float)) {
2247+
output.resize(nelements * sizeof(float));
2248+
}
2249+
float * f32_output = (float *) output.addr;
2250+
2251+
quantize_fns_t qtype;
2252+
if (ggml_is_quantized(tensor.type)) {
2253+
qtype = ggml_internal_get_quantize_fn(tensor.type);
2254+
if (qtype.dequantize_row_q == NULL) {
2255+
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type)));
2256+
}
2257+
} else if (tensor.type != GGML_TYPE_F16) {
2258+
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor.type)));
2259+
}
2260+
2261+
if (nthread < 2) {
2262+
if (tensor.type == GGML_TYPE_F16) {
2263+
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements);
2264+
} else if (ggml_is_quantized(tensor.type)) {
2265+
qtype.dequantize_row_q(tensor.data, f32_output, nelements);
2266+
} else {
2267+
LLAMA_ASSERT(false); // unreachable
2268+
}
2269+
return;
2270+
}
2271+
2272+
auto block_size = tensor.type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor.type);
2273+
auto block_size_bytes = ggml_type_size(tensor.type);
2274+
2275+
LLAMA_ASSERT(nelements % block_size == 0);
2276+
auto nblocks = nelements / block_size;
2277+
auto blocks_per_thread = nblocks / nthread;
2278+
auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
2279+
2280+
std::vector<std::thread> workers;
2281+
for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) {
2282+
auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
2283+
auto thr_elems = thr_blocks * block_size; // number of elements for this thread
2284+
auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
2285+
2286+
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
2287+
if (typ == GGML_TYPE_F16) {
2288+
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
2289+
} else {
2290+
qtype.dequantize_row_q(inbuf, outbuf, nels);
2291+
}
2292+
};
2293+
workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems));
2294+
in_buff_offs += thr_block_bytes;
2295+
out_buff_offs += thr_elems;
2296+
}
2297+
for (auto & worker : workers) {
2298+
worker.join();
2299+
}
2300+
2301+
}
2302+
2303+
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
22352304
ggml_type quantized_type;
2236-
switch (ftype) {
2305+
llama_ftype ftype = params->ftype;
2306+
int nthread = params->nthread;
2307+
2308+
switch (params->ftype) {
22372309
case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
22382310
case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
22392311
case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break;
@@ -2259,7 +2331,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
22592331

22602332
std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false,
22612333
/*vocab_only*/ false));
2262-
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype);
2334+
llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype);
22632335

22642336
int n_attention_wv = 0;
22652337
int n_feed_forward_w2 = 0;
@@ -2301,9 +2373,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
23012373
quantize &= (tensor.ne.size() == 2);
23022374

23032375
// uncomment this to keep the output layer in FP16
2304-
//if (tensor.name == "output.weight") {
2305-
// quantize = false;
2306-
//}
2376+
if (!params->quantize_output_tensor && tensor.name == "output.weight") {
2377+
quantize = false;
2378+
}
2379+
quantize = quantize && quantized_type != tensor.type;
23072380

23082381
enum ggml_type new_type;
23092382
void * new_data;
@@ -2346,17 +2419,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
23462419
float * f32_data;
23472420
size_t nelements = tensor.ne.at(0) * tensor.ne.at(1);
23482421
llama_buffer f32_conv_buf;
2422+
23492423
if (tensor.type == GGML_TYPE_F32) {
23502424
f32_data = (float *) tensor.data;
2351-
} else if (tensor.type == GGML_TYPE_F16) {
2352-
f32_conv_buf.resize(nelements * sizeof(float));
2353-
f32_data = (float *) f32_conv_buf.addr;
2354-
const auto * f16_data = (const ggml_fp16_t *) tensor.data;
2355-
for (size_t i = 0; i < nelements; i++) {
2356-
f32_data[i] = ggml_fp16_to_fp32(f16_data[i]);
2357-
}
2425+
} else if (ggml_is_quantized(tensor.type) && !params->allow_requantize) {
2426+
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor.type)));
23582427
} else {
2359-
throw std::runtime_error(format("type %s unsupported for integer quantization", ggml_type_name(tensor.type)));
2428+
llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread);
2429+
f32_data = (float *) f32_conv_buf.addr;
23602430
}
23612431

23622432
printf("quantizing .. ");
@@ -2566,10 +2636,9 @@ void llama_free(struct llama_context * ctx) {
25662636
int llama_model_quantize(
25672637
const char * fname_inp,
25682638
const char * fname_out,
2569-
enum llama_ftype ftype,
2570-
int nthread) {
2639+
const llama_model_quantize_params *params) {
25712640
try {
2572-
llama_model_quantize_internal(fname_inp, fname_out, ftype, nthread);
2641+
llama_model_quantize_internal(fname_inp, fname_out, params);
25732642
return 0;
25742643
} catch (const std::exception & err) {
25752644
fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.what());

llama.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,16 @@ extern "C" {
115115
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
116116
};
117117

118+
// model quantization parameters
119+
typedef struct llama_model_quantize_params {
120+
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
121+
enum llama_ftype ftype; // quantize to this llama_ftype
122+
bool allow_requantize; // allow quantizing non-f32/f16 tensors
123+
bool quantize_output_tensor; // quantize output.weight
124+
} llama_model_quantize_params;
125+
118126
LLAMA_API struct llama_context_params llama_context_default_params();
127+
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params();
119128

120129
LLAMA_API bool llama_mmap_supported();
121130
LLAMA_API bool llama_mlock_supported();
@@ -137,14 +146,11 @@ extern "C" {
137146
// Frees all allocated memory
138147
LLAMA_API void llama_free(struct llama_context * ctx);
139148

140-
// TODO: not great API - very likely to change
141149
// Returns 0 on success
142-
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
143150
LLAMA_API int llama_model_quantize(
144151
const char * fname_inp,
145152
const char * fname_out,
146-
enum llama_ftype ftype,
147-
int nthread);
153+
const llama_model_quantize_params * params);
148154

149155
// Apply a LoRA adapter to a loaded model
150156
// path_base_model is the path to a higher quality model to use as a base for

0 commit comments

Comments
 (0)