diff --git a/lib/src/llama.dart b/lib/src/llama.dart index b7fd9b7..e045caf 100644 --- a/lib/src/llama.dart +++ b/lib/src/llama.dart @@ -1,7 +1,6 @@ import 'dart:convert'; import 'dart:ffi'; import 'dart:math'; -import 'dart:typed_data'; import 'package:ffi/ffi.dart'; import 'package:llama_cpp_dart/src/sampling_params.dart'; @@ -225,8 +224,14 @@ class Llama { // Check if the sampled token is an EOS token. bool isEOSToken = newTokenId.value == lib.llama_token_eos(model); - // Convert the token ID to its string representation. - final newTokenStr = tokenToPiece(newTokenId.value); + // Prepare the string representation of the sampled token. + String newTokenStr = ""; + + // Check that the sampled token is not the BOS token. + if (newTokenId.value != lib.llama_token_bos(model)) { + // Convert the token ID to its string representation. + newTokenStr = tokenToPiece(newTokenId.value); + } // Update the batch and context for the next token generation. batch.n_tokens = 0; @@ -335,13 +340,16 @@ class Llama { /// It handles the conversion and memory management involved in this process. /// This is typically used in decoding the output of the model. String tokenToPiece(int token) { - Pointer result = malloc.allocate(32); + int bufferSize = 64; + Pointer result = malloc.allocate(bufferSize); try { - int nTokens = lib.llama_token_to_piece(model, token, result, 32); + int bytesWritten = lib.llama_token_to_piece(model, token, result, bufferSize); + + bytesWritten = min(bytesWritten, bufferSize - 1); + + final byteBuffer = result.cast().asTypedList(bytesWritten); - final ByteBuffer byteBuffer = result.cast().asTypedList(nTokens).buffer; - - return utf8.decode(byteBuffer.asUint8List(), allowMalformed: false); + return utf8.decode(byteBuffer, allowMalformed: true); } finally { malloc.free(result); } diff --git a/lib/src/llama_processor.dart b/lib/src/llama_processor.dart index dbe4a9d..12ca28c 100644 --- a/lib/src/llama_processor.dart +++ b/lib/src/llama_processor.dart @@ -93,7 +93,7 @@ class LlamaProcessor { Llama.libraryPath = args['libraryPath'] as String?; Llama? llama; - bool flagForStop = false; + Completer stopCompleter = Completer(); isolateReceivePort.listen((message) async { if (message is Map) { @@ -110,23 +110,18 @@ class LlamaProcessor { case 'prompt': llama?.setPrompt(message['prompt']); while (true) { - if (flagForStop) { - flagForStop = false; - break; - } + if (stopCompleter.isCompleted) break; + var (text, done) = llama!.getNext(); - if (done) break; mainSendPort.send(text); - await Future.delayed(Duration.zero); + + if (done) stopCompleter.complete(); } break; case 'stop': - flagForStop = true; + if (!stopCompleter.isCompleted) stopCompleter.complete(); llama?.clear(); break; - case 'clear': - // llama?.unloadModel(); - break; } } }); diff --git a/src/llama.cpp b/src/llama.cpp index 8bd91da..4a3d35a 160000 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1 +1 @@ -Subproject commit 8bd91daa2172ff6446476f9beac52c2738cdceb6 +Subproject commit 4a3d35a77d515ab24ec3638fab7828cadea9515f