Skip to content

Commit f2ad1e1

Browse files
authored
Merge pull request #20 from danemadsen/main
Exclude BOS token from output and further improvements to tokenToPiece
2 parents 19ac822 + d5eb0e9 commit f2ad1e1

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed

lib/src/llama.dart

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import 'dart:convert';
22
import 'dart:ffi';
33
import 'dart:math';
4-
import 'dart:typed_data';
54

65
import 'package:ffi/ffi.dart';
76
import 'package:llama_cpp_dart/src/sampling_params.dart';
@@ -225,8 +224,14 @@ class Llama {
225224
// Check if the sampled token is an EOS token.
226225
bool isEOSToken = newTokenId.value == lib.llama_token_eos(model);
227226

228-
// Convert the token ID to its string representation.
229-
final newTokenStr = tokenToPiece(newTokenId.value);
227+
// Prepare the string representation of the sampled token.
228+
String newTokenStr = "";
229+
230+
// Check that the sampled token is not the BOS token.
231+
if (newTokenId.value != lib.llama_token_bos(model)) {
232+
// Convert the token ID to its string representation.
233+
newTokenStr = tokenToPiece(newTokenId.value);
234+
}
230235

231236
// Update the batch and context for the next token generation.
232237
batch.n_tokens = 0;
@@ -335,13 +340,16 @@ class Llama {
335340
/// It handles the conversion and memory management involved in this process.
336341
/// This is typically used in decoding the output of the model.
337342
String tokenToPiece(int token) {
338-
Pointer<Char> result = malloc.allocate<Char>(32);
343+
int bufferSize = 64;
344+
Pointer<Char> result = malloc.allocate<Char>(bufferSize);
339345
try {
340-
int nTokens = lib.llama_token_to_piece(model, token, result, 32);
346+
int bytesWritten = lib.llama_token_to_piece(model, token, result, bufferSize);
347+
348+
bytesWritten = min(bytesWritten, bufferSize - 1);
349+
350+
final byteBuffer = result.cast<Uint8>().asTypedList(bytesWritten);
341351

342-
final ByteBuffer byteBuffer = result.cast<Uint8>().asTypedList(nTokens).buffer;
343-
344-
return utf8.decode(byteBuffer.asUint8List(), allowMalformed: false);
352+
return utf8.decode(byteBuffer, allowMalformed: true);
345353
} finally {
346354
malloc.free(result);
347355
}

lib/src/llama_processor.dart

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class LlamaProcessor {
9393
Llama.libraryPath = args['libraryPath'] as String?;
9494

9595
Llama? llama;
96-
bool flagForStop = false;
96+
Completer stopCompleter = Completer();
9797

9898
isolateReceivePort.listen((message) async {
9999
if (message is Map) {
@@ -110,23 +110,18 @@ class LlamaProcessor {
110110
case 'prompt':
111111
llama?.setPrompt(message['prompt']);
112112
while (true) {
113-
if (flagForStop) {
114-
flagForStop = false;
115-
break;
116-
}
113+
if (stopCompleter.isCompleted) break;
114+
117115
var (text, done) = llama!.getNext();
118-
if (done) break;
119116
mainSendPort.send(text);
120-
await Future.delayed(Duration.zero);
117+
118+
if (done) stopCompleter.complete();
121119
}
122120
break;
123121
case 'stop':
124-
flagForStop = true;
122+
if (!stopCompleter.isCompleted) stopCompleter.complete();
125123
llama?.clear();
126124
break;
127-
case 'clear':
128-
// llama?.unloadModel();
129-
break;
130125
}
131126
}
132127
});

src/llama.cpp

Submodule llama.cpp updated 187 files

0 commit comments

Comments
 (0)