@@ -64,7 +64,7 @@ if n_kv_req > n_ctx {
64
64
65
65
var buffer : [ CChar ] = [ ]
66
66
for id : llama_token in tokens {
67
- print ( token_to_piece ( token: id) , terminator: " " )
67
+ print ( token_to_piece ( token: id, buffer : & buffer ) ?? " " , terminator: " " )
68
68
}
69
69
70
70
print ( " \n " )
@@ -101,6 +101,7 @@ if n_parallel > 1 {
101
101
}
102
102
103
103
var streams : [ String ] = . init( repeating: " " , count: n_parallel)
104
+ var streamBuffers : [ [ CChar ] ] = . init( repeating: [ ] , count: n_parallel)
104
105
var i_batch = [ Int32] ( repeating: batch. n_tokens - 1 , count: n_parallel)
105
106
106
107
var n_cur = batch. n_tokens
@@ -157,12 +158,13 @@ while n_cur <= n_len {
157
158
continue
158
159
}
159
160
161
+ let nextStringPiece = token_to_piece ( token: new_token_id, buffer: & streamBuffers[ i] ) ?? " "
162
+
160
163
// if there is only one stream, we print immediately to stdout
161
164
if n_parallel == 1 {
162
- print ( token_to_piece ( token : new_token_id ) , terminator: " " )
165
+ print ( nextStringPiece , terminator: " " )
163
166
}
164
-
165
- streams [ i] += token_to_piece ( token: new_token_id)
167
+ streams [ i] += nextStringPiece
166
168
167
169
// push this new token for next evaluation
168
170
batch. token [ Int ( batch. n_tokens) ] = new_token_id
@@ -216,11 +218,38 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
216
218
return swiftTokens
217
219
}
218
220
219
- private func token_to_piece( token: llama_token ) -> String {
220
- let result = UnsafeMutablePointer< Int8> . allocate( capacity: 8 )
221
- result. initialize ( repeating: Int8 ( 0 ) , count: 8 )
222
- let _ = llama_token_to_piece ( model, token, result, 8 )
223
- let resultStr = String ( cString: result)
224
- result. deallocate ( )
225
- return resultStr
221
+ private func token_to_piece( token: llama_token , buffer: inout [ CChar ] ) -> String ? {
222
+ var result = [ CChar] ( repeating: 0 , count: 8 )
223
+ let nTokens = llama_token_to_piece ( model, token, & result, Int32 ( result. count) )
224
+ if nTokens < 0 {
225
+ if result. count >= - Int( nTokens) {
226
+ result. removeLast ( - Int( nTokens) )
227
+ } else {
228
+ result. removeAll ( )
229
+ }
230
+ let check = llama_token_to_piece (
231
+ model,
232
+ token,
233
+ & result,
234
+ Int32 ( result. count)
235
+ )
236
+ assert ( check == nTokens)
237
+ } else {
238
+ result. removeLast ( result. count - Int( nTokens) )
239
+ }
240
+ if buffer. isEmpty, let utfString = String ( cString: result + [ 0 ] , encoding: . utf8) {
241
+ return utfString
242
+ } else {
243
+ buffer. append ( contentsOf: result)
244
+ let data = Data ( buffer. map { UInt8 ( bitPattern: $0) } )
245
+ if buffer. count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
246
+ buffer = [ ]
247
+ }
248
+ guard let bufferString = String ( data: data, encoding: . utf8) else {
249
+ return nil
250
+ }
251
+ buffer = [ ]
252
+ return bufferString
253
+ }
254
+ return nil
226
255
}
0 commit comments