Skip to content

Commit 1b3b907

Browse files
committed
[examples/swift] fix multi-byte unicode character parsing
1 parent 90c5462 commit 1b3b907

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

examples/swift/Sources/main.swift

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ if n_kv_req > n_ctx {
6464

6565
var buffer: [CChar] = []
6666
for id: llama_token in tokens {
67-
print(token_to_piece(token: id), terminator: "")
67+
print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "")
6868
}
6969

7070
print("\n")
@@ -101,6 +101,7 @@ if n_parallel > 1 {
101101
}
102102

103103
var streams: [String] = .init(repeating: "", count: n_parallel)
104+
var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel)
104105
var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel)
105106

106107
var n_cur = batch.n_tokens
@@ -157,12 +158,13 @@ while n_cur <= n_len {
157158
continue
158159
}
159160

161+
let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? ""
162+
160163
// if there is only one stream, we print immediately to stdout
161164
if n_parallel == 1 {
162-
print(token_to_piece(token: new_token_id), terminator: "")
165+
print(nextStringPiece, terminator: "")
163166
}
164-
165-
streams[i] += token_to_piece(token: new_token_id)
167+
streams[i] += nextStringPiece
166168

167169
// push this new token for next evaluation
168170
batch.token[Int(batch.n_tokens)] = new_token_id
@@ -216,11 +218,38 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
216218
return swiftTokens
217219
}
218220

219-
private func token_to_piece(token: llama_token) -> String {
220-
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
221-
result.initialize(repeating: Int8(0), count: 8)
222-
let _ = llama_token_to_piece(model, token, result, 8)
223-
let resultStr = String(cString: result)
224-
result.deallocate()
225-
return resultStr
221+
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
222+
var result = [CChar](repeating: 0, count: 8)
223+
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count))
224+
if nTokens < 0 {
225+
if result.count >= -Int(nTokens) {
226+
result.removeLast(-Int(nTokens))
227+
} else {
228+
result.removeAll()
229+
}
230+
let check = llama_token_to_piece(
231+
model,
232+
token,
233+
&result,
234+
Int32(result.count)
235+
)
236+
assert(check == nTokens)
237+
} else {
238+
result.removeLast(result.count - Int(nTokens))
239+
}
240+
if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) {
241+
return utfString
242+
} else {
243+
buffer.append(contentsOf: result)
244+
let data = Data(buffer.map { UInt8(bitPattern: $0) })
245+
if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
246+
buffer = []
247+
}
248+
guard let bufferString = String(data: data, encoding: .utf8) else {
249+
return nil
250+
}
251+
buffer = []
252+
return bufferString
253+
}
254+
return nil
226255
}

0 commit comments

Comments
 (0)