@@ -39,73 +39,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
39
39
return true ;
40
40
}
41
41
42
- // TODO: use common/sampling.h
43
- static llama_token sample_id (llama_context * ctx_llama, gpt_params & params) {
44
- auto & sparams = params.sparams ;
45
-
46
- // out of user input, sample next token
47
- const float temp = sparams.temp ;
48
- const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab (llama_get_model (ctx_llama)) : sparams.top_k ;
49
- const float top_p = sparams.top_p ;
50
- const float tfs_z = sparams.tfs_z ;
51
- const float typical_p = sparams.typical_p ;
52
- // const int32_t repeat_last_n = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n;
53
- // const float repeat_penalty = sparams.repeat_penalty;
54
- // const float alpha_presence = sparams.presence_penalty;
55
- // const float alpha_frequency = sparams.frequency_penalty;
56
- const int mirostat = sparams.mirostat ;
57
- const float mirostat_tau = sparams.mirostat_tau ;
58
- const float mirostat_eta = sparams.mirostat_eta ;
59
- // const bool penalize_nl = sparams.penalize_nl;
60
-
61
- llama_token id = 0 ;
62
- {
63
- auto logits = llama_get_logits (ctx_llama);
64
- auto n_vocab = llama_n_vocab (llama_get_model (ctx_llama));
65
-
66
- // Apply params.logit_bias map
67
- for (auto it = sparams.logit_bias .begin (); it != sparams.logit_bias .end (); it++) {
68
- logits[it->first ] += it->second ;
69
- }
70
-
71
- std::vector<llama_token_data> candidates;
72
- candidates.reserve (n_vocab);
73
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
74
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
75
- }
76
-
77
- llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
78
-
79
- if (temp <= 0 ) {
80
- // Greedy sampling
81
- id = llama_sample_token_greedy (ctx_llama, &candidates_p);
82
- } else {
83
- if (mirostat == 1 ) {
84
- static float mirostat_mu = 2 .0f * mirostat_tau;
85
- const int mirostat_m = 100 ;
86
- llama_sample_temp (ctx_llama, &candidates_p, temp);
87
- id = llama_sample_token_mirostat (ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
88
- } else if (mirostat == 2 ) {
89
- static float mirostat_mu = 2 .0f * mirostat_tau;
90
- llama_sample_temp (ctx_llama, &candidates_p, temp);
91
- id = llama_sample_token_mirostat_v2 (ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
92
- } else {
93
- // Temperature sampling
94
- llama_sample_top_k (ctx_llama, &candidates_p, top_k, 1 );
95
- llama_sample_tail_free (ctx_llama, &candidates_p, tfs_z, 1 );
96
- llama_sample_typical (ctx_llama, &candidates_p, typical_p, 1 );
97
- llama_sample_top_p (ctx_llama, &candidates_p, top_p, 1 );
98
- llama_sample_temp (ctx_llama, &candidates_p, temp);
99
- id = llama_sample_token (ctx_llama, &candidates_p);
100
- }
101
- }
102
- }
103
-
104
- return id;
105
- }
106
-
107
- static const char * sample (struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
108
- int id = sample_id (ctx_llama, params);
42
+ static const char * sample (struct llama_sampling_context * ctx_sampling,
43
+ struct llama_context * ctx_llama,
44
+ int * n_past) {
45
+ const llama_token id = llama_sampling_sample (ctx_sampling, ctx_llama, NULL );
46
+ llama_sampling_accept (ctx_sampling, ctx_llama, id, true );
109
47
static std::string ret;
110
48
if (id == llama_token_eos (llama_get_model (ctx_llama))) {
111
49
ret = " </s>" ;
@@ -174,8 +112,8 @@ struct llava_context {
174
112
};
175
113
176
114
static void show_additional_info (int /* argc*/ , char ** argv) {
177
- printf ( " \n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \" describe the image in detail.\" ]\n " , argv[0 ]);
178
- printf ( " note: a lower temperature value like 0.1 is recommended for better quality.\n " );
115
+ fprintf (stderr, " \n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \" describe the image in detail.\" ]\n " , argv[0 ]);
116
+ fprintf (stderr, " note: a lower temperature value like 0.1 is recommended for better quality.\n " );
179
117
}
180
118
181
119
static struct llava_image_embed * load_image (llava_context * ctx_llava, gpt_params * params) {
@@ -185,7 +123,7 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
185
123
auto prompt = params->prompt ;
186
124
if (prompt_contains_image (prompt)) {
187
125
if (!params->image .empty ()) {
188
- printf ( " using base64 encoded image instead of command line image path\n " );
126
+ fprintf (stderr, " using base64 encoded image instead of command line image path\n " );
189
127
}
190
128
embed = llava_image_embed_make_with_prompt_base64 (ctx_llava->ctx_clip , params->n_threads , prompt);
191
129
if (!embed) {
@@ -217,16 +155,19 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
217
155
218
156
// generate the response
219
157
220
- printf (" \n " );
158
+ fprintf (stderr, " \n " );
159
+
160
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init (params->sparams );
221
161
222
162
for (int i = 0 ; i < max_tgt_len; i++) {
223
- const char * tmp = sample (ctx_llava->ctx_llama , *params , &n_past);
163
+ const char * tmp = sample (ctx_sampling, ctx_llava->ctx_llama , &n_past);
224
164
if (strcmp (tmp, " </s>" ) == 0 ) break ;
225
165
226
166
printf (" %s" , tmp);
227
167
fflush (stdout);
228
168
}
229
169
170
+ llama_sampling_free (ctx_sampling);
230
171
printf (" \n " );
231
172
}
232
173
0 commit comments