Skip to content

Commit 603c771

Browse files
authored
Configurable sparse prediction threshold (ggml-org#7)
* remove warning at gpu split * remove dead code * adaptive sparsity threshold reading from model file * convert models with sparse threshold
1 parent 597ef34 commit 603c771

File tree

9 files changed

+96
-41
lines changed

9 files changed

+96
-41
lines changed

convert.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import argparse
55
import concurrent.futures
6+
import dataclasses
67
import enum
78
import faulthandler
89
import functools
@@ -138,6 +139,28 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
138139
# hparams loading
139140
#
140141

142+
@dataclass
143+
class PredictorParams:
144+
sparse_threshold: float | None = None
145+
146+
@staticmethod
147+
def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams:
148+
config = json.load(open(config_path))
149+
return PredictorParams(
150+
sparse_threshold = config.get("sparse_threshold"),
151+
)
152+
153+
@staticmethod
154+
def load(model_plus: ModelPlus) -> PredictorParams:
155+
config_path = model_plus.paths[0].parent / "config.json"
156+
157+
if config_path.exists():
158+
params = PredictorParams.loadPredictorJson(model_plus.model, config_path)
159+
else:
160+
params = PredictorParams()
161+
162+
return params
163+
141164
@dataclass
142165
class Params:
143166
n_vocab: int
@@ -160,6 +183,9 @@ class Params:
160183
# path to the directory containing the model files
161184
path_model: Path | None = None
162185

186+
# MLP predictor parameters
187+
predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams)
188+
163189
@staticmethod
164190
def guessed(model: LazyModel) -> Params:
165191
# try transformer naming first
@@ -843,6 +869,9 @@ def add_meta_arch(self, params: Params) -> None:
843869
if params.ftype is not None:
844870
self.gguf.add_file_type(params.ftype)
845871

872+
if params.predictor_params.sparse_threshold is not None:
873+
self.gguf.add_sparse_threshold(params.predictor_params.sparse_threshold)
874+
846875
def add_meta_vocab(self, vocab: Vocab) -> None:
847876
tokens = []
848877
scores = []
@@ -1181,10 +1210,13 @@ def main(args_in: list[str] | None = None) -> None:
11811210

11821211
if not args.vocab_only:
11831212
model_plus = load_some_model(args.model)
1213+
params = Params.load(model_plus)
11841214
mlp_predictor_plus = load_mlp_model(args.mlp_model)
1215+
params.predictor_params = PredictorParams.load(mlp_predictor_plus)
11851216
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
11861217
else:
11871218
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
1219+
params = Params.load(model_plus)
11881220

11891221
if args.dump:
11901222
do_dump_model(model_plus)
@@ -1193,7 +1225,6 @@ def main(args_in: list[str] | None = None) -> None:
11931225
if args.bigendian:
11941226
endianess = gguf.GGUFEndian.BIG
11951227

1196-
params = Params.load(model_plus)
11971228
if params.n_ctx == -1:
11981229
if args.ctx is None:
11991230
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"

ggml-cuda.cu

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@
108108
// max batch size to use MMQ kernels when tensor cores are available
109109
#define MMQ_MAX_BATCH_SIZE 32
110110

111+
__constant__ float dev_sparse_threshold;
112+
111113
#if defined(GGML_USE_HIPBLAS)
112114
#define __CUDA_ARCH__ 1300
113115

@@ -4483,7 +4485,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__
44834485
// printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095));
44844486
// }
44854487
// int id = row;
4486-
if (idx[id] < 0.0f) {
4488+
if (idx[id] < dev_sparse_threshold) {
44874489
return;
44884490
}
44894491

@@ -4552,12 +4554,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
45524554
return;
45534555
}
45544556
int id = lst[row];
4555-
// int id = row;
4556-
// if (idx[id] < 0.0f) {
4557-
// return;
4558-
// }
45594557
const int bid = blockIdx.y;
4560-
// if (bid == 0) global_lock = 0;
45614558

45624559
extern __shared__ float shared_dst[]; // TODO:dynamic
45634560

@@ -4578,7 +4575,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
45784575
// __syncthreads();
45794576
for (int col_id = 0; col_id < src1_ncols; col_id++) {
45804577
__syncthreads();
4581-
if (loop_idx[id] < 0.0f) {
4578+
if (loop_idx[id] < dev_sparse_threshold) {
45824579
loop_dst += ncols;
45834580
loop_idx += src1_ne0;
45844581
loop_y += src1_ne0;
@@ -4640,7 +4637,7 @@ static __global__ void dequantize_axpy_sparse(const void * __restrict__ vx, cons
46404637
return;
46414638
}
46424639
int id = lst[row];
4643-
if (idx[id] < 0.0f) {
4640+
if (idx[id] < dev_sparse_threshold) {
46444641
return;
46454642
}
46464643

@@ -4689,8 +4686,7 @@ static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ v
46894686
return;
46904687
}
46914688
int id = lst[row];
4692-
// int id = row;
4693-
if (idx[id] < 0.0f) {
4689+
if (idx[id] < dev_sparse_threshold) {
46944690
return;
46954691
}
46964692

@@ -4782,7 +4778,7 @@ static __global__ void dequantize_mul_mat_batch_sparse(const void * __restrict__
47824778
{
47834779
__syncthreads();
47844780
tmp = 0.0f;
4785-
if (loop_idx[id] < 0.0f)
4781+
if (loop_idx[id] < dev_sparse_threshold)
47864782
{
47874783
loop_dst += dst_ne0;
47884784
loop_idx += dst_ne0;
@@ -9618,3 +9614,6 @@ ggml_backend_t ggml_backend_cuda_init() {
96189614
return cuda_backend;
96199615
}
96209616

9617+
void ggml_cuda_set_device_constants(float sparse_pred_threshold) {
9618+
CUDA_CHECK(cudaMemcpyToSymbol(dev_sparse_threshold, &sparse_pred_threshold, sizeof(float)));
9619+
}

ggml-cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ GGML_API int ggml_cuda_get_device_count(void);
5353
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
5454
GGML_API size_t ggml_cuda_get_free_memory(int device);
5555

56+
GGML_API void ggml_cuda_set_device_constants(float sparse_pred_threshold);
57+
5658
// backend API
5759
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
5860

ggml.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14059,6 +14059,8 @@ static void ggml_compute_forward_mul_mat_sparse(
1405914059
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
1406014060
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
1406114061

14062+
const float threshold = sparse_pred_threshold;
14063+
1406214064
GGML_ASSERT(ne0 == ne01);
1406314065
GGML_ASSERT(ne1 == ne11);
1406414066
GGML_ASSERT(ne2 == ne12);
@@ -14262,7 +14264,7 @@ static void ggml_compute_forward_mul_mat_sparse(
1426214264
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
1426314265

1426414266
// if (ffdata[ir0] <= 0.0f) {
14265-
if (gid[ir0] == 1 || ffdata[ir0] < -0.0f) {
14267+
if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
1426614268
dst_col[ir0] = 0;
1426714269
continue;
1426814270
}
@@ -14413,11 +14415,6 @@ static void ggml_compute_forward_mul_mat_axpy_dense(
1441314415
const int ir0 = atomic_fetch_add(params->aic, dr);
1441414416
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
1441514417
if (ir1 >= nr) break;
14416-
// if (gid[ir1] == 1)
14417-
// continue;
14418-
// if (idx[ir1] < 0.0f)
14419-
// continue;
14420-
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
1442114418
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
1442214419
}
1442314420
if (ir0 + dr >= nr)
@@ -14482,6 +14479,8 @@ static void ggml_compute_forward_mul_mat_axpy(
1448214479
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
1448314480
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
1448414481

14482+
const float threshold = sparse_pred_threshold;
14483+
1448514484
// GGML_ASSERT(ne0 == ne01);
1448614485
// GGML_ASSERT(ne1 == ne11);
1448714486
// GGML_ASSERT(ne2 == ne12);
@@ -14569,7 +14568,7 @@ static void ggml_compute_forward_mul_mat_axpy(
1456914568
if (gid[ir1] == 1) {
1457014569
continue;
1457114570
}
14572-
if (idx[ir1] < -0.0f)
14571+
if (idx[ir1] < threshold)
1457314572
continue;
1457414573
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
1457514574
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]);
@@ -14632,6 +14631,8 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
1463214631
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
1463314632
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
1463414633

14634+
const float threshold = sparse_pred_threshold;
14635+
1463514636
// GGML_ASSERT(ne0 == ne01);
1463614637
// GGML_ASSERT(ne1 == ne11);
1463714638
// GGML_ASSERT(ne2 == ne12);
@@ -14713,7 +14714,7 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
1471314714
break;
1471414715
if (gid[ir1] == 1)
1471514716
continue;
14716-
if (idx[ir1] < 0.0f)
14717+
if (idx[ir1] < threshold)
1471714718
continue;
1471814719
int bid = ir1 / QK8_0;
1471914720
int qsid = ir1 % QK8_0;

ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,6 +2196,12 @@ extern "C" {
21962196
GGML_API int ggml_cpu_has_ssse3 (void);
21972197
GGML_API int ggml_cpu_has_vsx (void);
21982198

2199+
//
2200+
// global variables
2201+
//
2202+
// TODO: these should be moved to the context
2203+
extern float sparse_pred_threshold;
2204+
21992205
//
22002206
// Internal types and functions exposed for tests and benchmarks
22012207
//

gguf-py/gguf/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class Tokenizer:
7070
ADD_EOS = "tokenizer.ggml.add_eos_token"
7171
HF_JSON = "tokenizer.huggingface.json"
7272
RWKV = "tokenizer.rwkv.world"
73+
74+
class PowerInfer:
75+
SPARSE_THRESHOLD = "powerinfer.sparse_threshold"
7376

7477

7578
#

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def add_add_bos_token(self, value: bool) -> None:
399399
def add_add_eos_token(self, value: bool) -> None:
400400
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
401401

402+
def add_sparse_threshold(self, value: float) -> None:
403+
self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value)
404+
402405
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
403406
pack_prefix = ''
404407
if not skip_pack_prefix:

llama.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@
9393

9494
#define LLAMA_MAX_NODES 4096
9595

96+
//
97+
// global variables
98+
//
99+
100+
// sparsity threshold for sparse matrix multiplication prediction
101+
float sparse_pred_threshold = 0.;
102+
96103
//
97104
// logging
98105
//
@@ -257,6 +264,8 @@ enum llm_kv {
257264
LLM_KV_TOKENIZER_PAD_ID,
258265
LLM_KV_TOKENIZER_HF_JSON,
259266
LLM_KV_TOKENIZER_RWKV,
267+
268+
LLM_KV_SPARSE_THRESHOLD,
260269
};
261270

262271
static std::map<llm_kv, std::string> LLM_KV_NAMES = {
@@ -305,6 +314,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
305314
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
306315
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
307316
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
317+
318+
{ LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" },
308319
};
309320

310321
struct LLM_KV {
@@ -1150,6 +1161,9 @@ struct llama_hparams {
11501161

11511162
float f_clamp_kqv;
11521163
float f_max_alibi_bias;
1164+
1165+
// sparse predictor threshold if sparse inference is enabled
1166+
float sparse_pred_threshold = atof(getenv("LLAMA_SPARSE_PRED_THRESHOLD") ?: "0.0");
11531167

11541168
bool operator!=(const llama_hparams & other) const {
11551169
if (this->vocab_only != other.vocab_only) return true;
@@ -2220,6 +2234,11 @@ static void llm_load_hparams(
22202234
// gpt-j n_rot = rotary_dim
22212235
}
22222236

2237+
if (gguf_get_sparse_deriv(ctx)) {
2238+
// read sparse threshold override if sparse deriv is enabled
2239+
GGUF_GET_KEY(ctx, hparams.sparse_pred_threshold, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_SPARSE_THRESHOLD));
2240+
}
2241+
22232242
// arch-specific KVs
22242243
switch (model.arch) {
22252244
case LLM_ARCH_LLAMA:
@@ -2607,6 +2626,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
26072626
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
26082627
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
26092628
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
2629+
2630+
// sparse inference
2631+
LLAMA_LOG_INFO("%s: sparse_pred_threshold = %.2f\n", __func__, hparams.sparse_pred_threshold);
26102632
}
26112633

26122634

@@ -2808,7 +2830,7 @@ struct llama_augmentation_model_loader {
28082830
return NULL;
28092831
}
28102832
// allocate and copy selected weights to gpu
2811-
#ifdef GGML_USE_CUBLAS
2833+
#ifdef GGML_USE_CUBLAS
28122834
int64_t row_len = src->ne[0];
28132835
int64_t gpu_rows = gpu_bucket->ne[0];
28142836
if (gpu_rows == 0)
@@ -2841,10 +2863,9 @@ struct llama_augmentation_model_loader {
28412863
ggml_set_no_alloc(aux_ctx, false);
28422864

28432865
return gpu_dst;
2844-
#else
2845-
printf("As you do not support CUDA. Split to GPU is not allowed.\n");
2866+
#else
28462867
return NULL;
2847-
#endif
2868+
#endif
28482869
}
28492870

28502871
void slice_ffn_mat_to_gpu(llama_layer & layer) {
@@ -2882,22 +2903,11 @@ struct llama_augmentation_model_loader {
28822903
const int64_t t_start_aug_us = ggml_time_us();
28832904
std::vector<uint8_t> work_buffer;
28842905

2885-
// transpose ffn_down to use axpy
2886-
// ggml_cgraph * tmp_transpose_gf = ggml_new_graph(aux_ctx);
2887-
// for (llama_layer &model_layer : model -> layers) {
2888-
// // gpu_w2 transpose load
2889-
// ggml_tensor * ffn_down_t = ggml_cont(aux_ctx, ggml_transpose(aux_ctx, model_layer.ffn_down));
2890-
// ggml_build_forward_expand(tmp_transpose_gf, ffn_down_t);
2891-
// model_layer.ffn_down_t = ffn_down_t;
2892-
// LLAMA_LOG_INFO(".");
2893-
// }
2894-
// ggml_graph_compute_helper(work_buffer, tmp_transpose_gf, 2);
2895-
// for (llama_layer &model_layer : model -> layers) {
2896-
// model_layer.ffn_down_t->op = GGML_OP_NONE;
2897-
// model_layer.ffn_down_t->src[0] = NULL;
2898-
// model_layer.ffn_down_t->src[1] = NULL;
2899-
// model_layer.ffn_down_t->src[2] = NULL;
2900-
// }
2906+
// Set sparsity threshold via global virables
2907+
sparse_pred_threshold = model->hparams.sparse_pred_threshold;
2908+
#if defined (GGML_USE_CUBLAS)
2909+
ggml_cuda_set_device_constants(model->hparams.sparse_pred_threshold);
2910+
#endif
29012911

29022912
// load gpu_idx and slice mat to gpu
29032913
for (llama_layer &model_layer : model -> layers) {

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy==1.24.4
22
sentencepiece==0.1.98
3-
gguf>=0.1.0
3+
-e ./gguf-py

0 commit comments

Comments
 (0)