@@ -11,28 +11,28 @@ actor LlamaContext {
11
11
private var context : OpaquePointer
12
12
private var batch : llama_batch
13
13
private var tokens_list : [ llama_token ]
14
-
14
+
15
15
var n_len : Int32 = 512
16
16
var n_cur : Int32 = 0
17
17
var n_decode : Int32 = 0
18
-
18
+
19
19
init ( model: OpaquePointer , context: OpaquePointer ) {
20
20
self . model = model
21
21
self . context = context
22
22
self . tokens_list = [ ]
23
23
self . batch = llama_batch_init ( 512 , 0 , 1 )
24
24
}
25
-
25
+
26
26
deinit {
27
27
llama_free ( context)
28
28
llama_free_model ( model)
29
29
llama_backend_free ( )
30
30
}
31
-
31
+
32
32
static func createContext( path: String ) throws -> LlamaContext {
33
33
llama_backend_init ( false )
34
34
let model_params = llama_model_default_params ( )
35
-
35
+
36
36
let model = llama_load_model_from_file ( path, model_params)
37
37
guard let model else {
38
38
print ( " Could not load model at \( path) " )
@@ -43,41 +43,41 @@ actor LlamaContext {
43
43
ctx_params. n_ctx = 2048
44
44
ctx_params. n_threads = 8
45
45
ctx_params. n_threads_batch = 8
46
-
46
+
47
47
let context = llama_new_context_with_model ( model, ctx_params)
48
48
guard let context else {
49
49
print ( " Could not load context! " )
50
50
throw LlamaError . couldNotInitializeContext
51
51
}
52
-
52
+
53
53
return LlamaContext ( model: model, context: context)
54
54
}
55
-
55
+
56
56
func get_n_tokens( ) -> Int32 {
57
57
return batch. n_tokens;
58
58
}
59
-
59
+
60
60
func completion_init( text: String ) {
61
61
print ( " attempting to complete \" \( text) \" " )
62
-
62
+
63
63
tokens_list = tokenize ( text: text, add_bos: true )
64
-
64
+
65
65
let n_ctx = llama_n_ctx ( context)
66
66
let n_kv_req = tokens_list. count + ( Int ( n_len) - tokens_list. count)
67
-
67
+
68
68
print ( " \n n_len = \( n_len) , n_ctx = \( n_ctx) , n_kv_req = \( n_kv_req) " )
69
69
70
70
if n_kv_req > n_ctx {
71
71
print ( " error: n_kv_req > n_ctx, the required KV cache size is not big enough " )
72
72
}
73
-
73
+
74
74
for id in tokens_list {
75
75
print ( token_to_piece ( token: id) )
76
76
}
77
-
77
+
78
78
// batch = llama_batch_init(512, 0) // done in init()
79
79
batch. n_tokens = Int32 ( tokens_list. count)
80
-
80
+
81
81
for i1 in 0 ..< batch. n_tokens {
82
82
let i = Int ( i1)
83
83
batch. token [ i] = tokens_list [ i]
@@ -87,90 +87,90 @@ actor LlamaContext {
87
87
batch. logits [ i] = 0
88
88
}
89
89
batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
90
-
90
+
91
91
if llama_decode ( context, batch) != 0 {
92
92
print ( " llama_decode() failed " )
93
93
}
94
-
94
+
95
95
n_cur = batch. n_tokens
96
96
}
97
-
97
+
98
98
func completion_loop( ) -> String {
99
99
var new_token_id : llama_token = 0
100
-
100
+
101
101
let n_vocab = llama_n_vocab ( model)
102
102
let logits = llama_get_logits_ith ( context, batch. n_tokens - 1 )
103
-
103
+
104
104
var candidates = Array < llama_token_data > ( )
105
105
candidates. reserveCapacity ( Int ( n_vocab) )
106
-
106
+
107
107
for token_id in 0 ..< n_vocab {
108
108
candidates. append ( llama_token_data ( id: token_id, logit: logits![ Int ( token_id) ] , p: 0.0 ) )
109
109
}
110
110
candidates. withUnsafeMutableBufferPointer ( ) { buffer in
111
111
var candidates_p = llama_token_data_array ( data: buffer. baseAddress, size: buffer. count, sorted: false )
112
-
112
+
113
113
new_token_id = llama_sample_token_greedy ( context, & candidates_p)
114
114
}
115
-
115
+
116
116
if new_token_id == llama_token_eos ( context) || n_cur == n_len {
117
117
print ( " \n " )
118
118
return " "
119
119
}
120
-
120
+
121
121
let new_token_str = token_to_piece ( token: new_token_id)
122
122
print ( new_token_str)
123
123
// tokens_list.append(new_token_id)
124
-
124
+
125
125
batch. n_tokens = 0
126
-
126
+
127
127
batch. token [ Int ( batch. n_tokens) ] = new_token_id
128
128
batch. pos [ Int ( batch. n_tokens) ] = n_cur
129
129
batch. n_seq_id [ Int ( batch. n_tokens) ] = 1
130
130
batch. seq_id [ Int ( batch. n_tokens) ] ![ 0 ] = 0
131
131
batch. logits [ Int ( batch. n_tokens) ] = 1 // true
132
132
batch. n_tokens += 1
133
-
133
+
134
134
n_decode += 1
135
-
135
+
136
136
n_cur += 1
137
-
137
+
138
138
if llama_decode ( context, batch) != 0 {
139
139
print ( " failed to evaluate llama! " )
140
140
}
141
-
141
+
142
142
return new_token_str
143
143
}
144
-
144
+
145
145
func clear( ) {
146
146
tokens_list. removeAll ( )
147
147
}
148
-
148
+
149
149
private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
150
150
let n_tokens = text. count + ( add_bos ? 1 : 0 )
151
151
let tokens = UnsafeMutablePointer< llama_token> . allocate( capacity: n_tokens)
152
152
let tokenCount = llama_tokenize ( model, text, Int32 ( text. count) , tokens, Int32 ( n_tokens) , add_bos, false )
153
-
153
+
154
154
var swiftTokens : [ llama_token ] = [ ]
155
155
for i in 0 ..< tokenCount {
156
156
swiftTokens. append ( tokens [ Int ( i) ] )
157
157
}
158
-
158
+
159
159
tokens. deallocate ( )
160
-
160
+
161
161
return swiftTokens
162
162
}
163
-
163
+
164
164
private func token_to_piece( token: llama_token ) -> String {
165
165
let result = UnsafeMutablePointer< Int8> . allocate( capacity: 8 )
166
166
result. initialize ( repeating: Int8 ( 0 ) , count: 8 )
167
-
167
+
168
168
let _ = llama_token_to_piece ( model, token, result, 8 )
169
-
169
+
170
170
let resultStr = String ( cString: result)
171
-
171
+
172
172
result. deallocate ( )
173
-
173
+
174
174
return resultStr
175
175
}
176
176
}
0 commit comments