diff --git a/.eslintrc.json b/.eslintrc.json index b6e2476a..a974ecd3 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -5,7 +5,13 @@ "browser": false, "es6": true }, - "ignorePatterns": ["/dist", "/llama", "/docs-site", "/packages/create-node-llama-cpp/dist"], + "ignorePatterns": [ + "/dist", + "/llama", + "/docs-site", + "/packages/create-node-llama-cpp/dist", + "/packages/@node-llama-cpp/*/dist" + ], "extends": [ "eslint:recommended", "plugin:jsdoc/recommended" @@ -133,7 +139,9 @@ "no-duplicate-imports": ["error", { "includeExports": true }], - "camelcase": ["warn"], + "camelcase": ["warn", { + "allow": ["\\d+_\\d+"] + }], "jsx-quotes": ["warn"], "yoda": ["error", "never", { "exceptRange": true diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7b165472..1f0ecf01 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -60,7 +60,7 @@ jobs: os: windows-2022 artifact: "win-arm" - name: "Ubuntu" - os: ubuntu-20.04 + os: ubuntu-22.04 artifact: "linux" - name: "macOS" os: macos-13 @@ -131,8 +131,8 @@ jobs: - name: Install Vulkan SDK on Ubuntu if: matrix.config.name == 'Ubuntu' run: | - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-focal.list https://packages.lunarg.com/vulkan/lunarg-vulkan-focal.list + wget -qO- https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo tee /etc/apt/trusted.gpg.d/lunarg.asc + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list sudo apt update sudo apt install vulkan-sdk @@ -513,7 +513,7 @@ jobs: - name: "Windows" os: windows-2022 - name: "Ubuntu" - os: ubuntu-20.04 + os: ubuntu-22.04 - name: "macOS" os: macos-13 diff --git a/llama/CMakeLists.txt b/llama/CMakeLists.txt index 5f4a8888..40544e6d 100644 --- a/llama/CMakeLists.txt +++ b/llama/CMakeLists.txt @@ -129,7 +129,7 @@ if (GGML_METAL) ) endif() -file(GLOB SOURCE_FILES "addon.cpp" ${GPU_INFO_SOURCES}) +file(GLOB SOURCE_FILES "addon/*.cpp" "addon/**/*.cpp" ${GPU_INFO_SOURCES}) if(APPLE) set(CMAKE_SKIP_BUILD_RPATH FALSE) diff --git a/llama/addon.cpp b/llama/addon.cpp deleted file mode 100644 index 045d47a4..00000000 --- a/llama/addon.cpp +++ /dev/null @@ -1,1997 +0,0 @@ -#include - -#include -#include -#include -#include - -#include "common.h" -#include "common/grammar-parser.h" -#include "llama.h" -#include "napi.h" - -#ifdef GPU_INFO_USE_CUDA -# include "gpuInfo/cuda-gpu-info.h" -#endif -#ifdef GPU_INFO_USE_VULKAN -# include "gpuInfo/vulkan-gpu-info.h" -#endif -#ifdef GPU_INFO_USE_METAL -# include "gpuInfo/metal-gpu-info.h" -#endif - - -struct addon_logger_log { - public: - const int logLevelNumber; - const std::stringstream* stringStream; -}; - -static void addonLlamaCppLogCallback(ggml_log_level level, const char* text, void* user_data); - -using AddonThreadSafeLogCallbackFunctionContext = Napi::Reference; -void addonCallJsLogCallback( - Napi::Env env, Napi::Function callback, AddonThreadSafeLogCallbackFunctionContext* context, addon_logger_log* data -); -using AddonThreadSafeLogCallbackFunction = - Napi::TypedThreadSafeFunction; - - -struct addon_progress_event { - public: - const float progress; -}; - -using AddonThreadSafeProgressCallbackFunctionContext = Napi::Reference; -void addonCallJsProgressCallback( - Napi::Env env, Napi::Function callback, AddonThreadSafeProgressCallbackFunctionContext* context, addon_progress_event* data -); -using AddonThreadSafeProgressEventCallbackFunction = - Napi::TypedThreadSafeFunction; - - -AddonThreadSafeLogCallbackFunction addonThreadSafeLoggerCallback; -bool addonJsLoggerCallbackSet = false; -int addonLoggerLogLevel = 5; -bool backendInitialized = false; -bool backendDisposed = false; - -void addonCallJsProgressCallback( - Napi::Env env, Napi::Function callback, AddonThreadSafeProgressCallbackFunctionContext* context, addon_progress_event* data -) { - if (env != nullptr && callback != nullptr && addonJsLoggerCallbackSet) { - try { - callback.Call({Napi::Number::New(env, data->progress)}); - } catch (const Napi::Error& e) {} - } - - if (data != nullptr) { - delete data; - } -} - -static uint64_t calculateBatchMemorySize(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - uint64_t totalSize = 0; - - if (embd) { - totalSize += sizeof(float) * n_tokens_alloc * embd; - } else { - totalSize += sizeof(llama_token) * n_tokens_alloc; - } - - totalSize += sizeof(llama_pos) * n_tokens_alloc; - totalSize += sizeof(int32_t) * n_tokens_alloc; - totalSize += sizeof(llama_seq_id *) * (n_tokens_alloc + 1); - - totalSize += sizeof(llama_seq_id) * n_seq_max * n_tokens_alloc; - - totalSize += sizeof(int8_t) * n_tokens_alloc; - - return totalSize; -} - -static void adjustNapiExternalMemoryAdd(Napi::Env env, uint64_t size) { - const uint64_t chunkSize = std::numeric_limits::max(); - while (size > 0) { - int64_t adjustSize = std::min(size, chunkSize); - Napi::MemoryManagement::AdjustExternalMemory(env, adjustSize); - size -= adjustSize; - } -} - -static void adjustNapiExternalMemorySubtract(Napi::Env env, uint64_t size) { - const uint64_t chunkSize = std::numeric_limits::max(); - while (size > 0) { - int64_t adjustSize = std::min(size, chunkSize); - Napi::MemoryManagement::AdjustExternalMemory(env, -adjustSize); - size -= adjustSize; - } -} - -#ifdef GPU_INFO_USE_CUDA -void logCudaError(const char* message) { - addonLlamaCppLogCallback(GGML_LOG_LEVEL_ERROR, (std::string("CUDA error: ") + std::string(message)).c_str(), nullptr); -} -#endif -#ifdef GPU_INFO_USE_VULKAN -void logVulkanWarning(const char* message) { - addonLlamaCppLogCallback(GGML_LOG_LEVEL_WARN, (std::string("Vulkan warning: ") + std::string(message)).c_str(), nullptr); -} -#endif - -Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { - uint64_t total = 0; - uint64_t used = 0; - -#ifdef GPU_INFO_USE_CUDA - size_t cudaDeviceTotal = 0; - size_t cudaDeviceUsed = 0; - bool cudeGetInfoSuccess = gpuInfoGetTotalCudaDevicesInfo(&cudaDeviceTotal, &cudaDeviceUsed, logCudaError); - - if (cudeGetInfoSuccess) { - total += cudaDeviceTotal; - used += cudaDeviceUsed; - } -#endif - -#ifdef GPU_INFO_USE_VULKAN - uint64_t vulkanDeviceTotal = 0; - uint64_t vulkanDeviceUsed = 0; - const bool vulkanDeviceSupportsMemoryBudgetExtension = gpuInfoGetTotalVulkanDevicesInfo(&vulkanDeviceTotal, &vulkanDeviceUsed, logVulkanWarning); - - if (vulkanDeviceSupportsMemoryBudgetExtension) { - total += vulkanDeviceTotal; - used += vulkanDeviceUsed; - } -#endif - -#ifdef GPU_INFO_USE_METAL - uint64_t metalDeviceTotal = 0; - uint64_t metalDeviceUsed = 0; - getMetalGpuInfo(&metalDeviceTotal, &metalDeviceUsed); - - total += metalDeviceTotal; - used += metalDeviceUsed; -#endif - - Napi::Object result = Napi::Object::New(info.Env()); - result.Set("total", Napi::Number::From(info.Env(), total)); - result.Set("used", Napi::Number::From(info.Env(), used)); - - return result; -} - -Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info) { - std::vector deviceNames; - -#ifdef GPU_INFO_USE_CUDA - gpuInfoGetCudaDeviceNames(&deviceNames, logCudaError); -#endif - -#ifdef GPU_INFO_USE_VULKAN - gpuInfoGetVulkanDeviceNames(&deviceNames, logVulkanWarning); -#endif - -#ifdef GPU_INFO_USE_METAL - getMetalGpuDeviceNames(&deviceNames); -#endif - - Napi::Object result = Napi::Object::New(info.Env()); - - Napi::Array deviceNamesNapiArray = Napi::Array::New(info.Env(), deviceNames.size()); - for (size_t i = 0; i < deviceNames.size(); ++i) { - deviceNamesNapiArray[i] = Napi::String::New(info.Env(), deviceNames[i]); - } - result.Set("deviceNames", deviceNamesNapiArray); - - return result; -} - -Napi::Value getGpuType(const Napi::CallbackInfo& info) { -#ifdef GPU_INFO_USE_CUDA - return Napi::String::New(info.Env(), "cuda"); -#endif - -#ifdef GPU_INFO_USE_VULKAN - return Napi::String::New(info.Env(), "vulkan"); -#endif - -#ifdef GPU_INFO_USE_METAL - return Napi::String::New(info.Env(), "metal"); -#endif - - return info.Env().Undefined(); -} - -static Napi::Value getNapiToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) { - if (token < 0) { - return Napi::Number::From(info.Env(), -1); - } - - auto tokenAttributes = llama_token_get_attr(model, token); - - if (tokenAttributes & LLAMA_TOKEN_ATTR_UNDEFINED || tokenAttributes & LLAMA_TOKEN_ATTR_UNKNOWN) { - return Napi::Number::From(info.Env(), -1); - } - - return Napi::Number::From(info.Env(), token); -} - -static Napi::Value getNapiControlToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) { - if (token < 0) { - return Napi::Number::From(info.Env(), -1); - } - - auto tokenAttributes = llama_token_get_attr(model, token); - - if (!(tokenAttributes & LLAMA_TOKEN_ATTR_CONTROL) && !(tokenAttributes & LLAMA_TOKEN_ATTR_UNDEFINED)) { - return Napi::Number::From(info.Env(), -1); - } - - return Napi::Number::From(info.Env(), token); -} - -static bool llamaModelParamsProgressCallback(float progress, void * user_data); - -class AddonModel : public Napi::ObjectWrap { - public: - llama_model_params model_params; - llama_model* model; - uint64_t loadedModelSize = 0; - Napi::Reference addonExportsRef; - bool hasAddonExportsRef = false; - - std::string modelPath; - bool modelLoaded = false; - bool abortModelLoad = false; - bool model_load_stopped = false; - float rawModelLoadPercentage = 0; - unsigned modelLoadPercentage = 0; - AddonThreadSafeProgressEventCallbackFunction addonThreadSafeOnLoadProgressEventCallback; - bool onLoadProgressEventCallbackSet = false; - bool hasLoadAbortSignal = false; - - bool disposed = false; - - AddonModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - model_params = llama_model_default_params(); - - // Get the model path - modelPath = info[0].As().Utf8Value(); - - if (info.Length() > 1 && info[1].IsObject()) { - Napi::Object options = info[1].As(); - - if (options.Has("addonExports")) { - addonExportsRef = Napi::Persistent(options.Get("addonExports").As()); - hasAddonExportsRef = true; - } - - if (options.Has("gpuLayers")) { - model_params.n_gpu_layers = options.Get("gpuLayers").As().Int32Value(); - } - - if (options.Has("vocabOnly")) { - model_params.vocab_only = options.Get("vocabOnly").As().Value(); - } - - if (options.Has("useMmap")) { - model_params.use_mmap = options.Get("useMmap").As().Value(); - } - - if (options.Has("useMlock")) { - model_params.use_mlock = options.Get("useMlock").As().Value(); - } - - if (options.Has("checkTensors")) { - model_params.check_tensors = options.Get("checkTensors").As().Value(); - } - - if (options.Has("onLoadProgress")) { - auto onLoadProgressJSCallback = options.Get("onLoadProgress").As(); - if (onLoadProgressJSCallback.IsFunction()) { - AddonThreadSafeProgressCallbackFunctionContext* context = new Napi::Reference(Napi::Persistent(info.This())); - addonThreadSafeOnLoadProgressEventCallback = AddonThreadSafeProgressEventCallbackFunction::New( - info.Env(), - onLoadProgressJSCallback, - "onLoadProgressCallback", - 0, - 1, - context, - [](Napi::Env, AddonModel* addonModel, AddonThreadSafeProgressCallbackFunctionContext* ctx) { - addonModel->onLoadProgressEventCallbackSet = false; - - delete ctx; - }, - this - ); - onLoadProgressEventCallbackSet = true; - } - } - - if (options.Has("hasLoadAbortSignal")) { - hasLoadAbortSignal = options.Get("hasLoadAbortSignal").As().Value(); - } - - if (onLoadProgressEventCallbackSet || hasLoadAbortSignal) { - model_params.progress_callback_user_data = &(*this); - model_params.progress_callback = llamaModelParamsProgressCallback; - } - } - } - - ~AddonModel() { - dispose(); - } - - void dispose() { - if (disposed) { - return; - } - - disposed = true; - if (modelLoaded) { - modelLoaded = false; - llama_free_model(model); - - adjustNapiExternalMemorySubtract(Env(), loadedModelSize); - loadedModelSize = 0; - } - - if (hasAddonExportsRef) { - addonExportsRef.Unref(); - hasAddonExportsRef = false; - } - } - - Napi::Value Init(const Napi::CallbackInfo& info); - Napi::Value LoadLora(const Napi::CallbackInfo& info); - Napi::Value AbortActiveModelLoad(const Napi::CallbackInfo& info) { - abortModelLoad = true; - return info.Env().Undefined(); - } - Napi::Value Dispose(const Napi::CallbackInfo& info); - - Napi::Value Tokenize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - std::string text = info[0].As().Utf8Value(); - bool specialTokens = info[1].As().Value(); - - std::vector tokens = llama_tokenize(model, text, false, specialTokens); - - Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size()); - for (size_t i = 0; i < tokens.size(); ++i) { - result[i] = static_cast(tokens[i]); - } - - return result; - } - Napi::Value Detokenize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - Napi::Uint32Array tokens = info[0].As(); - bool decodeSpecialTokens = info.Length() > 0 - ? info[1].As().Value() - : false; - - std::vector result(8, 0); - const int n_length = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens); - - if (n_length < 0) { - result.resize(-n_length); - int check = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens); - GGML_ASSERT(check == -n_length); - } else { - result.resize(n_length); - } - - return Napi::String::New(info.Env(), result.data(), result.size()); - } - - Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_n_ctx_train(model)); - } - - Napi::Value GetEmbeddingVectorSize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_n_embd(model)); - } - - Napi::Value GetTotalSize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_model_size(model)); - } - - Napi::Value GetTotalParameters(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_model_n_params(model)); - } - - Napi::Value GetModelDescription(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - char model_desc[128]; - int actual_length = llama_model_desc(model, model_desc, sizeof(model_desc)); - - return Napi::String::New(info.Env(), model_desc, actual_length); - } - - Napi::Value TokenBos(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_bos(model)); - } - Napi::Value TokenEos(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_eos(model)); - } - Napi::Value TokenNl(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiToken(info, model, llama_token_nl(model)); - } - Napi::Value PrefixToken(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_prefix(model)); - } - Napi::Value MiddleToken(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_middle(model)); - } - Napi::Value SuffixToken(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_suffix(model)); - } - Napi::Value EotToken(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return getNapiControlToken(info, model, llama_token_eot(model)); - } - Napi::Value GetTokenString(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int token = info[0].As().Int32Value(); - std::stringstream ss; - - const char* str = llama_token_get_text(model, token); - if (str == nullptr) { - return info.Env().Undefined(); - } - - ss << str; - - return Napi::String::New(info.Env(), ss.str()); - } - - Napi::Value GetTokenAttributes(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - if (info[0].IsNumber() == false) { - return Napi::Number::From(info.Env(), int32_t(LLAMA_TOKEN_ATTR_UNDEFINED)); - } - - int token = info[0].As().Int32Value(); - auto tokenAttributes = llama_token_get_attr(model, token); - - return Napi::Number::From(info.Env(), int32_t(tokenAttributes)); - } - Napi::Value IsEogToken(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - if (info[0].IsNumber() == false) { - return Napi::Boolean::New(info.Env(), false); - } - - int token = info[0].As().Int32Value(); - - return Napi::Boolean::New(info.Env(), llama_token_is_eog(model, token)); - } - Napi::Value GetVocabularyType(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - auto vocabularyType = llama_vocab_type(model); - - return Napi::Number::From(info.Env(), int32_t(vocabularyType)); - } - Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info) { - const int addBos = llama_add_bos_token(model); - - bool shouldPrependBos = addBos != -1 ? bool(addBos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); - - return Napi::Boolean::New(info.Env(), shouldPrependBos); - } - - Napi::Value GetModelSize(const Napi::CallbackInfo& info) { - return Napi::Number::From(info.Env(), llama_model_size(model)); - } - - static void init(Napi::Object exports) { - exports.Set( - "AddonModel", - DefineClass( - exports.Env(), - "AddonModel", - { - InstanceMethod("init", &AddonModel::Init), - InstanceMethod("loadLora", &AddonModel::LoadLora), - InstanceMethod("abortActiveModelLoad", &AddonModel::AbortActiveModelLoad), - InstanceMethod("tokenize", &AddonModel::Tokenize), - InstanceMethod("detokenize", &AddonModel::Detokenize), - InstanceMethod("getTrainContextSize", &AddonModel::GetTrainContextSize), - InstanceMethod("getEmbeddingVectorSize", &AddonModel::GetEmbeddingVectorSize), - InstanceMethod("getTotalSize", &AddonModel::GetTotalSize), - InstanceMethod("getTotalParameters", &AddonModel::GetTotalParameters), - InstanceMethod("getModelDescription", &AddonModel::GetModelDescription), - InstanceMethod("tokenBos", &AddonModel::TokenBos), - InstanceMethod("tokenEos", &AddonModel::TokenEos), - InstanceMethod("tokenNl", &AddonModel::TokenNl), - InstanceMethod("prefixToken", &AddonModel::PrefixToken), - InstanceMethod("middleToken", &AddonModel::MiddleToken), - InstanceMethod("suffixToken", &AddonModel::SuffixToken), - InstanceMethod("eotToken", &AddonModel::EotToken), - InstanceMethod("getTokenString", &AddonModel::GetTokenString), - InstanceMethod("getTokenAttributes", &AddonModel::GetTokenAttributes), - InstanceMethod("isEogToken", &AddonModel::IsEogToken), - InstanceMethod("getVocabularyType", &AddonModel::GetVocabularyType), - InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken), - InstanceMethod("getModelSize", &AddonModel::GetModelSize), - InstanceMethod("dispose", &AddonModel::Dispose), - } - ) - ); - } -}; - -static bool llamaModelParamsProgressCallback(float progress, void * user_data) { - AddonModel* addonModel = (AddonModel *) user_data; - unsigned percentage = (unsigned) (100 * progress); - - if (percentage > addonModel->modelLoadPercentage) { - addonModel->modelLoadPercentage = percentage; - - // original llama.cpp logs - addonLlamaCppLogCallback(GGML_LOG_LEVEL_INFO, ".", nullptr); - if (percentage >= 100) { - addonLlamaCppLogCallback(GGML_LOG_LEVEL_INFO, "\n", nullptr); - } - } - - if (progress > addonModel->rawModelLoadPercentage) { - addonModel->rawModelLoadPercentage = progress; - - if (addonModel->onLoadProgressEventCallbackSet) { - addon_progress_event* data = new addon_progress_event { - progress - }; - - auto status = addonModel->addonThreadSafeOnLoadProgressEventCallback.NonBlockingCall(data); - - if (status != napi_ok) { - delete data; - } - } - } - - return !(addonModel->abortModelLoad); -} - -class AddonModelLoadModelWorker : public Napi::AsyncWorker { - public: - AddonModel* model; - - AddonModelLoadModelWorker(const Napi::Env& env, AddonModel* model) - : Napi::AsyncWorker(env, "AddonModelLoadModelWorker"), - model(model), - deferred(Napi::Promise::Deferred::New(env)) { - model->Ref(); - } - ~AddonModelLoadModelWorker() { - model->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - model->model = llama_load_model_from_file(model->modelPath.c_str(), model->model_params); - - model->modelLoaded = model->model != nullptr && model->model != NULL; - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_load_model_from_file\""); - } - } - void OnOK() { - if (model->modelLoaded) { - uint64_t modelSize = llama_model_size(model->model); - adjustNapiExternalMemoryAdd(Env(), modelSize); - model->loadedModelSize = modelSize; - } - - deferred.Resolve(Napi::Boolean::New(Env(), model->modelLoaded)); - if (model->onLoadProgressEventCallbackSet) { - model->addonThreadSafeOnLoadProgressEventCallback.Release(); - } - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; -class AddonModelUnloadModelWorker : public Napi::AsyncWorker { - public: - AddonModel* model; - - AddonModelUnloadModelWorker(const Napi::Env& env, AddonModel* model) - : Napi::AsyncWorker(env, "AddonModelUnloadModelWorker"), - model(model), - deferred(Napi::Promise::Deferred::New(env)) { - model->Ref(); - } - ~AddonModelUnloadModelWorker() { - model->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - llama_free_model(model->model); - model->modelLoaded = false; - - model->dispose(); - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_free_model\""); - } - } - void OnOK() { - adjustNapiExternalMemorySubtract(Env(), model->loadedModelSize); - model->loadedModelSize = 0; - - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; -class AddonModelLoadLoraWorker : public Napi::AsyncWorker { - public: - AddonModel* model; - std::string loraFilePath; - float loraScale; - int32_t loraThreads; - std::string baseModelPath; - - AddonModelLoadLoraWorker( - const Napi::Env& env, - AddonModel* model, - std::string loraFilePath, - float loraScale, - int32_t loraThreads, - std::string baseModelPath - ) - : Napi::AsyncWorker(env, "AddonModelLoadLoraWorker"), - model(model), - loraFilePath(loraFilePath), - loraScale(loraScale), - loraThreads(loraThreads), - baseModelPath(baseModelPath), - deferred(Napi::Promise::Deferred::New(env)) { - model->Ref(); - } - ~AddonModelLoadLoraWorker() { - model->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - const auto res = llama_model_apply_lora_from_file( - model->model, - loraFilePath.c_str(), - loraScale, - baseModelPath.empty() ? NULL : baseModelPath.c_str(), - loraThreads - ); - - if (res != 0) { - SetError( - std::string( - std::string("Failed to apply LoRA \"") + loraFilePath + std::string("\"") + ( - baseModelPath.empty() - ? std::string("") - : (std::string(" with base model \"") + baseModelPath + std::string("\"")) - ) - ) - ); - } - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_model_apply_lora_from_file\""); - } - } - void OnOK() { - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - -Napi::Value AddonModel::Init(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - AddonModelLoadModelWorker* worker = new AddonModelLoadModelWorker(this->Env(), this); - worker->Queue(); - return worker->GetPromise(); -} -Napi::Value AddonModel::LoadLora(const Napi::CallbackInfo& info) { - std::string loraFilePath = info[0].As().Utf8Value(); - float scale = info[1].As().FloatValue(); - int32_t threads = info[2].As().Int32Value(); - std::string baseModelPath = (info.Length() > 3 && info[3].IsString()) ? info[3].As().Utf8Value() : std::string(""); - - int32_t resolvedThreads = threads == 0 ? std::thread::hardware_concurrency() : threads; - - AddonModelLoadLoraWorker* worker = new AddonModelLoadLoraWorker(this->Env(), this, loraFilePath, scale, threads, baseModelPath); - worker->Queue(); - return worker->GetPromise(); -} -Napi::Value AddonModel::Dispose(const Napi::CallbackInfo& info) { - if (disposed) { - return info.Env().Undefined(); - } - - if (modelLoaded) { - modelLoaded = false; - - AddonModelUnloadModelWorker* worker = new AddonModelUnloadModelWorker(this->Env(), this); - worker->Queue(); - return worker->GetPromise(); - } else { - dispose(); - - Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); - deferred.Resolve(info.Env().Undefined()); - return deferred.Promise(); - } -} - -class AddonGrammar : public Napi::ObjectWrap { - public: - grammar_parser::parse_state parsed_grammar; - Napi::Reference addonExportsRef; - bool hasAddonExportsRef = false; - - AddonGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - // Get the model path - std::string grammarCode = info[0].As().Utf8Value(); - bool should_print_grammar = false; - - if (info.Length() > 1 && info[1].IsObject()) { - Napi::Object options = info[1].As(); - - if (options.Has("addonExports")) { - addonExportsRef = Napi::Persistent(options.Get("addonExports").As()); - hasAddonExportsRef = true; - } - - if (options.Has("printGrammar")) { - should_print_grammar = options.Get("printGrammar").As().Value(); - } - } - - parsed_grammar = grammar_parser::parse(grammarCode.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException(); - return; - } - - if (should_print_grammar) { - grammar_parser::print_grammar(stderr, parsed_grammar); - } - } - - ~AddonGrammar() { - if (hasAddonExportsRef) { - addonExportsRef.Unref(); - hasAddonExportsRef = false; - } - } - - static void init(Napi::Object exports) { - exports.Set("AddonGrammar", DefineClass(exports.Env(), "AddonGrammar", {})); - } -}; - -class AddonGrammarEvaluationState : public Napi::ObjectWrap { - public: - AddonGrammar* grammarDef; - llama_grammar* grammar = nullptr; - - AddonGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - grammarDef = Napi::ObjectWrap::Unwrap(info[0].As()); - grammarDef->Ref(); - - std::vector grammar_rules(grammarDef->parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root")); - } - - ~AddonGrammarEvaluationState() { - grammarDef->Unref(); - - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - } - } - - static void init(Napi::Object exports) { - exports.Set("AddonGrammarEvaluationState", DefineClass(exports.Env(), "AddonGrammarEvaluationState", {})); - } -}; - -class AddonContext : public Napi::ObjectWrap { - public: - AddonModel* model; - llama_context_params context_params; - llama_context* ctx; - llama_batch batch; - uint64_t batchMemorySize = 0; - bool has_batch = false; - int32_t batch_n_tokens = 0; - int n_cur = 0; - - uint64_t loadedContextMemorySize = 0; - bool contextLoaded = false; - - bool disposed = false; - - AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - model = Napi::ObjectWrap::Unwrap(info[0].As()); - model->Ref(); - - context_params = llama_context_default_params(); - context_params.seed = -1; - context_params.n_ctx = 4096; - context_params.n_threads = 6; - context_params.n_threads_batch = context_params.n_threads; - - if (info.Length() > 1 && info[1].IsObject()) { - Napi::Object options = info[1].As(); - - if (options.Has("noSeed")) { - context_params.seed = time(NULL); - } else if (options.Has("seed")) { - context_params.seed = options.Get("seed").As().Uint32Value(); - } - - if (options.Has("contextSize")) { - context_params.n_ctx = options.Get("contextSize").As().Uint32Value(); - } - - if (options.Has("batchSize")) { - context_params.n_batch = options.Get("batchSize").As().Uint32Value(); - context_params.n_ubatch = context_params.n_batch; // the batch queue is managed in the JS side, so there's no need for managing it on the C++ side - } - - if (options.Has("sequences")) { - context_params.n_seq_max = options.Get("sequences").As().Uint32Value(); - } - - if (options.Has("embeddings")) { - context_params.embeddings = options.Get("embeddings").As().Value(); - } - - if (options.Has("flashAttention")) { - context_params.flash_attn = options.Get("flashAttention").As().Value(); - } - - if (options.Has("threads")) { - const auto n_threads = options.Get("threads").As().Uint32Value(); - const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads; - - context_params.n_threads = resolved_n_threads; - context_params.n_threads_batch = resolved_n_threads; - } - } - } - ~AddonContext() { - dispose(); - } - - void dispose() { - if (disposed) { - return; - } - - disposed = true; - if (contextLoaded) { - contextLoaded = false; - llama_free(ctx); - - adjustNapiExternalMemorySubtract(Env(), loadedContextMemorySize); - loadedContextMemorySize = 0; - } - - model->Unref(); - - disposeBatch(); - } - void disposeBatch() { - if (!has_batch) { - return; - } - - llama_batch_free(batch); - has_batch = false; - batch_n_tokens = 0; - - adjustNapiExternalMemorySubtract(Env(), batchMemorySize); - batchMemorySize = 0; - } - - Napi::Value Init(const Napi::CallbackInfo& info); - Napi::Value Dispose(const Napi::CallbackInfo& info); - - Napi::Value GetContextSize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_n_ctx(ctx)); - } - Napi::Value InitBatch(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - if (has_batch) { - llama_batch_free(batch); - } - - int32_t n_tokens = info[0].As().Int32Value(); - - batch = llama_batch_init(n_tokens, 0, 1); - has_batch = true; - batch_n_tokens = n_tokens; - - uint64_t newBatchMemorySize = calculateBatchMemorySize(n_tokens, llama_n_embd(model->model), context_params.n_batch); - if (newBatchMemorySize > batchMemorySize) { - adjustNapiExternalMemoryAdd(Env(), newBatchMemorySize - batchMemorySize); - batchMemorySize = newBatchMemorySize; - } else if (newBatchMemorySize < batchMemorySize) { - adjustNapiExternalMemorySubtract(Env(), batchMemorySize - newBatchMemorySize); - batchMemorySize = newBatchMemorySize; - } - - return info.Env().Undefined(); - } - Napi::Value DisposeBatch(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - disposeBatch(); - - return info.Env().Undefined(); - } - Napi::Value AddToBatch(const Napi::CallbackInfo& info) { - if (!has_batch) { - Napi::Error::New(info.Env(), "No batch is initialized").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int32_t sequenceId = info[0].As().Int32Value(); - int32_t firstTokenContextIndex = info[1].As().Int32Value(); - Napi::Uint32Array tokens = info[2].As(); - bool generateLogitAtTheEnd = info[3].As().Value(); - - auto tokensLength = tokens.ElementLength(); - GGML_ASSERT(batch.n_tokens + tokensLength <= batch_n_tokens); - - for (size_t i = 0; i < tokensLength; i++) { - llama_batch_add(batch, static_cast(tokens[i]), firstTokenContextIndex + i, { sequenceId }, false); - } - - if (generateLogitAtTheEnd) { - batch.logits[batch.n_tokens - 1] = true; - - auto logit_index = batch.n_tokens - 1; - - return Napi::Number::From(info.Env(), logit_index); - } - - return info.Env().Undefined(); - } - Napi::Value DisposeSequence(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int32_t sequenceId = info[0].As().Int32Value(); - - bool result = llama_kv_cache_seq_rm(ctx, sequenceId, -1, -1); - - if (!result) { - Napi::Error::New(info.Env(), "Failed to dispose sequence").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return info.Env().Undefined(); - } - Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int32_t sequenceId = info[0].As().Int32Value(); - int32_t startPos = info[1].As().Int32Value(); - int32_t endPos = info[2].As().Int32Value(); - - bool result = llama_kv_cache_seq_rm(ctx, sequenceId, startPos, endPos); - - return Napi::Boolean::New(info.Env(), result); - } - Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int32_t sequenceId = info[0].As().Int32Value(); - int32_t startPos = info[1].As().Int32Value(); - int32_t endPos = info[2].As().Int32Value(); - int32_t shiftDelta = info[3].As().Int32Value(); - - llama_kv_cache_seq_add(ctx, sequenceId, startPos, endPos, shiftDelta); - - return info.Env().Undefined(); - } - Napi::Value DecodeBatch(const Napi::CallbackInfo& info); - Napi::Value SampleToken(const Napi::CallbackInfo& info); - - Napi::Value AcceptGrammarEvaluationStateToken(const Napi::CallbackInfo& info) { - AddonGrammarEvaluationState* grammar_evaluation_state = - Napi::ObjectWrap::Unwrap(info[0].As()); - llama_token tokenId = info[1].As().Int32Value(); - - if ((grammar_evaluation_state)->grammar != nullptr) { - llama_grammar_accept_token(ctx, (grammar_evaluation_state)->grammar, tokenId); - } - - return info.Env().Undefined(); - } - - Napi::Value CanBeNextTokenForGrammarEvaluationState(const Napi::CallbackInfo& info) { - AddonGrammarEvaluationState* grammar_evaluation_state = - Napi::ObjectWrap::Unwrap(info[0].As()); - llama_token tokenId = info[1].As().Int32Value(); - - if ((grammar_evaluation_state)->grammar != nullptr) { - std::vector candidates; - candidates.reserve(1); - candidates.emplace_back(llama_token_data { tokenId, 1, 0.0f }); - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - llama_sample_grammar(ctx, &candidates_p, (grammar_evaluation_state)->grammar); - - if (candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) { - return Napi::Boolean::New(info.Env(), false); - } - - return Napi::Boolean::New(info.Env(), true); - } - - return Napi::Boolean::New(info.Env(), false); - } - - Napi::Value GetEmbedding(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - int32_t inputTokensLength = info[0].As().Int32Value(); - - if (inputTokensLength <= 0) { - Napi::Error::New(info.Env(), "Invalid input tokens length").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - const int n_embd = llama_n_embd(model->model); - const auto* embeddings = llama_get_embeddings_seq(ctx, 0); - if (embeddings == NULL) { - embeddings = llama_get_embeddings_ith(ctx, inputTokensLength - 1); - - if (embeddings == NULL) { - Napi::Error::New(info.Env(), std::string("Failed to get embeddings for token ") + std::to_string(inputTokensLength - 1)).ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - } - - Napi::Float64Array result = Napi::Float64Array::New(info.Env(), n_embd); - for (size_t i = 0; i < n_embd; ++i) { - result[i] = embeddings[i]; - } - - return result; - } - - Napi::Value GetStateSize(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - return Napi::Number::From(info.Env(), llama_state_get_size(ctx)); - } - - Napi::Value PrintTimings(const Napi::CallbackInfo& info) { - llama_print_timings(ctx); - llama_reset_timings(ctx); - return info.Env().Undefined(); - } - - static void init(Napi::Object exports) { - exports.Set( - "AddonContext", - DefineClass( - exports.Env(), - "AddonContext", - { - InstanceMethod("init", &AddonContext::Init), - InstanceMethod("getContextSize", &AddonContext::GetContextSize), - InstanceMethod("initBatch", &AddonContext::InitBatch), - InstanceMethod("addToBatch", &AddonContext::AddToBatch), - InstanceMethod("disposeSequence", &AddonContext::DisposeSequence), - InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence), - InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells), - InstanceMethod("decodeBatch", &AddonContext::DecodeBatch), - InstanceMethod("sampleToken", &AddonContext::SampleToken), - InstanceMethod("acceptGrammarEvaluationStateToken", &AddonContext::AcceptGrammarEvaluationStateToken), - InstanceMethod("canBeNextTokenForGrammarEvaluationState", &AddonContext::CanBeNextTokenForGrammarEvaluationState), - InstanceMethod("getEmbedding", &AddonContext::GetEmbedding), - InstanceMethod("getStateSize", &AddonContext::GetStateSize), - InstanceMethod("printTimings", &AddonContext::PrintTimings), - InstanceMethod("dispose", &AddonContext::Dispose), - } - ) - ); - } -}; - - -class AddonContextDecodeBatchWorker : public Napi::AsyncWorker { - public: - AddonContext* ctx; - - AddonContextDecodeBatchWorker(const Napi::Env& env, AddonContext* ctx) - : Napi::AsyncWorker(env, "AddonContextDecodeBatchWorker"), - ctx(ctx), - deferred(Napi::Promise::Deferred::New(env)) { - ctx->Ref(); - } - ~AddonContextDecodeBatchWorker() { - ctx->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - // Perform the evaluation using llama_decode. - int r = llama_decode(ctx->ctx, ctx->batch); - - if (r != 0) { - if (r == 1) { - SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)"); - } else { - SetError("Eval has failed"); - } - - return; - } - - llama_synchronize(ctx->ctx); - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_decode\""); - } - } - void OnOK() { - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - -Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) { - AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info.Env(), this); - worker->Queue(); - return worker->GetPromise(); -} - -class AddonContextLoadContextWorker : public Napi::AsyncWorker { - public: - AddonContext* context; - - AddonContextLoadContextWorker(const Napi::Env& env, AddonContext* context) - : Napi::AsyncWorker(env, "AddonContextLoadContextWorker"), - context(context), - deferred(Napi::Promise::Deferred::New(env)) { - context->Ref(); - } - ~AddonContextLoadContextWorker() { - context->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - context->ctx = llama_new_context_with_model(context->model->model, context->context_params); - - context->contextLoaded = context->ctx != nullptr && context->ctx != NULL; - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_new_context_with_model\""); - } - } - void OnOK() { - if (context->contextLoaded) { - uint64_t contextMemorySize = llama_state_get_size(context->ctx); - adjustNapiExternalMemoryAdd(Env(), contextMemorySize); - context->loadedContextMemorySize = contextMemorySize; - } - - deferred.Resolve(Napi::Boolean::New(Env(), context->contextLoaded)); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; -class AddonContextUnloadContextWorker : public Napi::AsyncWorker { - public: - AddonContext* context; - - AddonContextUnloadContextWorker(const Napi::Env& env, AddonContext* context) - : Napi::AsyncWorker(env, "AddonContextUnloadContextWorker"), - context(context), - deferred(Napi::Promise::Deferred::New(env)) { - context->Ref(); - } - ~AddonContextUnloadContextWorker() { - context->Unref(); - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - llama_free(context->ctx); - context->contextLoaded = false; - - try { - if (context->has_batch) { - llama_batch_free(context->batch); - context->has_batch = false; - context->batch_n_tokens = 0; - } - - context->dispose(); - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_batch_free\""); - } - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_free\""); - } - } - void OnOK() { - adjustNapiExternalMemorySubtract(Env(), context->loadedContextMemorySize); - context->loadedContextMemorySize = 0; - - adjustNapiExternalMemorySubtract(Env(), context->batchMemorySize); - context->batchMemorySize = 0; - - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - -Napi::Value AddonContext::Init(const Napi::CallbackInfo& info) { - if (disposed) { - Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - - AddonContextLoadContextWorker* worker = new AddonContextLoadContextWorker(this->Env(), this); - worker->Queue(); - return worker->GetPromise(); -} -Napi::Value AddonContext::Dispose(const Napi::CallbackInfo& info) { - if (disposed) { - return info.Env().Undefined(); - } - - if (contextLoaded) { - contextLoaded = false; - - AddonContextUnloadContextWorker* worker = new AddonContextUnloadContextWorker(this->Env(), this); - worker->Queue(); - return worker->GetPromise(); - } else { - dispose(); - - Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); - deferred.Resolve(info.Env().Undefined()); - return deferred.Promise(); - } -} - -class AddonContextSampleTokenWorker : public Napi::AsyncWorker { - public: - AddonContext* ctx; - AddonGrammarEvaluationState* grammar_evaluation_state; - int32_t batchLogitIndex; - bool use_grammar = false; - llama_token result; - float temperature = 0.0f; - float min_p = 0; - int32_t top_k = 40; - float top_p = 0.95f; - float repeat_penalty = 1.10f; // 1.0 = disabled - float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled - float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled - std::vector repeat_penalty_tokens; - std::unordered_map tokenBiases; - bool useTokenBiases = false; - bool use_repeat_penalty = false; - - AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx) - : Napi::AsyncWorker(info.Env(), "AddonContextSampleTokenWorker"), - ctx(ctx), - deferred(Napi::Promise::Deferred::New(info.Env())) { - ctx->Ref(); - - batchLogitIndex = info[0].As().Int32Value(); - - if (info.Length() > 1 && info[1].IsObject()) { - Napi::Object options = info[1].As(); - - if (options.Has("temperature")) { - temperature = options.Get("temperature").As().FloatValue(); - } - - if (options.Has("minP")) { - min_p = options.Get("minP").As().FloatValue(); - } - - if (options.Has("topK")) { - top_k = options.Get("topK").As().Int32Value(); - } - - if (options.Has("topP")) { - top_p = options.Get("topP").As().FloatValue(); - } - - if (options.Has("repeatPenalty")) { - repeat_penalty = options.Get("repeatPenalty").As().FloatValue(); - } - - if (options.Has("repeatPenaltyTokens")) { - Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As(); - - repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength()); - for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) { - repeat_penalty_tokens.push_back(static_cast(repeat_penalty_tokens_uint32_array[i])); - } - - use_repeat_penalty = true; - } - - if (options.Has("tokenBiasKeys") && options.Has("tokenBiasValues")) { - Napi::Uint32Array tokenBiasKeys = options.Get("tokenBiasKeys").As(); - Napi::Float32Array tokenBiasValues = options.Get("tokenBiasValues").As(); - - if (tokenBiasKeys.ElementLength() == tokenBiasValues.ElementLength()) { - for (size_t i = 0; i < tokenBiasKeys.ElementLength(); i++) { - tokenBiases[static_cast(tokenBiasKeys[i])] = tokenBiasValues[i]; - } - - useTokenBiases = true; - } - } - - if (options.Has("repeatPenaltyPresencePenalty")) { - repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As().FloatValue(); - } - - if (options.Has("repeatPenaltyFrequencyPenalty")) { - repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As().FloatValue(); - } - - if (options.Has("grammarEvaluationState")) { - grammar_evaluation_state = - Napi::ObjectWrap::Unwrap(options.Get("grammarEvaluationState").As()); - grammar_evaluation_state->Ref(); - use_grammar = true; - } - } - } - ~AddonContextSampleTokenWorker() { - ctx->Unref(); - - if (use_grammar) { - grammar_evaluation_state->Unref(); - use_grammar = false; - } - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - SampleToken(); - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"SampleToken\""); - } - } - - void SampleToken() { - llama_token new_token_id = 0; - - // Select the best prediction. - if (llama_get_logits(ctx->ctx) == nullptr) { - SetError("This model does not support token generation"); - return; - } - - auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex); - auto n_vocab = llama_n_vocab(ctx->model->model); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - auto logit = logits[token_id]; - - if (useTokenBiases) { - bool hasTokenBias = tokenBiases.find(token_id) != tokenBiases.end(); - if (hasTokenBias) { - auto logitBias = tokenBiases.at(token_id); - if (logitBias == -INFINITY || logitBias < -INFINITY) { - if (!llama_token_is_eog(ctx->model->model, token_id)) { - logit = -INFINITY; - } - } else { - logit += logitBias; - } - } - } - - candidates.emplace_back(llama_token_data { token_id, logit, 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - if (use_repeat_penalty && !repeat_penalty_tokens.empty()) { - llama_sample_repetition_penalties( - ctx->ctx, - &candidates_p, - repeat_penalty_tokens.data(), - repeat_penalty_tokens.size(), - repeat_penalty, - repeat_penalty_frequency_penalty, - repeat_penalty_presence_penalty - ); - } - - if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) { - llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar); - - if ((candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) && useTokenBiases) { - // logit biases caused grammar sampling to fail, so sampling again without logit biases - useTokenBiases = false; - SampleToken(); - return; - } - } - - if (temperature <= 0) { - new_token_id = llama_sample_token_greedy(ctx->ctx, &candidates_p); - } else { - const int32_t resolved_top_k = - top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model)); - const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled - const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled - const float typical_p = 1.00f; // Typical probability - 1.0 = disabled - const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled - - // Temperature sampling - size_t min_keep = std::max(1, n_probs); - llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep); - llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep); - llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep); - llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep); - llama_sample_min_p(ctx->ctx, &candidates_p, min_p, min_keep); - llama_sample_temp(ctx->ctx, &candidates_p, temperature); - new_token_id = llama_sample_token(ctx->ctx, &candidates_p); - } - - if (!llama_token_is_eog(ctx->model->model, new_token_id) && use_grammar && (grammar_evaluation_state)->grammar != nullptr) { - llama_grammar_accept_token(ctx->ctx, (grammar_evaluation_state)->grammar, new_token_id); - } - - result = new_token_id; - } - void OnOK() { - Napi::Number resultValue = Napi::Number::New(Env(), static_cast(result)); - deferred.Resolve(resultValue); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - -Napi::Value AddonContext::SampleToken(const Napi::CallbackInfo& info) { - AddonContextSampleTokenWorker* worker = new AddonContextSampleTokenWorker(info, this); - worker->Queue(); - return worker->GetPromise(); -} - -Napi::Value systemInfo(const Napi::CallbackInfo& info) { - return Napi::String::From(info.Env(), llama_print_system_info()); -} - -Napi::Value addonGetSupportsGpuOffloading(const Napi::CallbackInfo& info) { - return Napi::Boolean::New(info.Env(), llama_supports_gpu_offload()); -} - -Napi::Value addonGetSupportsMmap(const Napi::CallbackInfo& info) { - return Napi::Boolean::New(info.Env(), llama_supports_mmap()); -} - -Napi::Value addonGetSupportsMlock(const Napi::CallbackInfo& info) { - return Napi::Boolean::New(info.Env(), llama_supports_mlock()); -} - -Napi::Value addonGetBlockSizeForGgmlType(const Napi::CallbackInfo& info) { - const int ggmlType = info[0].As().Int32Value(); - - if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { - return info.Env().Undefined(); - } - - const auto blockSize = ggml_blck_size(static_cast(ggmlType)); - - return Napi::Number::New(info.Env(), blockSize); -} - -Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) { - const int ggmlType = info[0].As().Int32Value(); - - if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { - return info.Env().Undefined(); - } - - const auto typeSize = ggml_type_size(static_cast(ggmlType)); - - return Napi::Number::New(info.Env(), typeSize); -} - -Napi::Value addonGetConsts(const Napi::CallbackInfo& info) { - Napi::Object consts = Napi::Object::New(info.Env()); - consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS)); - consts.Set("ggmlTypeF16Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F16))); - consts.Set("ggmlTypeF32Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F32))); - consts.Set("ggmlTensorOverhead", Napi::Number::New(info.Env(), ggml_tensor_overhead())); - consts.Set("llamaMaxRngState", Napi::Number::New(info.Env(), LLAMA_MAX_RNG_STATE)); - consts.Set("llamaPosSize", Napi::Number::New(info.Env(), sizeof(llama_pos))); - consts.Set("llamaSeqIdSize", Napi::Number::New(info.Env(), sizeof(llama_seq_id))); - - return consts; -} - -int addonGetGgmlLogLevelNumber(ggml_log_level level) { - switch (level) { - case GGML_LOG_LEVEL_ERROR: return 2; - case GGML_LOG_LEVEL_WARN: return 3; - case GGML_LOG_LEVEL_INFO: return 4; - case GGML_LOG_LEVEL_DEBUG: return 5; - } - - return 1; -} - -void addonCallJsLogCallback( - Napi::Env env, Napi::Function callback, AddonThreadSafeLogCallbackFunctionContext* context, addon_logger_log* data -) { - bool called = false; - - if (env != nullptr && callback != nullptr && addonJsLoggerCallbackSet) { - try { - callback.Call({ - Napi::Number::New(env, data->logLevelNumber), - Napi::String::New(env, data->stringStream->str()), - }); - called = true; - } catch (const Napi::Error& e) { - called = false; - } - } - - if (!called && data != nullptr) { - if (data->logLevelNumber == 2) { - fputs(data->stringStream->str().c_str(), stderr); - fflush(stderr); - } else { - fputs(data->stringStream->str().c_str(), stdout); - fflush(stdout); - } - } - - if (data != nullptr) { - delete data->stringStream; - delete data; - } -} - -static void addonLlamaCppLogCallback(ggml_log_level level, const char* text, void* user_data) { - int logLevelNumber = addonGetGgmlLogLevelNumber(level); - - if (logLevelNumber > addonLoggerLogLevel) { - return; - } - - if (addonJsLoggerCallbackSet) { - std::stringstream* stringStream = new std::stringstream(); - if (text != nullptr) { - *stringStream << text; - } - - addon_logger_log* data = new addon_logger_log { - logLevelNumber, - stringStream, - }; - - auto status = addonThreadSafeLoggerCallback.NonBlockingCall(data); - - if (status == napi_ok) { - return; - } else { - delete stringStream; - delete data; - } - } - - if (text != nullptr) { - if (level == 2) { - fputs(text, stderr); - fflush(stderr); - } else { - fputs(text, stdout); - fflush(stdout); - } - } -} - -Napi::Value setLogger(const Napi::CallbackInfo& info) { - if (info.Length() < 1 || !info[0].IsFunction()) { - if (addonJsLoggerCallbackSet) { - addonJsLoggerCallbackSet = false; - addonThreadSafeLoggerCallback.Release(); - } - - return info.Env().Undefined(); - } - - auto addonLoggerJSCallback = info[0].As(); - AddonThreadSafeLogCallbackFunctionContext* context = new Napi::Reference(Napi::Persistent(info.This())); - addonThreadSafeLoggerCallback = AddonThreadSafeLogCallbackFunction::New( - info.Env(), - addonLoggerJSCallback, - "loggerCallback", - 0, - 1, - context, - [](Napi::Env, void*, AddonThreadSafeLogCallbackFunctionContext* ctx) { - addonJsLoggerCallbackSet = false; - - delete ctx; - } - ); - addonJsLoggerCallbackSet = true; - - // prevent blocking the main node process from exiting due to active resources - addonThreadSafeLoggerCallback.Unref(info.Env()); - - return info.Env().Undefined(); -} - -Napi::Value setLoggerLogLevel(const Napi::CallbackInfo& info) { - if (info.Length() < 1 || !info[0].IsNumber()) { - addonLoggerLogLevel = 5; - - return info.Env().Undefined(); - } - - addonLoggerLogLevel = info[0].As().Int32Value(); - - return info.Env().Undefined(); -} - -class AddonBackendLoadWorker : public Napi::AsyncWorker { - public: - AddonBackendLoadWorker(const Napi::Env& env) - : Napi::AsyncWorker(env, "AddonBackendLoadWorker"), - deferred(Napi::Promise::Deferred::New(env)) { - } - ~AddonBackendLoadWorker() { - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - llama_backend_init(); - - try { - if (backendDisposed) { - llama_backend_free(); - } else { - backendInitialized = true; - } - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_backend_free\""); - } - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_backend_init\""); - } - } - void OnOK() { - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - - -class AddonBackendUnloadWorker : public Napi::AsyncWorker { - public: - AddonBackendUnloadWorker(const Napi::Env& env) - : Napi::AsyncWorker(env, "AddonBackendUnloadWorker"), - deferred(Napi::Promise::Deferred::New(env)) { - } - ~AddonBackendUnloadWorker() { - } - - Napi::Promise GetPromise() { - return deferred.Promise(); - } - - protected: - Napi::Promise::Deferred deferred; - - void Execute() { - try { - if (backendInitialized) { - backendInitialized = false; - llama_backend_free(); - } - } catch (const std::exception& e) { - SetError(e.what()); - } catch(...) { - SetError("Unknown error when calling \"llama_backend_free\""); - } - } - void OnOK() { - deferred.Resolve(Env().Undefined()); - } - void OnError(const Napi::Error& err) { - deferred.Reject(err.Value()); - } -}; - -Napi::Value addonInit(const Napi::CallbackInfo& info) { - if (backendInitialized) { - Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); - deferred.Resolve(info.Env().Undefined()); - return deferred.Promise(); - } - - AddonBackendLoadWorker* worker = new AddonBackendLoadWorker(info.Env()); - worker->Queue(); - return worker->GetPromise(); -} - -Napi::Value addonDispose(const Napi::CallbackInfo& info) { - if (backendDisposed) { - Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); - deferred.Resolve(info.Env().Undefined()); - return deferred.Promise(); - } - - backendDisposed = true; - - AddonBackendUnloadWorker* worker = new AddonBackendUnloadWorker(info.Env()); - worker->Queue(); - return worker->GetPromise(); -} - -static void addonFreeLlamaBackend(Napi::Env env, int* data) { - if (backendDisposed) { - return; - } - - backendDisposed = true; - if (backendInitialized) { - backendInitialized = false; - llama_backend_free(); - } -} - -Napi::Object registerCallback(Napi::Env env, Napi::Object exports) { - exports.DefineProperties({ - Napi::PropertyDescriptor::Function("systemInfo", systemInfo), - Napi::PropertyDescriptor::Function("getSupportsGpuOffloading", addonGetSupportsGpuOffloading), - Napi::PropertyDescriptor::Function("getSupportsMmap", addonGetSupportsMmap), - Napi::PropertyDescriptor::Function("getSupportsMlock", addonGetSupportsMlock), - Napi::PropertyDescriptor::Function("getBlockSizeForGgmlType", addonGetBlockSizeForGgmlType), - Napi::PropertyDescriptor::Function("getTypeSizeForGgmlType", addonGetTypeSizeForGgmlType), - Napi::PropertyDescriptor::Function("getConsts", addonGetConsts), - Napi::PropertyDescriptor::Function("setLogger", setLogger), - Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel), - Napi::PropertyDescriptor::Function("getGpuVramInfo", getGpuVramInfo), - Napi::PropertyDescriptor::Function("getGpuDeviceInfo", getGpuDeviceInfo), - Napi::PropertyDescriptor::Function("getGpuType", getGpuType), - Napi::PropertyDescriptor::Function("init", addonInit), - Napi::PropertyDescriptor::Function("dispose", addonDispose), - }); - AddonModel::init(exports); - AddonGrammar::init(exports); - AddonGrammarEvaluationState::init(exports); - AddonContext::init(exports); - - llama_log_set(addonLlamaCppLogCallback, nullptr); - - exports.AddFinalizer(addonFreeLlamaBackend, static_cast(nullptr)); - - return exports; -} - -NODE_API_MODULE(NODE_GYP_MODULE_NAME, registerCallback) diff --git a/llama/addon/AddonContext.cpp b/llama/addon/AddonContext.cpp new file mode 100644 index 00000000..24bc63c1 --- /dev/null +++ b/llama/addon/AddonContext.cpp @@ -0,0 +1,772 @@ +#include +#include +#include "common.h" +#include "llama.h" + +#include "addonGlobals.h" +#include "AddonModel.h" +#include "AddonModelLora.h" +#include "AddonGrammarEvaluationState.h" +#include "AddonContext.h" + +static uint64_t calculateBatchMemorySize(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + uint64_t totalSize = 0; + + if (embd) { + totalSize += sizeof(float) * n_tokens_alloc * embd; + } else { + totalSize += sizeof(llama_token) * n_tokens_alloc; + } + + totalSize += sizeof(llama_pos) * n_tokens_alloc; + totalSize += sizeof(int32_t) * n_tokens_alloc; + totalSize += sizeof(llama_seq_id *) * (n_tokens_alloc + 1); + + totalSize += sizeof(llama_seq_id) * n_seq_max * n_tokens_alloc; + + totalSize += sizeof(int8_t) * n_tokens_alloc; + + return totalSize; +} + +class AddonContextDecodeBatchWorker : public Napi::AsyncWorker { + public: + AddonContext* ctx; + + AddonContextDecodeBatchWorker(const Napi::Env& env, AddonContext* ctx) + : Napi::AsyncWorker(env, "AddonContextDecodeBatchWorker"), + ctx(ctx), + deferred(Napi::Promise::Deferred::New(env)) { + ctx->Ref(); + } + ~AddonContextDecodeBatchWorker() { + ctx->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + // Perform the evaluation using llama_decode. + int r = llama_decode(ctx->ctx, ctx->batch); + + if (r != 0) { + if (r == 1) { + SetError("could not find a KV slot for the batch (try reducing the size of the batch or increase the context)"); + } else { + SetError("Eval has failed"); + } + + return; + } + + llama_synchronize(ctx->ctx); + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_decode\""); + } + } + void OnOK() { + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +class AddonContextLoadContextWorker : public Napi::AsyncWorker { + public: + AddonContext* context; + + AddonContextLoadContextWorker(const Napi::Env& env, AddonContext* context) + : Napi::AsyncWorker(env, "AddonContextLoadContextWorker"), + context(context), + deferred(Napi::Promise::Deferred::New(env)) { + context->Ref(); + } + ~AddonContextLoadContextWorker() { + context->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + context->ctx = llama_new_context_with_model(context->model->model, context->context_params); + + context->contextLoaded = context->ctx != nullptr && context->ctx != NULL; + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_new_context_with_model\""); + } + } + void OnOK() { + if (context->contextLoaded) { + uint64_t contextMemorySize = llama_state_get_size(context->ctx); + adjustNapiExternalMemoryAdd(Env(), contextMemorySize); + context->loadedContextMemorySize = contextMemorySize; + } + + deferred.Resolve(Napi::Boolean::New(Env(), context->contextLoaded)); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; +class AddonContextUnloadContextWorker : public Napi::AsyncWorker { + public: + AddonContext* context; + + AddonContextUnloadContextWorker(const Napi::Env& env, AddonContext* context) + : Napi::AsyncWorker(env, "AddonContextUnloadContextWorker"), + context(context), + deferred(Napi::Promise::Deferred::New(env)) { + context->Ref(); + } + ~AddonContextUnloadContextWorker() { + context->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + llama_free(context->ctx); + context->contextLoaded = false; + + try { + if (context->has_batch) { + llama_batch_free(context->batch); + context->has_batch = false; + context->batch_n_tokens = 0; + } + + context->dispose(); + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_batch_free\""); + } + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_free\""); + } + } + void OnOK() { + adjustNapiExternalMemorySubtract(Env(), context->loadedContextMemorySize); + context->loadedContextMemorySize = 0; + + adjustNapiExternalMemorySubtract(Env(), context->batchMemorySize); + context->batchMemorySize = 0; + + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + + +class AddonContextSampleTokenWorker : public Napi::AsyncWorker { + public: + AddonContext* ctx; + AddonGrammarEvaluationState* grammar_evaluation_state; + int32_t batchLogitIndex; + bool use_grammar = false; + llama_token result; + float temperature = 0.0f; + float min_p = 0; + int32_t top_k = 40; + float top_p = 0.95f; + float repeat_penalty = 1.10f; // 1.0 = disabled + float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled + float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled + std::vector repeat_penalty_tokens; + std::unordered_map tokenBiases; + bool useTokenBiases = false; + bool use_repeat_penalty = false; + + AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx) + : Napi::AsyncWorker(info.Env(), "AddonContextSampleTokenWorker"), + ctx(ctx), + deferred(Napi::Promise::Deferred::New(info.Env())) { + ctx->Ref(); + + batchLogitIndex = info[0].As().Int32Value(); + + if (info.Length() > 1 && info[1].IsObject()) { + Napi::Object options = info[1].As(); + + if (options.Has("temperature")) { + temperature = options.Get("temperature").As().FloatValue(); + } + + if (options.Has("minP")) { + min_p = options.Get("minP").As().FloatValue(); + } + + if (options.Has("topK")) { + top_k = options.Get("topK").As().Int32Value(); + } + + if (options.Has("topP")) { + top_p = options.Get("topP").As().FloatValue(); + } + + if (options.Has("repeatPenalty")) { + repeat_penalty = options.Get("repeatPenalty").As().FloatValue(); + } + + if (options.Has("repeatPenaltyTokens")) { + Napi::Uint32Array repeat_penalty_tokens_uint32_array = options.Get("repeatPenaltyTokens").As(); + + repeat_penalty_tokens.reserve(repeat_penalty_tokens_uint32_array.ElementLength()); + for (size_t i = 0; i < repeat_penalty_tokens_uint32_array.ElementLength(); i++) { + repeat_penalty_tokens.push_back(static_cast(repeat_penalty_tokens_uint32_array[i])); + } + + use_repeat_penalty = true; + } + + if (options.Has("tokenBiasKeys") && options.Has("tokenBiasValues")) { + Napi::Uint32Array tokenBiasKeys = options.Get("tokenBiasKeys").As(); + Napi::Float32Array tokenBiasValues = options.Get("tokenBiasValues").As(); + + if (tokenBiasKeys.ElementLength() == tokenBiasValues.ElementLength()) { + for (size_t i = 0; i < tokenBiasKeys.ElementLength(); i++) { + tokenBiases[static_cast(tokenBiasKeys[i])] = tokenBiasValues[i]; + } + + useTokenBiases = true; + } + } + + if (options.Has("repeatPenaltyPresencePenalty")) { + repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As().FloatValue(); + } + + if (options.Has("repeatPenaltyFrequencyPenalty")) { + repeat_penalty_frequency_penalty = options.Get("repeatPenaltyFrequencyPenalty").As().FloatValue(); + } + + if (options.Has("grammarEvaluationState")) { + grammar_evaluation_state = + Napi::ObjectWrap::Unwrap(options.Get("grammarEvaluationState").As()); + grammar_evaluation_state->Ref(); + use_grammar = true; + } + } + } + ~AddonContextSampleTokenWorker() { + ctx->Unref(); + + if (use_grammar) { + grammar_evaluation_state->Unref(); + use_grammar = false; + } + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + SampleToken(); + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"SampleToken\""); + } + } + + void SampleToken() { + llama_token new_token_id = 0; + + // Select the best prediction. + if (llama_get_logits(ctx->ctx) == nullptr) { + SetError("This model does not support token generation"); + return; + } + + auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex); + auto n_vocab = llama_n_vocab(ctx->model->model); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + auto logit = logits[token_id]; + + if (useTokenBiases) { + bool hasTokenBias = tokenBiases.find(token_id) != tokenBiases.end(); + if (hasTokenBias) { + auto logitBias = tokenBiases.at(token_id); + if (logitBias == -INFINITY || logitBias < -INFINITY) { + if (!llama_token_is_eog(ctx->model->model, token_id)) { + logit = -INFINITY; + } + } else { + logit += logitBias; + } + } + } + + candidates.emplace_back(llama_token_data { token_id, logit, 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + if (use_repeat_penalty && !repeat_penalty_tokens.empty()) { + llama_sample_repetition_penalties( + ctx->ctx, + &candidates_p, + repeat_penalty_tokens.data(), + repeat_penalty_tokens.size(), + repeat_penalty, + repeat_penalty_frequency_penalty, + repeat_penalty_presence_penalty + ); + } + + if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) { + llama_grammar_sample((grammar_evaluation_state)->grammar, ctx->ctx, &candidates_p); + + if ((candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) && useTokenBiases) { + // logit biases caused grammar sampling to fail, so sampling again without logit biases + useTokenBiases = false; + SampleToken(); + return; + } + } + + if (temperature <= 0) { + new_token_id = llama_sample_token_greedy(ctx->ctx, &candidates_p); + } else { + const int32_t resolved_top_k = + top_k <= 0 ? llama_n_vocab(ctx->model->model) : std::min(top_k, llama_n_vocab(ctx->model->model)); + const int32_t n_probs = 0; // Number of probabilities to keep - 0 = disabled + const float tfs_z = 1.00f; // Tail free sampling - 1.0 = disabled + const float typical_p = 1.00f; // Typical probability - 1.0 = disabled + const float resolved_top_p = top_p; // Top p sampling - 1.0 = disabled + + // Temperature sampling + size_t min_keep = std::max(1, n_probs); + llama_sample_top_k(ctx->ctx, &candidates_p, resolved_top_k, min_keep); + llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep); + llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep); + llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep); + llama_sample_min_p(ctx->ctx, &candidates_p, min_p, min_keep); + llama_sample_temp(ctx->ctx, &candidates_p, temperature); + new_token_id = llama_sample_token(ctx->ctx, &candidates_p); + } + + if (!llama_token_is_eog(ctx->model->model, new_token_id) && use_grammar && (grammar_evaluation_state)->grammar != nullptr) { + llama_grammar_accept_token((grammar_evaluation_state)->grammar, ctx->ctx, new_token_id); + } + + result = new_token_id; + } + void OnOK() { + Napi::Number resultValue = Napi::Number::New(Env(), static_cast(result)); + deferred.Resolve(resultValue); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + batchMemorySize = 0; + has_batch = false; + batch_n_tokens = 0; + n_cur = 0; + + uint64_t loadedContextMemorySize = 0; + bool contextLoaded = false; + + bool disposed = false; + + model = Napi::ObjectWrap::Unwrap(info[0].As()); + model->Ref(); + + context_params = llama_context_default_params(); + context_params.seed = -1; + context_params.n_ctx = 4096; + context_params.n_threads = 6; + context_params.n_threads_batch = context_params.n_threads; + + if (info.Length() > 1 && info[1].IsObject()) { + Napi::Object options = info[1].As(); + + if (options.Has("noSeed")) { + context_params.seed = time(NULL); + } else if (options.Has("seed")) { + context_params.seed = options.Get("seed").As().Uint32Value(); + } + + if (options.Has("contextSize")) { + context_params.n_ctx = options.Get("contextSize").As().Uint32Value(); + } + + if (options.Has("batchSize")) { + context_params.n_batch = options.Get("batchSize").As().Uint32Value(); + context_params.n_ubatch = context_params.n_batch; // the batch queue is managed in the JS side, so there's no need for managing it on the C++ side + } + + if (options.Has("sequences")) { + context_params.n_seq_max = options.Get("sequences").As().Uint32Value(); + } + + if (options.Has("embeddings")) { + context_params.embeddings = options.Get("embeddings").As().Value(); + } + + if (options.Has("flashAttention")) { + context_params.flash_attn = options.Get("flashAttention").As().Value(); + } + + if (options.Has("threads")) { + const auto n_threads = options.Get("threads").As().Uint32Value(); + const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads; + + context_params.n_threads = resolved_n_threads; + context_params.n_threads_batch = resolved_n_threads; + } + } +} +AddonContext::~AddonContext() { + dispose(); +} + +void AddonContext::dispose() { + if (disposed) { + return; + } + + disposed = true; + if (contextLoaded) { + contextLoaded = false; + llama_free(ctx); + + adjustNapiExternalMemorySubtract(Env(), loadedContextMemorySize); + loadedContextMemorySize = 0; + } + + model->Unref(); + + disposeBatch(); +} +void AddonContext::disposeBatch() { + if (!has_batch) { + return; + } + + llama_batch_free(batch); + has_batch = false; + batch_n_tokens = 0; + + adjustNapiExternalMemorySubtract(Env(), batchMemorySize); + batchMemorySize = 0; +} + +Napi::Value AddonContext::Init(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + AddonContextLoadContextWorker* worker = new AddonContextLoadContextWorker(this->Env(), this); + worker->Queue(); + return worker->GetPromise(); +} +Napi::Value AddonContext::Dispose(const Napi::CallbackInfo& info) { + if (disposed) { + return info.Env().Undefined(); + } + + if (contextLoaded) { + contextLoaded = false; + + AddonContextUnloadContextWorker* worker = new AddonContextUnloadContextWorker(this->Env(), this); + worker->Queue(); + return worker->GetPromise(); + } else { + dispose(); + + Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); + deferred.Resolve(info.Env().Undefined()); + return deferred.Promise(); + } +} + +Napi::Value AddonContext::GetContextSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_n_ctx(ctx)); +} +Napi::Value AddonContext::InitBatch(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + if (has_batch) { + llama_batch_free(batch); + } + + int32_t n_tokens = info[0].As().Int32Value(); + + batch = llama_batch_init(n_tokens, 0, 1); + has_batch = true; + batch_n_tokens = n_tokens; + + uint64_t newBatchMemorySize = calculateBatchMemorySize(n_tokens, llama_n_embd(model->model), context_params.n_batch); + if (newBatchMemorySize > batchMemorySize) { + adjustNapiExternalMemoryAdd(Env(), newBatchMemorySize - batchMemorySize); + batchMemorySize = newBatchMemorySize; + } else if (newBatchMemorySize < batchMemorySize) { + adjustNapiExternalMemorySubtract(Env(), batchMemorySize - newBatchMemorySize); + batchMemorySize = newBatchMemorySize; + } + + return info.Env().Undefined(); +} +Napi::Value AddonContext::DisposeBatch(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + disposeBatch(); + + return info.Env().Undefined(); +} +Napi::Value AddonContext::AddToBatch(const Napi::CallbackInfo& info) { + if (!has_batch) { + Napi::Error::New(info.Env(), "No batch is initialized").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int32_t sequenceId = info[0].As().Int32Value(); + int32_t firstTokenContextIndex = info[1].As().Int32Value(); + Napi::Uint32Array tokens = info[2].As(); + bool generateLogitAtTheEnd = info[3].As().Value(); + + auto tokensLength = tokens.ElementLength(); + GGML_ASSERT(batch.n_tokens + tokensLength <= batch_n_tokens); + + for (size_t i = 0; i < tokensLength; i++) { + llama_batch_add(batch, static_cast(tokens[i]), firstTokenContextIndex + i, { sequenceId }, false); + } + + if (generateLogitAtTheEnd) { + batch.logits[batch.n_tokens - 1] = true; + + auto logit_index = batch.n_tokens - 1; + + return Napi::Number::From(info.Env(), logit_index); + } + + return info.Env().Undefined(); +} +Napi::Value AddonContext::DisposeSequence(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int32_t sequenceId = info[0].As().Int32Value(); + + bool result = llama_kv_cache_seq_rm(ctx, sequenceId, -1, -1); + + if (!result) { + Napi::Error::New(info.Env(), "Failed to dispose sequence").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return info.Env().Undefined(); +} +Napi::Value AddonContext::RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int32_t sequenceId = info[0].As().Int32Value(); + int32_t startPos = info[1].As().Int32Value(); + int32_t endPos = info[2].As().Int32Value(); + + bool result = llama_kv_cache_seq_rm(ctx, sequenceId, startPos, endPos); + + return Napi::Boolean::New(info.Env(), result); +} +Napi::Value AddonContext::ShiftSequenceTokenCells(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int32_t sequenceId = info[0].As().Int32Value(); + int32_t startPos = info[1].As().Int32Value(); + int32_t endPos = info[2].As().Int32Value(); + int32_t shiftDelta = info[3].As().Int32Value(); + + llama_kv_cache_seq_add(ctx, sequenceId, startPos, endPos, shiftDelta); + + return info.Env().Undefined(); +} +Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) { + AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info.Env(), this); + worker->Queue(); + return worker->GetPromise(); +} +Napi::Value AddonContext::SampleToken(const Napi::CallbackInfo& info) { + AddonContextSampleTokenWorker* worker = new AddonContextSampleTokenWorker(info, this); + worker->Queue(); + return worker->GetPromise(); +} + +Napi::Value AddonContext::AcceptGrammarEvaluationStateToken(const Napi::CallbackInfo& info) { + AddonGrammarEvaluationState* grammar_evaluation_state = + Napi::ObjectWrap::Unwrap(info[0].As()); + llama_token tokenId = info[1].As().Int32Value(); + + if ((grammar_evaluation_state)->grammar != nullptr) { + llama_grammar_accept_token((grammar_evaluation_state)->grammar, ctx, tokenId); + } + + return info.Env().Undefined(); +} + +Napi::Value AddonContext::CanBeNextTokenForGrammarEvaluationState(const Napi::CallbackInfo& info) { + AddonGrammarEvaluationState* grammar_evaluation_state = + Napi::ObjectWrap::Unwrap(info[0].As()); + llama_token tokenId = info[1].As().Int32Value(); + + if ((grammar_evaluation_state)->grammar != nullptr) { + std::vector candidates; + candidates.reserve(1); + candidates.emplace_back(llama_token_data { tokenId, 1, 0.0f }); + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + llama_grammar_sample((grammar_evaluation_state)->grammar, ctx, &candidates_p); + + if (candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) { + return Napi::Boolean::New(info.Env(), false); + } + + return Napi::Boolean::New(info.Env(), true); + } + + return Napi::Boolean::New(info.Env(), false); +} + +Napi::Value AddonContext::GetEmbedding(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int32_t inputTokensLength = info[0].As().Int32Value(); + + if (inputTokensLength <= 0) { + Napi::Error::New(info.Env(), "Invalid input tokens length").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + const int n_embd = llama_n_embd(model->model); + const auto* embeddings = llama_get_embeddings_seq(ctx, 0); + if (embeddings == NULL) { + embeddings = llama_get_embeddings_ith(ctx, inputTokensLength - 1); + + if (embeddings == NULL) { + Napi::Error::New(info.Env(), std::string("Failed to get embeddings for token ") + std::to_string(inputTokensLength - 1)).ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + } + + Napi::Float64Array result = Napi::Float64Array::New(info.Env(), n_embd); + for (size_t i = 0; i < n_embd; ++i) { + result[i] = embeddings[i]; + } + + return result; +} + +Napi::Value AddonContext::GetStateSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_state_get_size(ctx)); +} + +Napi::Value AddonContext::PrintTimings(const Napi::CallbackInfo& info) { + llama_print_timings(ctx); + llama_reset_timings(ctx); + return info.Env().Undefined(); +} + +Napi::Value AddonContext::SetLora(const Napi::CallbackInfo& info) { + AddonModelLora* lora = Napi::ObjectWrap::Unwrap(info[0].As()); + float scale = info[1].As().FloatValue(); + + llama_lora_adapter_set(ctx, lora->lora_adapter, scale); + + return info.Env().Undefined(); +} + +void AddonContext::init(Napi::Object exports) { + exports.Set( + "AddonContext", + DefineClass( + exports.Env(), + "AddonContext", + { + InstanceMethod("init", &AddonContext::Init), + InstanceMethod("getContextSize", &AddonContext::GetContextSize), + InstanceMethod("initBatch", &AddonContext::InitBatch), + InstanceMethod("addToBatch", &AddonContext::AddToBatch), + InstanceMethod("disposeSequence", &AddonContext::DisposeSequence), + InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence), + InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells), + InstanceMethod("decodeBatch", &AddonContext::DecodeBatch), + InstanceMethod("sampleToken", &AddonContext::SampleToken), + InstanceMethod("acceptGrammarEvaluationStateToken", &AddonContext::AcceptGrammarEvaluationStateToken), + InstanceMethod("canBeNextTokenForGrammarEvaluationState", &AddonContext::CanBeNextTokenForGrammarEvaluationState), + InstanceMethod("getEmbedding", &AddonContext::GetEmbedding), + InstanceMethod("getStateSize", &AddonContext::GetStateSize), + InstanceMethod("printTimings", &AddonContext::PrintTimings), + InstanceMethod("setLora", &AddonContext::SetLora), + InstanceMethod("dispose", &AddonContext::Dispose), + } + ) + ); +} diff --git a/llama/addon/AddonContext.h b/llama/addon/AddonContext.h new file mode 100644 index 00000000..f100b6b9 --- /dev/null +++ b/llama/addon/AddonContext.h @@ -0,0 +1,53 @@ +#pragma once +#include "llama.h" +#include "napi.h" +#include "addonGlobals.h" + +class AddonContext : public Napi::ObjectWrap { + public: + AddonModel* model; + llama_context_params context_params; + llama_context* ctx; + llama_batch batch; + uint64_t batchMemorySize; + bool has_batch; + int32_t batch_n_tokens; + int n_cur; + + uint64_t loadedContextMemorySize; + bool contextLoaded; + + bool disposed; + + AddonContext(const Napi::CallbackInfo& info); + ~AddonContext(); + + void dispose(); + void disposeBatch(); + + Napi::Value Init(const Napi::CallbackInfo& info); + Napi::Value Dispose(const Napi::CallbackInfo& info); + + Napi::Value GetContextSize(const Napi::CallbackInfo& info); + Napi::Value InitBatch(const Napi::CallbackInfo& info); + Napi::Value DisposeBatch(const Napi::CallbackInfo& info); + Napi::Value AddToBatch(const Napi::CallbackInfo& info); + Napi::Value DisposeSequence(const Napi::CallbackInfo& info); + Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info); + Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info); + Napi::Value DecodeBatch(const Napi::CallbackInfo& info); + Napi::Value SampleToken(const Napi::CallbackInfo& info); + + Napi::Value AcceptGrammarEvaluationStateToken(const Napi::CallbackInfo& info); + + Napi::Value CanBeNextTokenForGrammarEvaluationState(const Napi::CallbackInfo& info); + + Napi::Value GetEmbedding(const Napi::CallbackInfo& info); + Napi::Value GetStateSize(const Napi::CallbackInfo& info); + + Napi::Value PrintTimings(const Napi::CallbackInfo& info); + + Napi::Value SetLora(const Napi::CallbackInfo& info); + + static void init(Napi::Object exports); +}; \ No newline at end of file diff --git a/llama/addon/AddonGrammar.cpp b/llama/addon/AddonGrammar.cpp new file mode 100644 index 00000000..85e5327f --- /dev/null +++ b/llama/addon/AddonGrammar.cpp @@ -0,0 +1,44 @@ +#include "addonGlobals.h" +#include "AddonGrammar.h" + +AddonGrammar::AddonGrammar(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + hasAddonExportsRef = false; + + // Get the model path + std::string grammarCode = info[0].As().Utf8Value(); + bool should_print_grammar = false; + + if (info.Length() > 1 && info[1].IsObject()) { + Napi::Object options = info[1].As(); + + if (options.Has("addonExports")) { + addonExportsRef = Napi::Persistent(options.Get("addonExports").As()); + hasAddonExportsRef = true; + } + + if (options.Has("debugPrintGrammar")) { + should_print_grammar = options.Get("debugPrintGrammar").As().Value(); + } + } + + parsed_grammar = grammar_parser::parse(grammarCode.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + Napi::Error::New(info.Env(), "Failed to parse grammar").ThrowAsJavaScriptException(); + return; + } + + if (should_print_grammar) { + grammar_parser::print_grammar(stderr, parsed_grammar); + } +} +AddonGrammar::~AddonGrammar() { + if (hasAddonExportsRef) { + addonExportsRef.Unref(); + hasAddonExportsRef = false; + } +} + +void AddonGrammar::init(Napi::Object exports) { + exports.Set("AddonGrammar", DefineClass(exports.Env(), "AddonGrammar", {})); +} \ No newline at end of file diff --git a/llama/addon/AddonGrammar.h b/llama/addon/AddonGrammar.h new file mode 100644 index 00000000..e98abc4d --- /dev/null +++ b/llama/addon/AddonGrammar.h @@ -0,0 +1,18 @@ +#pragma once +#include "llama.h" +#include "common.h" +#include "common/grammar-parser.h" +#include "napi.h" +#include "addonGlobals.h" + +class AddonGrammar : public Napi::ObjectWrap { + public: + grammar_parser::parse_state parsed_grammar; + Napi::Reference addonExportsRef; + bool hasAddonExportsRef; + + AddonGrammar(const Napi::CallbackInfo& info); + ~AddonGrammar(); + + static void init(Napi::Object exports); +}; \ No newline at end of file diff --git a/llama/addon/AddonGrammarEvaluationState.cpp b/llama/addon/AddonGrammarEvaluationState.cpp new file mode 100644 index 00000000..8007b7e0 --- /dev/null +++ b/llama/addon/AddonGrammarEvaluationState.cpp @@ -0,0 +1,28 @@ +#include +#include "addonGlobals.h" +#include "common.h" +#include "llama.h" +#include "AddonGrammarEvaluationState.h" +#include "AddonGrammar.h" + +AddonGrammarEvaluationState::AddonGrammarEvaluationState(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + grammar = nullptr; + + grammarDef = Napi::ObjectWrap::Unwrap(info[0].As()); + grammarDef->Ref(); + + std::vector grammar_rules(grammarDef->parsed_grammar.c_rules()); + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), grammarDef->parsed_grammar.symbol_ids.at("root")); +} +AddonGrammarEvaluationState::~AddonGrammarEvaluationState() { + grammarDef->Unref(); + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + } +} + +void AddonGrammarEvaluationState::init(Napi::Object exports) { + exports.Set("AddonGrammarEvaluationState", DefineClass(exports.Env(), "AddonGrammarEvaluationState", {})); +} \ No newline at end of file diff --git a/llama/addon/AddonGrammarEvaluationState.h b/llama/addon/AddonGrammarEvaluationState.h new file mode 100644 index 00000000..f304d30c --- /dev/null +++ b/llama/addon/AddonGrammarEvaluationState.h @@ -0,0 +1,15 @@ +#pragma once +#include "llama.h" +#include "napi.h" +#include "addonGlobals.h" + +class AddonGrammarEvaluationState : public Napi::ObjectWrap { + public: + AddonGrammar* grammarDef; + llama_grammar* grammar; + + AddonGrammarEvaluationState(const Napi::CallbackInfo& info); + ~AddonGrammarEvaluationState(); + + static void init(Napi::Object exports); +}; \ No newline at end of file diff --git a/llama/addon/AddonModel.cpp b/llama/addon/AddonModel.cpp new file mode 100644 index 00000000..e0cd7da5 --- /dev/null +++ b/llama/addon/AddonModel.cpp @@ -0,0 +1,681 @@ +#include +#include "addonGlobals.h" +#include "globals/addonLog.h" +#include "common.h" +#include "llama.h" +#include "AddonModel.h" +#include "AddonModelData.h" +#include "AddonModelLora.h" + +static Napi::Value getNapiToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) { + if (token < 0) { + return Napi::Number::From(info.Env(), -1); + } + + auto tokenAttributes = llama_token_get_attr(model, token); + + if (tokenAttributes & LLAMA_TOKEN_ATTR_UNDEFINED || tokenAttributes & LLAMA_TOKEN_ATTR_UNKNOWN) { + return Napi::Number::From(info.Env(), -1); + } + + return Napi::Number::From(info.Env(), token); +} + +static Napi::Value getNapiControlToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) { + if (token < 0) { + return Napi::Number::From(info.Env(), -1); + } + + auto tokenAttributes = llama_token_get_attr(model, token); + + if (!(tokenAttributes & LLAMA_TOKEN_ATTR_CONTROL) && !(tokenAttributes & LLAMA_TOKEN_ATTR_UNDEFINED)) { + return Napi::Number::From(info.Env(), -1); + } + + return Napi::Number::From(info.Env(), token); +} + +static bool llamaModelParamsProgressCallback(float progress, void * user_data) { + AddonModel* addonModel = (AddonModel *) user_data; + unsigned percentage = (unsigned) (100 * progress); + + if (percentage > addonModel->modelLoadPercentage) { + addonModel->modelLoadPercentage = percentage; + + // original llama.cpp logs + addonLlamaCppLogCallback(GGML_LOG_LEVEL_INFO, ".", nullptr); + if (percentage >= 100) { + addonLlamaCppLogCallback(GGML_LOG_LEVEL_INFO, "\n", nullptr); + } + } + + if (progress > addonModel->rawModelLoadPercentage) { + addonModel->rawModelLoadPercentage = progress; + + if (addonModel->onLoadProgressEventCallbackSet) { + addon_progress_event* data = new addon_progress_event { + progress + }; + + auto status = addonModel->addonThreadSafeOnLoadProgressEventCallback.NonBlockingCall(data); + + if (status != napi_ok) { + delete data; + } + } + } + + return !(addonModel->abortModelLoad); +} + +class AddonModelLoadModelWorker : public Napi::AsyncWorker { + public: + AddonModel* model; + + AddonModelLoadModelWorker(const Napi::Env& env, AddonModel* model) + : Napi::AsyncWorker(env, "AddonModelLoadModelWorker"), + model(model), + deferred(Napi::Promise::Deferred::New(env)) { + model->Ref(); + } + ~AddonModelLoadModelWorker() { + model->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + model->model = llama_load_model_from_file(model->modelPath.c_str(), model->model_params); + + model->modelLoaded = model->model != nullptr && model->model != NULL; + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_load_model_from_file\""); + } + } + void OnOK() { + if (model->modelLoaded) { + uint64_t modelSize = llama_model_size(model->model); + adjustNapiExternalMemoryAdd(Env(), modelSize); + model->loadedModelSize = modelSize; + } + + deferred.Resolve(Napi::Boolean::New(Env(), model->modelLoaded)); + if (model->onLoadProgressEventCallbackSet) { + model->addonThreadSafeOnLoadProgressEventCallback.Release(); + } + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +class AddonModelUnloadModelWorker : public Napi::AsyncWorker { + public: + AddonModel* model; + + AddonModelUnloadModelWorker(const Napi::Env& env, AddonModel* model) + : Napi::AsyncWorker(env, "AddonModelUnloadModelWorker"), + model(model), + deferred(Napi::Promise::Deferred::New(env)) { + model->Ref(); + } + ~AddonModelUnloadModelWorker() { + model->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + llama_free_model(model->model); + model->modelLoaded = false; + + model->dispose(); + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_free_model\""); + } + } + void OnOK() { + adjustNapiExternalMemorySubtract(Env(), model->loadedModelSize); + model->loadedModelSize = 0; + + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +class AddonModelLoadLoraWorker : public Napi::AsyncWorker { + public: + AddonModelLora* modelLora; + + AddonModelLoadLoraWorker( + const Napi::Env& env, + AddonModelLora* modelLora + ) + : Napi::AsyncWorker(env, "AddonModelLoadLoraWorker"), + modelLora(modelLora), + deferred(Napi::Promise::Deferred::New(env)) { + modelLora->model->Ref(); + modelLora->Ref(); + } + ~AddonModelLoadLoraWorker() { + modelLora->model->Unref(); + modelLora->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + const auto loraAdapter = llama_lora_adapter_init(modelLora->model->model, modelLora->loraFilePath.c_str()); + + if (loraAdapter == nullptr) { + SetError( + std::string( + std::string("Failed to initialize LoRA adapter \"" + modelLora->loraFilePath + "\"") + ) + ); + return; + } + + modelLora->lora_adapter = loraAdapter; + modelLora->model->Ref(); + + if (modelLora->model->data != nullptr) { + modelLora->model->data->loraAdapters.insert(modelLora); + } else { + modelLora->dispose(true); + SetError("Model data is not initialized"); + } + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_lora_adapter_init\""); + } + } + void OnOK() { + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +AddonModel::AddonModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + loadedModelSize = 0; + hasAddonExportsRef = false; + modelLoaded = false; + abortModelLoad = false; + model_load_stopped = false; + rawModelLoadPercentage = 0; + modelLoadPercentage = 0; + onLoadProgressEventCallbackSet = false; + hasLoadAbortSignal = false; + disposed = false; + + data = new AddonModelData(); + model_params = llama_model_default_params(); + + // Get the model path + modelPath = info[0].As().Utf8Value(); + + if (info.Length() > 1 && info[1].IsObject()) { + Napi::Object options = info[1].As(); + + if (options.Has("addonExports")) { + addonExportsRef = Napi::Persistent(options.Get("addonExports").As()); + hasAddonExportsRef = true; + } + + if (options.Has("gpuLayers")) { + model_params.n_gpu_layers = options.Get("gpuLayers").As().Int32Value(); + } + + if (options.Has("vocabOnly")) { + model_params.vocab_only = options.Get("vocabOnly").As().Value(); + } + + if (options.Has("useMmap")) { + model_params.use_mmap = options.Get("useMmap").As().Value(); + } + + if (options.Has("useMlock")) { + model_params.use_mlock = options.Get("useMlock").As().Value(); + } + + if (options.Has("checkTensors")) { + model_params.check_tensors = options.Get("checkTensors").As().Value(); + } + + if (options.Has("onLoadProgress")) { + auto onLoadProgressJSCallback = options.Get("onLoadProgress").As(); + if (onLoadProgressJSCallback.IsFunction()) { + AddonThreadSafeProgressCallbackFunctionContext* context = new Napi::Reference(Napi::Persistent(info.This())); + addonThreadSafeOnLoadProgressEventCallback = AddonThreadSafeProgressEventCallbackFunction::New( + info.Env(), + onLoadProgressJSCallback, + "onLoadProgressCallback", + 0, + 1, + context, + [](Napi::Env, AddonModel* addonModel, AddonThreadSafeProgressCallbackFunctionContext* ctx) { + addonModel->onLoadProgressEventCallbackSet = false; + + delete ctx; + }, + this + ); + onLoadProgressEventCallbackSet = true; + } + } + + if (options.Has("hasLoadAbortSignal")) { + hasLoadAbortSignal = options.Get("hasLoadAbortSignal").As().Value(); + } + + if (options.Has("overridesList")) { + Napi::Array overridesList = options.Get("overridesList").As(); + kv_overrides.reserve(overridesList.Length()); + + for (uint32_t i = 0; i < overridesList.Length(); i++) { + Napi::Array overrideItem = overridesList.Get(i).As(); + auto key = overrideItem.Get((uint32_t)0).As().Utf8Value(); + auto value = overrideItem.Get((uint32_t)1); + + if (key.length() > 127) { + continue; + } + + llama_model_kv_override kvo; + std::strncpy(kvo.key, key.c_str(), key.length()); + kvo.key[key.length()] = 0; + + if (value.IsString()) { + auto valueString = value.As().Utf8Value(); + if (valueString.length() > 127) { + continue; + } + + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + std::strncpy(kvo.val_str, valueString.c_str(), valueString.length()); + kvo.val_str[valueString.length()] = 0; + + fputs(std::string("Override: " + key + " = " + valueString + "\n").c_str(), stdout); + fflush(stdout); + } else if (value.IsNumber() || value.IsBigInt()) { + auto numberType = overrideItem.Get((uint32_t)2).As().Int32Value(); + if (numberType == 0) { + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = value.As().Int64Value(); + } else { + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = value.As().DoubleValue(); + } + + continue; + } else if (value.IsBoolean()) { + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + kvo.val_bool = value.As().Value(); + } + + kv_overrides.emplace_back(std::move(kvo)); + } + + if (!kv_overrides.empty()) { + kv_overrides.emplace_back(); + kv_overrides.back().key[0] = 0; + } + + model_params.kv_overrides = kv_overrides.data(); + } + + if (onLoadProgressEventCallbackSet || hasLoadAbortSignal) { + model_params.progress_callback_user_data = &(*this); + model_params.progress_callback = llamaModelParamsProgressCallback; + } + } +} + +AddonModel::~AddonModel() { + dispose(); +} +void AddonModel::dispose() { + if (disposed) { + return; + } + + disposed = true; + if (modelLoaded) { + modelLoaded = false; + llama_free_model(model); + + adjustNapiExternalMemorySubtract(Env(), loadedModelSize); + loadedModelSize = 0; + } + + if (data != nullptr) { + auto currentData = data; + data = nullptr; + delete currentData; + } + + if (hasAddonExportsRef) { + addonExportsRef.Unref(); + hasAddonExportsRef = false; + } +} + +Napi::Value AddonModel::Init(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + AddonModelLoadModelWorker* worker = new AddonModelLoadModelWorker(this->Env(), this); + worker->Queue(); + return worker->GetPromise(); +} +Napi::Value AddonModel::LoadLora(const Napi::CallbackInfo& info) { + AddonModelLora* modelLora = Napi::ObjectWrap::Unwrap(info[0].As()); + AddonModelLoadLoraWorker* worker = new AddonModelLoadLoraWorker(this->Env(), modelLora); + worker->Queue(); + return worker->GetPromise(); +} +Napi::Value AddonModel::AbortActiveModelLoad(const Napi::CallbackInfo& info) { + abortModelLoad = true; + return info.Env().Undefined(); +} +Napi::Value AddonModel::Dispose(const Napi::CallbackInfo& info) { + if (disposed) { + return info.Env().Undefined(); + } + + if (modelLoaded) { + modelLoaded = false; + + AddonModelUnloadModelWorker* worker = new AddonModelUnloadModelWorker(this->Env(), this); + worker->Queue(); + return worker->GetPromise(); + } else { + dispose(); + + Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); + deferred.Resolve(info.Env().Undefined()); + return deferred.Promise(); + } +} + +Napi::Value AddonModel::Tokenize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + std::string text = info[0].As().Utf8Value(); + bool specialTokens = info[1].As().Value(); + + std::vector tokens = llama_tokenize(model, text, false, specialTokens); + + Napi::Uint32Array result = Napi::Uint32Array::New(info.Env(), tokens.size()); + for (size_t i = 0; i < tokens.size(); ++i) { + result[i] = static_cast(tokens[i]); + } + + return result; +} +Napi::Value AddonModel::Detokenize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + Napi::Uint32Array tokens = info[0].As(); + bool decodeSpecialTokens = info.Length() > 0 + ? info[1].As().Value() + : false; + + std::vector result(8, 0); + const int n_length = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens); + + if (n_length < 0) { + result.resize(-n_length); + int check = llama_detokenize(model, (llama_token*)tokens.Data(), tokens.ElementLength(), result.data(), result.size(), false, decodeSpecialTokens); + GGML_ASSERT(check == -n_length); + } else { + result.resize(n_length); + } + + return Napi::String::New(info.Env(), result.data(), result.size()); +} + +Napi::Value AddonModel::GetTrainContextSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_n_ctx_train(model)); +} + +Napi::Value AddonModel::GetEmbeddingVectorSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_n_embd(model)); +} + +Napi::Value AddonModel::GetTotalSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_model_size(model)); +} + +Napi::Value AddonModel::GetTotalParameters(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_model_n_params(model)); +} + +Napi::Value AddonModel::GetModelDescription(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + char model_desc[128]; + int actual_length = llama_model_desc(model, model_desc, sizeof(model_desc)); + + return Napi::String::New(info.Env(), model_desc, actual_length); +} + +Napi::Value AddonModel::TokenBos(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_bos(model)); +} +Napi::Value AddonModel::TokenEos(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_eos(model)); +} +Napi::Value AddonModel::TokenNl(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiToken(info, model, llama_token_nl(model)); +} +Napi::Value AddonModel::PrefixToken(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_prefix(model)); +} +Napi::Value AddonModel::MiddleToken(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_middle(model)); +} +Napi::Value AddonModel::SuffixToken(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_suffix(model)); +} +Napi::Value AddonModel::EotToken(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return getNapiControlToken(info, model, llama_token_eot(model)); +} +Napi::Value AddonModel::GetTokenString(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + int token = info[0].As().Int32Value(); + std::stringstream ss; + + const char* str = llama_token_get_text(model, token); + if (str == nullptr) { + return info.Env().Undefined(); + } + + ss << str; + + return Napi::String::New(info.Env(), ss.str()); +} + +Napi::Value AddonModel::GetTokenAttributes(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + if (info[0].IsNumber() == false) { + return Napi::Number::From(info.Env(), int32_t(LLAMA_TOKEN_ATTR_UNDEFINED)); + } + + int token = info[0].As().Int32Value(); + auto tokenAttributes = llama_token_get_attr(model, token); + + return Napi::Number::From(info.Env(), int32_t(tokenAttributes)); +} +Napi::Value AddonModel::IsEogToken(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + if (info[0].IsNumber() == false) { + return Napi::Boolean::New(info.Env(), false); + } + + int token = info[0].As().Int32Value(); + + return Napi::Boolean::New(info.Env(), llama_token_is_eog(model, token)); +} +Napi::Value AddonModel::GetVocabularyType(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + auto vocabularyType = llama_vocab_type(model); + + return Napi::Number::From(info.Env(), int32_t(vocabularyType)); +} +Napi::Value AddonModel::ShouldPrependBosToken(const Napi::CallbackInfo& info) { + const int addBos = llama_add_bos_token(model); + + bool shouldPrependBos = addBos != -1 ? bool(addBos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); + + return Napi::Boolean::New(info.Env(), shouldPrependBos); +} + +Napi::Value AddonModel::GetModelSize(const Napi::CallbackInfo& info) { + return Napi::Number::From(info.Env(), llama_model_size(model)); +} + +void AddonModel::init(Napi::Object exports) { + exports.Set( + "AddonModel", + DefineClass( + exports.Env(), + "AddonModel", + { + InstanceMethod("init", &AddonModel::Init), + InstanceMethod("loadLora", &AddonModel::LoadLora), + InstanceMethod("abortActiveModelLoad", &AddonModel::AbortActiveModelLoad), + InstanceMethod("tokenize", &AddonModel::Tokenize), + InstanceMethod("detokenize", &AddonModel::Detokenize), + InstanceMethod("getTrainContextSize", &AddonModel::GetTrainContextSize), + InstanceMethod("getEmbeddingVectorSize", &AddonModel::GetEmbeddingVectorSize), + InstanceMethod("getTotalSize", &AddonModel::GetTotalSize), + InstanceMethod("getTotalParameters", &AddonModel::GetTotalParameters), + InstanceMethod("getModelDescription", &AddonModel::GetModelDescription), + InstanceMethod("tokenBos", &AddonModel::TokenBos), + InstanceMethod("tokenEos", &AddonModel::TokenEos), + InstanceMethod("tokenNl", &AddonModel::TokenNl), + InstanceMethod("prefixToken", &AddonModel::PrefixToken), + InstanceMethod("middleToken", &AddonModel::MiddleToken), + InstanceMethod("suffixToken", &AddonModel::SuffixToken), + InstanceMethod("eotToken", &AddonModel::EotToken), + InstanceMethod("getTokenString", &AddonModel::GetTokenString), + InstanceMethod("getTokenAttributes", &AddonModel::GetTokenAttributes), + InstanceMethod("isEogToken", &AddonModel::IsEogToken), + InstanceMethod("getVocabularyType", &AddonModel::GetVocabularyType), + InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken), + InstanceMethod("getModelSize", &AddonModel::GetModelSize), + InstanceMethod("dispose", &AddonModel::Dispose), + } + ) + ); +} diff --git a/llama/addon/AddonModel.h b/llama/addon/AddonModel.h new file mode 100644 index 00000000..f56b45d8 --- /dev/null +++ b/llama/addon/AddonModel.h @@ -0,0 +1,61 @@ +#pragma once +#include "llama.h" +#include "napi.h" +#include "addonGlobals.h" +#include "globals/addonProgress.h" + +class AddonModel : public Napi::ObjectWrap { + public: + llama_model_params model_params; + std::vector kv_overrides; + llama_model* model; + uint64_t loadedModelSize; + Napi::Reference addonExportsRef; + bool hasAddonExportsRef; + AddonModelData* data; + + std::string modelPath; + bool modelLoaded; + bool abortModelLoad; + bool model_load_stopped; + float rawModelLoadPercentage; + unsigned modelLoadPercentage; + AddonThreadSafeProgressEventCallbackFunction addonThreadSafeOnLoadProgressEventCallback; + bool onLoadProgressEventCallbackSet; + bool hasLoadAbortSignal; + + bool disposed; + + AddonModel(const Napi::CallbackInfo& info); + ~AddonModel(); + void dispose(); + + Napi::Value Init(const Napi::CallbackInfo& info); + Napi::Value LoadLora(const Napi::CallbackInfo& info); + Napi::Value AbortActiveModelLoad(const Napi::CallbackInfo& info); + Napi::Value Dispose(const Napi::CallbackInfo& info); + Napi::Value Tokenize(const Napi::CallbackInfo& info); + Napi::Value Detokenize(const Napi::CallbackInfo& info); + Napi::Value GetTrainContextSize(const Napi::CallbackInfo& info); + Napi::Value GetEmbeddingVectorSize(const Napi::CallbackInfo& info); + Napi::Value GetTotalSize(const Napi::CallbackInfo& info); + Napi::Value GetTotalParameters(const Napi::CallbackInfo& info); + Napi::Value GetModelDescription(const Napi::CallbackInfo& info); + + Napi::Value TokenBos(const Napi::CallbackInfo& info); + Napi::Value TokenEos(const Napi::CallbackInfo& info); + Napi::Value TokenNl(const Napi::CallbackInfo& info); + Napi::Value PrefixToken(const Napi::CallbackInfo& info); + Napi::Value MiddleToken(const Napi::CallbackInfo& info); + Napi::Value SuffixToken(const Napi::CallbackInfo& info); + Napi::Value EotToken(const Napi::CallbackInfo& info); + Napi::Value GetTokenString(const Napi::CallbackInfo& info); + + Napi::Value GetTokenAttributes(const Napi::CallbackInfo& info); + Napi::Value IsEogToken(const Napi::CallbackInfo& info); + Napi::Value GetVocabularyType(const Napi::CallbackInfo& info); + Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info); + Napi::Value GetModelSize(const Napi::CallbackInfo& info); + + static void init(Napi::Object exports); +}; diff --git a/llama/addon/AddonModelData.cpp b/llama/addon/AddonModelData.cpp new file mode 100644 index 00000000..3c1758a3 --- /dev/null +++ b/llama/addon/AddonModelData.cpp @@ -0,0 +1,25 @@ +#include + +#include "addonGlobals.h" +#include "AddonModelData.h" +#include "AddonModelLora.h" + +AddonModelData::AddonModelData() { + +} +AddonModelData::~AddonModelData() { + std::set currentLoraAdapters; + currentLoraAdapters.swap(loraAdapters); + + for (auto lora : currentLoraAdapters) { + lora->dispose(true); + } + currentLoraAdapters.clear(); +} + +void AddonModelData::removeLora(AddonModelLora* lora) { + auto pos = loraAdapters.find(lora); + if (pos != loraAdapters.end()) { + loraAdapters.erase(pos); + } +} \ No newline at end of file diff --git a/llama/addon/AddonModelData.h b/llama/addon/AddonModelData.h new file mode 100644 index 00000000..78c82497 --- /dev/null +++ b/llama/addon/AddonModelData.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include "llama.h" +#include "napi.h" +#include "addonGlobals.h" + +class AddonModelData { + public: + std::set loraAdapters; + + AddonModelData(); + ~AddonModelData(); + + void removeLora(AddonModelLora* lora); +}; \ No newline at end of file diff --git a/llama/addon/AddonModelLora.cpp b/llama/addon/AddonModelLora.cpp new file mode 100644 index 00000000..085ecce4 --- /dev/null +++ b/llama/addon/AddonModelLora.cpp @@ -0,0 +1,107 @@ +#include "addonGlobals.h" +#include "AddonModel.h" +#include "AddonModelData.h" +#include "AddonModelLora.h" + +class AddonModelLoraUnloadLoraWorker : public Napi::AsyncWorker { + public: + AddonModelLora* addonLora; + + AddonModelLoraUnloadLoraWorker(const Napi::Env& env, AddonModelLora* addonLora) + : Napi::AsyncWorker(env, "AddonModelLoraUnloadLoraWorker"), + addonLora(addonLora), + deferred(Napi::Promise::Deferred::New(env)) { + addonLora->Ref(); + } + ~AddonModelLoraUnloadLoraWorker() { + addonLora->Unref(); + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + addonLora->dispose(); + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_lora_adapter_free\""); + } + } + void OnOK() { + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +AddonModelLora::AddonModelLora(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + usages = 0; + + model = Napi::ObjectWrap::Unwrap(info[0].As()); + loraFilePath = info[1].As().Utf8Value(); + lora_adapter = nullptr; +} + +AddonModelLora::~AddonModelLora() { + dispose(); +} + +void AddonModelLora::dispose(bool skipErase) { + if (lora_adapter != nullptr) { + auto loraAdapterToDispose = lora_adapter; + lora_adapter = nullptr; + llama_lora_adapter_free(loraAdapterToDispose); + + if (!skipErase && model->data != nullptr) { + model->data->removeLora(this); + } + + model->Unref(); + } +} + +Napi::Value AddonModelLora::GetFilePath(const Napi::CallbackInfo& info) { + return Napi::String::New(info.Env(), loraFilePath); +} + + +Napi::Value AddonModelLora::GetUsages(const Napi::CallbackInfo& info) { + return Napi::Number::From(info.Env(), usages); +} + +void AddonModelLora::SetUsages(const Napi::CallbackInfo& info, const Napi::Value &value) { + usages = value.As().Uint32Value(); +} + +Napi::Value AddonModelLora::Dispose(const Napi::CallbackInfo& info) { + AddonModelLoraUnloadLoraWorker* worker = new AddonModelLoraUnloadLoraWorker(this->Env(), this); + worker->Queue(); + return worker->GetPromise(); +} + +Napi::Value AddonModelLora::GetDisposed(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), lora_adapter == nullptr); +} + +void AddonModelLora::init(Napi::Object exports) { + exports.Set( + "AddonModelLora", + DefineClass( + exports.Env(), + "AddonModelLora", + { + InstanceAccessor("usages", &AddonModelLora::GetUsages, &AddonModelLora::SetUsages), + InstanceAccessor("filePath", &AddonModelLora::GetFilePath, nullptr), + InstanceAccessor("disposed", &AddonModelLora::GetDisposed, nullptr), + InstanceMethod("dispose", &AddonModelLora::Dispose), + } + ) + ); +} diff --git a/llama/addon/AddonModelLora.h b/llama/addon/AddonModelLora.h new file mode 100644 index 00000000..d3ee7cd4 --- /dev/null +++ b/llama/addon/AddonModelLora.h @@ -0,0 +1,28 @@ +#pragma once +#include "llama.h" +#include "napi.h" +#include "addonGlobals.h" + +class AddonModelLora : public Napi::ObjectWrap { + public: + AddonModel* model; + llama_lora_adapter * lora_adapter; + std::string loraFilePath; + uint32_t usages; + + AddonModelLora(const Napi::CallbackInfo& info); + ~AddonModelLora(); + + void dispose(bool skipErase = false); + + Napi::Value GetFilePath(const Napi::CallbackInfo& info); + + Napi::Value GetUsages(const Napi::CallbackInfo& info); + void SetUsages(const Napi::CallbackInfo& info, const Napi::Value &value); + + Napi::Value GetDisposed(const Napi::CallbackInfo& info); + + Napi::Value Dispose(const Napi::CallbackInfo& info); + + static void init(Napi::Object exports); +}; diff --git a/llama/addon/addon.cpp b/llama/addon/addon.cpp new file mode 100644 index 00000000..83b2b503 --- /dev/null +++ b/llama/addon/addon.cpp @@ -0,0 +1,217 @@ +#include "addonGlobals.h" +#include "AddonModel.h" +#include "AddonModelLora.h" +#include "AddonGrammar.h" +#include "AddonGrammarEvaluationState.h" +#include "AddonContext.h" +#include "globals/addonLog.h" +#include "globals/addonProgress.h" +#include "globals/getGpuInfo.h" + +bool backendInitialized = false; +bool backendDisposed = false; + +Napi::Value systemInfo(const Napi::CallbackInfo& info) { + return Napi::String::From(info.Env(), llama_print_system_info()); +} + +Napi::Value addonGetSupportsGpuOffloading(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_gpu_offload()); +} + +Napi::Value addonGetSupportsMmap(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_mmap()); +} + +Napi::Value addonGetSupportsMlock(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_mlock()); +} + +Napi::Value addonGetBlockSizeForGgmlType(const Napi::CallbackInfo& info) { + const int ggmlType = info[0].As().Int32Value(); + + if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { + return info.Env().Undefined(); + } + + const auto blockSize = ggml_blck_size(static_cast(ggmlType)); + + return Napi::Number::New(info.Env(), blockSize); +} + +Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) { + const int ggmlType = info[0].As().Int32Value(); + + if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { + return info.Env().Undefined(); + } + + const auto typeSize = ggml_type_size(static_cast(ggmlType)); + + return Napi::Number::New(info.Env(), typeSize); +} + +Napi::Value addonGetConsts(const Napi::CallbackInfo& info) { + Napi::Object consts = Napi::Object::New(info.Env()); + consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS)); + consts.Set("ggmlTypeF16Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F16))); + consts.Set("ggmlTypeF32Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F32))); + consts.Set("ggmlTensorOverhead", Napi::Number::New(info.Env(), ggml_tensor_overhead())); + consts.Set("llamaMaxRngState", Napi::Number::New(info.Env(), LLAMA_MAX_RNG_STATE)); + consts.Set("llamaPosSize", Napi::Number::New(info.Env(), sizeof(llama_pos))); + consts.Set("llamaSeqIdSize", Napi::Number::New(info.Env(), sizeof(llama_seq_id))); + + return consts; +} + +class AddonBackendLoadWorker : public Napi::AsyncWorker { + public: + AddonBackendLoadWorker(const Napi::Env& env) + : Napi::AsyncWorker(env, "AddonBackendLoadWorker"), + deferred(Napi::Promise::Deferred::New(env)) { + } + ~AddonBackendLoadWorker() { + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + llama_backend_init(); + + try { + if (backendDisposed) { + llama_backend_free(); + } else { + backendInitialized = true; + } + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_backend_free\""); + } + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_backend_init\""); + } + } + void OnOK() { + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + + +class AddonBackendUnloadWorker : public Napi::AsyncWorker { + public: + AddonBackendUnloadWorker(const Napi::Env& env) + : Napi::AsyncWorker(env, "AddonBackendUnloadWorker"), + deferred(Napi::Promise::Deferred::New(env)) { + } + ~AddonBackendUnloadWorker() { + } + + Napi::Promise GetPromise() { + return deferred.Promise(); + } + + protected: + Napi::Promise::Deferred deferred; + + void Execute() { + try { + if (backendInitialized) { + backendInitialized = false; + llama_backend_free(); + } + } catch (const std::exception& e) { + SetError(e.what()); + } catch(...) { + SetError("Unknown error when calling \"llama_backend_free\""); + } + } + void OnOK() { + deferred.Resolve(Env().Undefined()); + } + void OnError(const Napi::Error& err) { + deferred.Reject(err.Value()); + } +}; + +Napi::Value addonInit(const Napi::CallbackInfo& info) { + if (backendInitialized) { + Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); + deferred.Resolve(info.Env().Undefined()); + return deferred.Promise(); + } + + AddonBackendLoadWorker* worker = new AddonBackendLoadWorker(info.Env()); + worker->Queue(); + return worker->GetPromise(); +} + +Napi::Value addonDispose(const Napi::CallbackInfo& info) { + if (backendDisposed) { + Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); + deferred.Resolve(info.Env().Undefined()); + return deferred.Promise(); + } + + backendDisposed = true; + + AddonBackendUnloadWorker* worker = new AddonBackendUnloadWorker(info.Env()); + worker->Queue(); + return worker->GetPromise(); +} + +static void addonFreeLlamaBackend(Napi::Env env, int* data) { + if (backendDisposed) { + return; + } + + backendDisposed = true; + if (backendInitialized) { + backendInitialized = false; + llama_backend_free(); + } +} + +Napi::Object registerCallback(Napi::Env env, Napi::Object exports) { + exports.DefineProperties({ + Napi::PropertyDescriptor::Function("systemInfo", systemInfo), + Napi::PropertyDescriptor::Function("getSupportsGpuOffloading", addonGetSupportsGpuOffloading), + Napi::PropertyDescriptor::Function("getSupportsMmap", addonGetSupportsMmap), + Napi::PropertyDescriptor::Function("getSupportsMlock", addonGetSupportsMlock), + Napi::PropertyDescriptor::Function("getBlockSizeForGgmlType", addonGetBlockSizeForGgmlType), + Napi::PropertyDescriptor::Function("getTypeSizeForGgmlType", addonGetTypeSizeForGgmlType), + Napi::PropertyDescriptor::Function("getConsts", addonGetConsts), + Napi::PropertyDescriptor::Function("setLogger", setLogger), + Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel), + Napi::PropertyDescriptor::Function("getGpuVramInfo", getGpuVramInfo), + Napi::PropertyDescriptor::Function("getGpuDeviceInfo", getGpuDeviceInfo), + Napi::PropertyDescriptor::Function("getGpuType", getGpuType), + Napi::PropertyDescriptor::Function("init", addonInit), + Napi::PropertyDescriptor::Function("dispose", addonDispose), + }); + AddonModel::init(exports); + AddonModelLora::init(exports); + AddonGrammar::init(exports); + AddonGrammarEvaluationState::init(exports); + AddonContext::init(exports); + + llama_log_set(addonLlamaCppLogCallback, nullptr); + + exports.AddFinalizer(addonFreeLlamaBackend, static_cast(nullptr)); + + return exports; +} + +NODE_API_MODULE(NODE_GYP_MODULE_NAME, registerCallback) diff --git a/llama/addon/addonGlobals.cpp b/llama/addon/addonGlobals.cpp new file mode 100644 index 00000000..2d73c466 --- /dev/null +++ b/llama/addon/addonGlobals.cpp @@ -0,0 +1,22 @@ +#include +#include +#include "addonGlobals.h" +#include "napi.h" + +void adjustNapiExternalMemoryAdd(Napi::Env env, uint64_t size) { + const uint64_t chunkSize = std::numeric_limits::max(); + while (size > 0) { + int64_t adjustSize = std::min(size, chunkSize); + Napi::MemoryManagement::AdjustExternalMemory(env, adjustSize); + size -= adjustSize; + } +} + +void adjustNapiExternalMemorySubtract(Napi::Env env, uint64_t size) { + const uint64_t chunkSize = std::numeric_limits::max(); + while (size > 0) { + int64_t adjustSize = std::min(size, chunkSize); + Napi::MemoryManagement::AdjustExternalMemory(env, -adjustSize); + size -= adjustSize; + } +} diff --git a/llama/addon/addonGlobals.h b/llama/addon/addonGlobals.h new file mode 100644 index 00000000..1a4dd8d1 --- /dev/null +++ b/llama/addon/addonGlobals.h @@ -0,0 +1,12 @@ +#pragma once +#include "napi.h" + +class AddonModel; +class AddonModelLora; +class AddonModelData; +class AddonContext; +class AddonGrammar; +class AddonGrammarEvaluationState; + +void adjustNapiExternalMemoryAdd(Napi::Env env, uint64_t size); +void adjustNapiExternalMemorySubtract(Napi::Env env, uint64_t size); diff --git a/llama/addon/globals/addonLog.cpp b/llama/addon/globals/addonLog.cpp new file mode 100644 index 00000000..c93002ea --- /dev/null +++ b/llama/addon/globals/addonLog.cpp @@ -0,0 +1,135 @@ +#include + +#include "addonLog.h" + +AddonThreadSafeLogCallbackFunction addonThreadSafeLoggerCallback; +bool addonJsLoggerCallbackSet = false; +int addonLoggerLogLevel = 5; + +static int addonGetGgmlLogLevelNumber(ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_ERROR: return 2; + case GGML_LOG_LEVEL_WARN: return 3; + case GGML_LOG_LEVEL_INFO: return 4; + case GGML_LOG_LEVEL_DEBUG: return 5; + } + + return 1; +} + +void addonCallJsLogCallback( + Napi::Env env, Napi::Function callback, AddonThreadSafeLogCallbackFunctionContext* context, addon_logger_log* data +) { + bool called = false; + + if (env != nullptr && callback != nullptr && addonJsLoggerCallbackSet) { + try { + callback.Call({ + Napi::Number::New(env, data->logLevelNumber), + Napi::String::New(env, data->stringStream->str()), + }); + called = true; + } catch (const Napi::Error& e) { + called = false; + } + } + + if (!called && data != nullptr) { + if (data->logLevelNumber == 2) { + fputs(data->stringStream->str().c_str(), stderr); + fflush(stderr); + } else { + fputs(data->stringStream->str().c_str(), stdout); + fflush(stdout); + } + } + + if (data != nullptr) { + delete data->stringStream; + delete data; + } +} + +void addonLlamaCppLogCallback(ggml_log_level level, const char* text, void* user_data) { + int logLevelNumber = addonGetGgmlLogLevelNumber(level); + + if (logLevelNumber > addonLoggerLogLevel) { + return; + } + + if (addonJsLoggerCallbackSet) { + std::stringstream* stringStream = new std::stringstream(); + if (text != nullptr) { + *stringStream << text; + } + + addon_logger_log* data = new addon_logger_log { + logLevelNumber, + stringStream, + }; + + auto status = addonThreadSafeLoggerCallback.NonBlockingCall(data); + + if (status == napi_ok) { + return; + } else { + delete stringStream; + delete data; + } + } + + if (text != nullptr) { + if (level == 2) { + fputs(text, stderr); + fflush(stderr); + } else { + fputs(text, stdout); + fflush(stdout); + } + } +} + +Napi::Value setLogger(const Napi::CallbackInfo& info) { + if (info.Length() < 1 || !info[0].IsFunction()) { + if (addonJsLoggerCallbackSet) { + addonJsLoggerCallbackSet = false; + addonThreadSafeLoggerCallback.Release(); + } + + return info.Env().Undefined(); + } + + auto addonLoggerJSCallback = info[0].As(); + AddonThreadSafeLogCallbackFunctionContext* context = new Napi::Reference(Napi::Persistent(info.This())); + addonThreadSafeLoggerCallback = AddonThreadSafeLogCallbackFunction::New( + info.Env(), + addonLoggerJSCallback, + "loggerCallback", + 0, + 1, + context, + [](Napi::Env, void*, AddonThreadSafeLogCallbackFunctionContext* ctx) { + addonJsLoggerCallbackSet = false; + + delete ctx; + } + ); + addonJsLoggerCallbackSet = true; + + // prevent blocking the main node process from exiting due to active resources + addonThreadSafeLoggerCallback.Unref(info.Env()); + + return info.Env().Undefined(); +} + +Napi::Value setLoggerLogLevel(const Napi::CallbackInfo& info) { + if (info.Length() < 1 || !info[0].IsNumber()) { + addonLoggerLogLevel = 5; + + return info.Env().Undefined(); + } + + addonLoggerLogLevel = info[0].As().Int32Value(); + + return info.Env().Undefined(); +} diff --git a/llama/addon/globals/addonLog.h b/llama/addon/globals/addonLog.h new file mode 100644 index 00000000..54879ff5 --- /dev/null +++ b/llama/addon/globals/addonLog.h @@ -0,0 +1,21 @@ +#pragma once +#include "llama.h" +#include "napi.h" + +struct addon_logger_log { + public: + const int logLevelNumber; + const std::stringstream* stringStream; +}; + +void addonLlamaCppLogCallback(ggml_log_level level, const char* text, void* user_data); + +using AddonThreadSafeLogCallbackFunctionContext = Napi::Reference; +void addonCallJsLogCallback( + Napi::Env env, Napi::Function callback, AddonThreadSafeLogCallbackFunctionContext* context, addon_logger_log* data +); +using AddonThreadSafeLogCallbackFunction = + Napi::TypedThreadSafeFunction; + +Napi::Value setLogger(const Napi::CallbackInfo& info); +Napi::Value setLoggerLogLevel(const Napi::CallbackInfo& info); \ No newline at end of file diff --git a/llama/addon/globals/addonProgress.cpp b/llama/addon/globals/addonProgress.cpp new file mode 100644 index 00000000..b4f62232 --- /dev/null +++ b/llama/addon/globals/addonProgress.cpp @@ -0,0 +1,15 @@ +#include "addonProgress.h" + +void addonCallJsProgressCallback( + Napi::Env env, Napi::Function callback, AddonThreadSafeProgressCallbackFunctionContext* context, addon_progress_event* data +) { + if (env != nullptr && callback != nullptr) { + try { + callback.Call({Napi::Number::New(env, data->progress)}); + } catch (const Napi::Error& e) {} + } + + if (data != nullptr) { + delete data; + } +} diff --git a/llama/addon/globals/addonProgress.h b/llama/addon/globals/addonProgress.h new file mode 100644 index 00000000..d1c38fc2 --- /dev/null +++ b/llama/addon/globals/addonProgress.h @@ -0,0 +1,15 @@ +#pragma once +#include "napi.h" + +struct addon_progress_event { + public: + const float progress; +}; + +using AddonThreadSafeProgressCallbackFunctionContext = Napi::Reference; +void addonCallJsProgressCallback( + Napi::Env env, Napi::Function callback, AddonThreadSafeProgressCallbackFunctionContext* context, addon_progress_event* data +); +using AddonThreadSafeProgressEventCallbackFunction = + Napi::TypedThreadSafeFunction; + diff --git a/llama/addon/globals/getGpuInfo.cpp b/llama/addon/globals/getGpuInfo.cpp new file mode 100644 index 00000000..f3a67185 --- /dev/null +++ b/llama/addon/globals/getGpuInfo.cpp @@ -0,0 +1,108 @@ +#include "getGpuInfo.h" +#include "addonLog.h" + +#ifdef GPU_INFO_USE_CUDA +# include "../../gpuInfo/cuda-gpu-info.h" +#endif +#ifdef GPU_INFO_USE_VULKAN +# include "../../gpuInfo/vulkan-gpu-info.h" +#endif +#ifdef GPU_INFO_USE_METAL +# include "../../gpuInfo/metal-gpu-info.h" +#endif + + +#ifdef GPU_INFO_USE_CUDA +void logCudaError(const char* message) { + addonLlamaCppLogCallback(GGML_LOG_LEVEL_ERROR, (std::string("CUDA error: ") + std::string(message)).c_str(), nullptr); +} +#endif +#ifdef GPU_INFO_USE_VULKAN +void logVulkanWarning(const char* message) { + addonLlamaCppLogCallback(GGML_LOG_LEVEL_WARN, (std::string("Vulkan warning: ") + std::string(message)).c_str(), nullptr); +} +#endif + +Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { + uint64_t total = 0; + uint64_t used = 0; + +#ifdef GPU_INFO_USE_CUDA + size_t cudaDeviceTotal = 0; + size_t cudaDeviceUsed = 0; + bool cudeGetInfoSuccess = gpuInfoGetTotalCudaDevicesInfo(&cudaDeviceTotal, &cudaDeviceUsed, logCudaError); + + if (cudeGetInfoSuccess) { + total += cudaDeviceTotal; + used += cudaDeviceUsed; + } +#endif + +#ifdef GPU_INFO_USE_VULKAN + uint64_t vulkanDeviceTotal = 0; + uint64_t vulkanDeviceUsed = 0; + const bool vulkanDeviceSupportsMemoryBudgetExtension = gpuInfoGetTotalVulkanDevicesInfo(&vulkanDeviceTotal, &vulkanDeviceUsed, logVulkanWarning); + + if (vulkanDeviceSupportsMemoryBudgetExtension) { + total += vulkanDeviceTotal; + used += vulkanDeviceUsed; + } +#endif + +#ifdef GPU_INFO_USE_METAL + uint64_t metalDeviceTotal = 0; + uint64_t metalDeviceUsed = 0; + getMetalGpuInfo(&metalDeviceTotal, &metalDeviceUsed); + + total += metalDeviceTotal; + used += metalDeviceUsed; +#endif + + Napi::Object result = Napi::Object::New(info.Env()); + result.Set("total", Napi::Number::From(info.Env(), total)); + result.Set("used", Napi::Number::From(info.Env(), used)); + + return result; +} + +Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info) { + std::vector deviceNames; + +#ifdef GPU_INFO_USE_CUDA + gpuInfoGetCudaDeviceNames(&deviceNames, logCudaError); +#endif + +#ifdef GPU_INFO_USE_VULKAN + gpuInfoGetVulkanDeviceNames(&deviceNames, logVulkanWarning); +#endif + +#ifdef GPU_INFO_USE_METAL + getMetalGpuDeviceNames(&deviceNames); +#endif + + Napi::Object result = Napi::Object::New(info.Env()); + + Napi::Array deviceNamesNapiArray = Napi::Array::New(info.Env(), deviceNames.size()); + for (size_t i = 0; i < deviceNames.size(); ++i) { + deviceNamesNapiArray[i] = Napi::String::New(info.Env(), deviceNames[i]); + } + result.Set("deviceNames", deviceNamesNapiArray); + + return result; +} + +Napi::Value getGpuType(const Napi::CallbackInfo& info) { +#ifdef GPU_INFO_USE_CUDA + return Napi::String::New(info.Env(), "cuda"); +#endif + +#ifdef GPU_INFO_USE_VULKAN + return Napi::String::New(info.Env(), "vulkan"); +#endif + +#ifdef GPU_INFO_USE_METAL + return Napi::String::New(info.Env(), "metal"); +#endif + + return info.Env().Undefined(); +} \ No newline at end of file diff --git a/llama/addon/globals/getGpuInfo.h b/llama/addon/globals/getGpuInfo.h new file mode 100644 index 00000000..c32de9d5 --- /dev/null +++ b/llama/addon/globals/getGpuInfo.h @@ -0,0 +1,6 @@ +#pragma once +#include "napi.h" + +Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info); +Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info); +Napi::Value getGpuType(const Napi::CallbackInfo& info); \ No newline at end of file diff --git a/src/ChatWrapper.ts b/src/ChatWrapper.ts index 7e137c1e..2838096b 100644 --- a/src/ChatWrapper.ts +++ b/src/ChatWrapper.ts @@ -1,9 +1,11 @@ import { - ChatHistoryItem, ChatModelFunctionCall, ChatModelFunctions, ChatModelResponse, ChatWrapperGenerateContextStateOptions, - ChatWrapperGeneratedContextState, ChatWrapperSettings + ChatHistoryItem, ChatModelFunctionCall, ChatModelFunctions, ChatModelResponse, ChatWrapperCheckModelCompatibilityParams, + ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState, ChatWrapperGenerateInitialHistoryOptions, ChatWrapperSettings } from "./types.js"; import {LlamaText, SpecialTokensText} from "./utils/LlamaText.js"; import {ChatModelFunctionsDocumentationGenerator} from "./chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.js"; +import {jsonDumps} from "./chatWrappers/utils/jsonDumps.js"; +import {defaultChatSystemPrompt} from "./config.js"; export abstract class ChatWrapper { public static defaultSettings: ChatWrapperSettings = { @@ -105,7 +107,7 @@ export abstract class ChatWrapper { ( params === undefined ? "" - : JSON.stringify(params) + : jsonDumps(params) ), this.settings.functions.call.suffix ]); @@ -120,7 +122,7 @@ export abstract class ChatWrapper { return value .replaceAll("{{functionName}}", functionName) - .replaceAll("{{functionParams}}", functionParams === undefined ? "" : JSON.stringify(functionParams)); + .replaceAll("{{functionParams}}", functionParams === undefined ? "" : jsonDumps(functionParams)); }); } @@ -129,7 +131,7 @@ export abstract class ChatWrapper { ( result === undefined ? "void" - : JSON.stringify(result) + : jsonDumps(result) ), resolveParameters(this.settings.functions.result.suffix) ]); @@ -154,6 +156,9 @@ export abstract class ChatWrapper { continue; } + if (response.startsNewChunk) + addFunctionCalls(); + pendingFunctionCalls.push(response); } @@ -181,9 +186,9 @@ export abstract class ChatWrapper { "Calling any of the provided functions can be done like this:", this.generateFunctionCall("getSomeInfo", {someKey: "someValue"}), "", - "Note that the || prefix is mandatory", + "Note that the || prefix is mandatory.", "The assistant does not inform the user about using functions and does not explain anything before calling a function.", - "After calling a function, the raw result appears afterwards and is not part of the conversation", + "After calling a function, the raw result appears afterwards and is not part of the conversation.", "To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax." ]); } @@ -209,10 +214,24 @@ export abstract class ChatWrapper { return res; } + public generateInitialChatHistory({ + systemPrompt = defaultChatSystemPrompt + }: ChatWrapperGenerateInitialHistoryOptions): ChatHistoryItem[] { + return [{ + type: "system", + text: LlamaText(systemPrompt ?? defaultChatSystemPrompt).toJSON() + }]; + } + /** @internal */ public static _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): Record[] { return [{}] satisfies Partial, object>>[]; } + + /** @internal */ // eslint-disable-next-line @typescript-eslint/no-unused-vars + public static _checkModelCompatibility(options: ChatWrapperCheckModelCompatibilityParams): boolean { + return true; + } } type FirstItemOfTupleOrFallback = T extends [infer U, ...any[]] ? U : Fallback; diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index b71fb3b9..9d13ebdb 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -11,9 +11,13 @@ export type BindingModule = { useMlock?: boolean, checkTensors?: boolean, onLoadProgress?(loadPercentage: number): void, - hasLoadAbortSignal?: boolean + hasLoadAbortSignal?: boolean, + overridesList?: Array<[key: string, value: number | bigint | boolean | string, type: 0 | 1 | undefined]> }): AddonModel }, + AddonModelLora: { + new (model: AddonModel, filePath: string): AddonModelLora + }, AddonContext: { new (model: AddonModel, params: { seed?: number, @@ -29,7 +33,7 @@ export type BindingModule = { AddonGrammar: { new (grammarPath: string, params?: { addonExports?: BindingModule, - printGrammar?: boolean + debugPrintGrammar?: boolean }): AddonGrammar }, AddonGrammarEvaluationState: { @@ -66,7 +70,7 @@ export type BindingModule = { export type AddonModel = { init(): Promise, - loadLora(loraFilePath: string, scale: number, threads: number, baseModelPath?: string): Promise, + loadLora(lora: AddonModelLora): Promise, abortActiveModelLoad(): void, dispose(): Promise, tokenize(text: string, specialTokens: boolean): Uint32Array, @@ -128,7 +132,8 @@ export type AddonContext = { canBeNextTokenForGrammarEvaluationState(grammarEvaluationState: AddonGrammarEvaluationState, token: Token): boolean, getEmbedding(inputTokensLength: number): Float64Array, getStateSize(): number, - printTimings(): void + printTimings(): void, + setLora(lora: AddonModelLora, scale: number): void }; export type BatchLogitIndex = number & { @@ -143,6 +148,13 @@ export type AddonGrammarEvaluationState = "AddonGrammarEvaluationState" & { __brand: never }; +export type AddonModelLora = { + usages: number, + readonly filePath: string, + readonly disposed: boolean, + dispose(): Promise +}; + export type ModelTypeDescription = `${AddonModelArchName} ${AddonModelTypeName} ${AddonModelFileTypeName}`; export type AddonModelArchName = "unknown" | "llama" | "falcon" | "gpt2" | "gptj" | "gptneox" | "mpt" | "baichuan" | "starcoder" | "persimmon" | "refact" | "bloom" | "stablelm"; diff --git a/src/bindings/getLlama.ts b/src/bindings/getLlama.ts index 3938d638..f59a4c68 100644 --- a/src/bindings/getLlama.ts +++ b/src/bindings/getLlama.ts @@ -114,8 +114,9 @@ export type LlamaOptions = { /** * Pad the available VRAM for the memory size calculations, as these calculations are not always accurate. * Recommended to ensure stability. + * This only affects the calculations of `"auto"` in function options and is not reflected in the `getVramState` function. * - * Defaults to `1.5%` of the total VRAM or 300MB, whichever is lower. + * Defaults to `6%` of the total VRAM or 1GB, whichever is lower. * Set to `0` to disable. */ vramPadding?: number | ((totalVram: number) => number), @@ -168,7 +169,7 @@ export type LastBuildOptions = { * Recommended to ensure stability. * This only affects the calculations of `"auto"` in function options and is not reflected in the `getVramState` function. * - * Defaults to `6%` of the total VRAM or 300MB, whichever is lower. + * Defaults to `6%` of the total VRAM or 1GB, whichever is lower. * Set to `0` to disable. */ vramPadding?: number | ((totalVram: number) => number), @@ -186,7 +187,7 @@ export type LastBuildOptions = { export const getLlamaFunctionName = "getLlama"; -export const defaultLlamaVramPadding = (totalVram: number) => Math.floor(Math.min(totalVram * 0.06, 300 * 1024 * 1024)); +export const defaultLlamaVramPadding = (totalVram: number) => Math.floor(Math.min(totalVram * 0.06, 1024 * 1024 * 1024)); const defaultBuildOption: Exclude = runningInElectron ? "never" : "auto"; diff --git a/src/chatWrappers/FunctionaryChatWrapper.ts b/src/chatWrappers/FunctionaryChatWrapper.ts index 52a43a98..f942bed5 100644 --- a/src/chatWrappers/FunctionaryChatWrapper.ts +++ b/src/chatWrappers/FunctionaryChatWrapper.ts @@ -3,8 +3,9 @@ import { ChatHistoryItem, ChatModelFunctions, ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState, ChatWrapperSettings, isChatModelResponseFunctionCall } from "../types.js"; -import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {LlamaText, SpecialToken, SpecialTokensText} from "../utils/LlamaText.js"; import {ChatModelFunctionsDocumentationGenerator} from "./utils/ChatModelFunctionsDocumentationGenerator.js"; +import {jsonDumps} from "./utils/jsonDumps.js"; // source: https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v2.txt export class FunctionaryChatWrapper extends ChatWrapper { @@ -161,6 +162,9 @@ export class FunctionaryChatWrapper extends ChatWrapper { ]) ); } else if (isChatModelResponseFunctionCall(response)) { + if (response.startsNewChunk) + addPendingFunctions(); + pendingFunctionCalls.push( response.rawCall != null ? LlamaText.fromJSON(response.rawCall) @@ -170,7 +174,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { new SpecialTokensText("\n"), response.params === undefined ? "" - : JSON.stringify(response.params) + : jsonDumps(response.params) ]) ); pendingFunctionResults.push( @@ -180,7 +184,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { new SpecialTokensText("\n"), response.result === undefined ? "" // "void" - : JSON.stringify(response.result), + : jsonDumps(response.result), new SpecialToken("EOT") ]) ); @@ -314,7 +318,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { new SpecialTokensText("<|content|>"), response.params === undefined ? "" - : JSON.stringify(response.params) + : jsonDumps(response.params) ]) ); pendingFunctionResults.push( @@ -325,7 +329,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { new SpecialTokensText("<|content|>"), response.result === undefined ? "" // "void" - : JSON.stringify(response.result) + : jsonDumps(response.result) ]) ); } else diff --git a/src/chatWrappers/GemmaChatWrapper.ts b/src/chatWrappers/GemmaChatWrapper.ts index fe937d1b..fb24520c 100644 --- a/src/chatWrappers/GemmaChatWrapper.ts +++ b/src/chatWrappers/GemmaChatWrapper.ts @@ -80,6 +80,7 @@ export class GemmaChatWrapper extends ChatWrapper { flush(); const contextText = LlamaText( + new SpecialToken("BOS"), resultItems.map(({user, model}, index) => { const isLastItem = index === resultItems.length - 1; diff --git a/src/chatWrappers/Llama3ChatWrapper.ts b/src/chatWrappers/Llama3ChatWrapper.ts index 20460e38..a095a4f8 100644 --- a/src/chatWrappers/Llama3ChatWrapper.ts +++ b/src/chatWrappers/Llama3ChatWrapper.ts @@ -8,7 +8,7 @@ import {ChatModelFunctionsDocumentationGenerator} from "./utils/ChatModelFunctio // source: https://github.com/meta-llama/llama-recipes/blob/79aa70442e97c3127e53c2d22c54438c32adcf5e/README.md // source: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ export class Llama3ChatWrapper extends ChatWrapper { - public readonly wrapperName: string = "Llama3Chat"; + public readonly wrapperName: string = "Llama 3"; public override readonly settings: ChatWrapperSettings; @@ -174,7 +174,6 @@ export class Llama3ChatWrapper extends ChatWrapper { ); } - // void (item satisfies never); return LlamaText(res); }) ); @@ -211,9 +210,9 @@ export class Llama3ChatWrapper extends ChatWrapper { "Calling any of the provided functions can be done like this:", this.generateFunctionCall("getSomeInfo", {someKey: "someValue"}), "", - "Note that the || prefix is mandatory", + "Note that the || prefix is mandatory.", "The assistant does not inform the user about using functions and does not explain anything before calling a function.", - "After calling a function, the raw result appears afterwards and is not part of the conversation", + "After calling a function, the raw result appears afterwards and is not part of the conversation.", "To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax." ]); } diff --git a/src/chatWrappers/Llama3_1ChatWrapper.ts b/src/chatWrappers/Llama3_1ChatWrapper.ts new file mode 100644 index 00000000..85ff5ffd --- /dev/null +++ b/src/chatWrappers/Llama3_1ChatWrapper.ts @@ -0,0 +1,306 @@ +import {ChatWrapper} from "../ChatWrapper.js"; +import { + ChatHistoryItem, ChatModelFunctions, ChatSystemMessage, ChatWrapperCheckModelCompatibilityParams, + ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState, ChatWrapperGenerateInitialHistoryOptions, ChatWrapperSettings +} from "../types.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {defaultChatSystemPrompt} from "../config.js"; +import {ChatModelFunctionsDocumentationGenerator} from "./utils/ChatModelFunctionsDocumentationGenerator.js"; +import {jsonDumps} from "./utils/jsonDumps.js"; + +// source: https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1 +export class Llama3_1ChatWrapper extends ChatWrapper { + public readonly wrapperName: string = "Llama 3.1"; + + public readonly cuttingKnowledgeDate?: Date | null; + public readonly todayDate: Date | null; + + public override readonly settings: ChatWrapperSettings = { + supportsSystemMessages: true, + functions: { + call: { + optionalPrefixSpace: true, + prefix: LlamaText(new SpecialTokensText("")), + suffix: LlamaText(new SpecialTokensText("<|eom_id|>")) + }, + result: { + prefix: LlamaText(new SpecialTokensText("\n<|start_header_id|>ipython<|end_header_id|>\n\n")), + suffix: LlamaText(new SpecialToken("EOT"), new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n")) + } + } + }; + + /** + * @param options + */ + public constructor({ + cuttingKnowledgeDate = new Date("2023-12-01T00:00:00Z"), + todayDate = new Date() + }: { + /** + * Set to `null` to disable + * @default December 2023 + */ + cuttingKnowledgeDate?: Date | number | string | null, + + /** + * Set to `null` to disable + * @default current date + */ + todayDate?: Date | number | string | null + } = {}) { + super(); + + this.cuttingKnowledgeDate = cuttingKnowledgeDate == null + ? null + : new Date(cuttingKnowledgeDate); + this.todayDate = todayDate == null + ? null + : new Date(todayDate); + } + + public override addAvailableFunctionsSystemMessageToHistory( + history: readonly ChatHistoryItem[], + availableFunctions?: ChatModelFunctions, { + documentParams = true + }: { + documentParams?: boolean + } = {} + ) { + const availableFunctionNames = Object.keys(availableFunctions ?? {}); + + if (availableFunctions == null || availableFunctionNames.length === 0) + return history; + + const res = history.slice(); + + const functionsSystemMessage: ChatSystemMessage = { + type: "system", + text: this.generateAvailableFunctionsSystemText(availableFunctions, {documentParams}).toJSON() + }; + + if (res.length >= 2 && res[0].type === "system" && res[1].type === "system") + res.splice(1, 0, functionsSystemMessage); + else + res.unshift({ + type: "system", + text: this.generateAvailableFunctionsSystemText(availableFunctions, {documentParams}).toJSON() + }); + + return res; + } + + public override generateContextState({ + chatHistory, availableFunctions, documentFunctionParams + }: ChatWrapperGenerateContextStateOptions): ChatWrapperGeneratedContextState { + const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(chatHistory, availableFunctions, { + documentParams: documentFunctionParams + }); + + const resultItems: Array<{ + system: LlamaText | null, + user: LlamaText | null, + model: LlamaText | null + }> = []; + + let systemTexts: LlamaText[] = []; + let userTexts: LlamaText[] = []; + let modelTexts: LlamaText[] = []; + let currentAggregateFocus: "system" | "user" | "model" | null = null; + + function flush() { + if (systemTexts.length > 0 || userTexts.length > 0 || modelTexts.length > 0) + resultItems.push({ + system: systemTexts.length === 0 + ? null + : LlamaText.joinValues("\n\n", systemTexts), + user: userTexts.length === 0 + ? null + : LlamaText.joinValues("\n\n", userTexts), + model: modelTexts.length === 0 + ? null + : LlamaText.joinValues("\n\n", modelTexts) + }); + + systemTexts = []; + userTexts = []; + modelTexts = []; + } + + for (const item of historyWithFunctions) { + if (item.type === "system") { + if (currentAggregateFocus !== "system") + flush(); + + currentAggregateFocus = "system"; + systemTexts.push(LlamaText.fromJSON(item.text)); + } else if (item.type === "user") { + if (currentAggregateFocus !== "user") + flush(); + + currentAggregateFocus = "user"; + userTexts.push(LlamaText(item.text)); + } else if (item.type === "model") { + if (currentAggregateFocus !== "model") + flush(); + + currentAggregateFocus = "model"; + modelTexts.push(this.generateModelResponseText(item.response)); + } else + void (item satisfies never); + } + + flush(); + + const contextText = LlamaText( + new SpecialToken("BOS"), + resultItems.map((item, index) => { + const isLastItem = index === resultItems.length - 1; + const res: LlamaText[] = []; + + if (item.system != null) { + res.push( + LlamaText([ + new SpecialTokensText("<|start_header_id|>system<|end_header_id|>\n\n"), + item.system, + new SpecialToken("EOT") + ]) + ); + } + + if (item.user != null) { + res.push( + LlamaText([ + new SpecialTokensText("<|start_header_id|>user<|end_header_id|>\n\n"), + item.user, + new SpecialToken("EOT") + ]) + ); + } + + if (item.model != null) { + res.push( + LlamaText([ + new SpecialTokensText("<|start_header_id|>assistant<|end_header_id|>\n\n"), + item.model, + isLastItem + ? LlamaText([]) + : new SpecialToken("EOT") + ]) + ); + } + + return LlamaText(res); + }) + ); + + return { + contextText, + stopGenerationTriggers: [ + LlamaText(new SpecialToken("EOS")), + LlamaText(new SpecialToken("EOT")), + LlamaText(new SpecialTokensText("<|eot_id|>")), + LlamaText(new SpecialTokensText("<|end_of_text|>")), + LlamaText("<|eot_id|>"), + LlamaText("<|end_of_text|>") + ] + }; + } + + public override generateAvailableFunctionsSystemText(availableFunctions: ChatModelFunctions, {documentParams = true}: { + documentParams?: boolean + }) { + const functionsDocumentationGenerator = new ChatModelFunctionsDocumentationGenerator(availableFunctions); + + if (!functionsDocumentationGenerator.hasAnyFunctions) + return LlamaText([]); + + return LlamaText.joinValues("\n", [ + "You have access to the following functions:", + "", + functionsDocumentationGenerator.getLlama3_1FunctionSignatures({documentParams}), + "", + "", + "If you choose to call a function ONLY reply in the following format:", + "<{start_tag}={function_name}>{parameters}{end_tag}", + "where", + "", + "start_tag => ` a JSON dict with the function argument name as key and function argument value as value.", + "end_tag => ``", + "", + "Here is an example,", + LlamaText([ + new SpecialTokensText(""), + jsonDumps({"example_name": "example_value"}), + new SpecialTokensText("") + ]), + "", + "Reminder:", + "- Function calls MUST follow the specified format", + "- Only call one function at a time", + "- Put the entire function call reply on one line", + "- Always add your sources when using search results to answer the user query" + ]); + } + + public override generateInitialChatHistory({ + systemPrompt = defaultChatSystemPrompt + }: ChatWrapperGenerateInitialHistoryOptions): ChatHistoryItem[] { + const res: ChatHistoryItem[] = []; + + function formatDate(date: Date) { + const day = date.toLocaleDateString("en-US", {day: "numeric", timeZone: "UTC"}); + const month = date.toLocaleDateString("en-US", {month: "short", timeZone: "UTC"}); + const year = date.toLocaleDateString("en-US", {year: "numeric", timeZone: "UTC"}); + return `${day} ${month} ${year}`; + } + + const formatMonthDate = (date: Date) => { + const today = this.todayDate ?? new Date(); + if (today.getUTCMonth() === date.getUTCMonth() && today.getUTCFullYear() === date.getUTCFullYear()) + return formatDate(date); + + const month = date.toLocaleDateString("en-US", {month: "long", timeZone: "UTC"}); + const year = date.toLocaleDateString("en-US", {year: "numeric", timeZone: "UTC"}); + return `${month} ${year}`; + }; + + const lines: string[] = []; + + if (this.cuttingKnowledgeDate != null) + lines.push(`Cutting Knowledge Date: ${formatMonthDate(this.cuttingKnowledgeDate)}`); + + if (this.todayDate != null) + lines.push(`Today Date: ${formatDate(this.todayDate)}`); + + lines.push(""); + lines.push("# Tool Instructions"); + lines.push("- When looking for real time information use relevant functions if available"); + lines.push(""); + lines.push(""); + + res.push({ + type: "system", + text: LlamaText.joinValues("\n", lines).toJSON() + }, { + type: "system", + text: LlamaText(systemPrompt ?? defaultChatSystemPrompt).toJSON() + }); + + return res; + } + + /** @internal */ + public static override _checkModelCompatibility(options: ChatWrapperCheckModelCompatibilityParams): boolean { + if (options.tokenizer != null) { + const tokens = options.tokenizer("<|eom_id|>", true, "trimLeadingSpace"); + return tokens.length === 1 && options.tokenizer.isSpecialToken(tokens[0]); + } + + return true; + } +} diff --git a/src/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.ts b/src/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.ts index cc9025cf..56eae902 100644 --- a/src/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.ts +++ b/src/chatWrappers/utils/ChatModelFunctionsDocumentationGenerator.ts @@ -1,5 +1,6 @@ import {ChatModelFunctions} from "../../types.js"; import {getTypeScriptTypeStringForGbnfJsonSchema} from "../../utils/getTypeScriptTypeStringForGbnfJsonSchema.js"; +import {jsonDumps} from "./jsonDumps.js"; /** * Generate documentation about the functions that are available for a model to call. @@ -16,7 +17,7 @@ export class ChatModelFunctionsDocumentationGenerator { /** * Example: - * ```typescript + * ```ts * // Retrieve the current date * function getDate(); * @@ -58,7 +59,7 @@ export class ChatModelFunctionsDocumentationGenerator { /** * Example: - * ```typescript + * ```ts * // Retrieve the current date * type getDate = () => any; * @@ -102,4 +103,48 @@ export class ChatModelFunctionsDocumentationGenerator { }) .join("\n\n"); } + + /* eslint-disable max-len */ + /** + * Example: + * ``` + * Use the function 'getDate' to: Retrieve the current date + * {"name": "getDate", "description": "Retrieve the current date"} + * + * Use the function 'getTime' to: Retrieve the current time + * {"name": "getTime", "description": "Retrieve the current time", "parameters": {"type": "object", "properties": {"hours": {"enum": ["24", "12"]}, "seconds": {"type": "boolean"}}}} + * ``` + * @param options + * @param [options.documentParams] - Whether to document the parameters of the functions + */ + public getLlama3_1FunctionSignatures({documentParams = true}: {documentParams?: boolean} = {}) { + const chatModelFunctions = this.chatModelFunctions; + + if (!this.hasAnyFunctions || chatModelFunctions == null) + return ""; + + const functionNames = Object.keys(chatModelFunctions); + + return functionNames + .map((functionName) => { + const functionDefinition = chatModelFunctions[functionName]; + let res = `Use the function '${functionName}'`; + + const addDescription = functionDefinition?.description != null && functionDefinition.description.trim() !== ""; + if (addDescription) + res += " to: " + functionDefinition.description.split("\n").join("\n// ") + "\n"; + else + res += ".\n"; + + res += jsonDumps({ + name: functionName, + ...(addDescription ? {description: functionDefinition.description} : {}), + ...(documentParams && functionDefinition?.params != null ? {parameters: functionDefinition.params} : {}) + }); + + return res; + }) + .join("\n\n"); + } + /* eslint-enable max-len */ } diff --git a/src/chatWrappers/utils/jsonDumps.ts b/src/chatWrappers/utils/jsonDumps.ts new file mode 100644 index 00000000..43e47bf0 --- /dev/null +++ b/src/chatWrappers/utils/jsonDumps.ts @@ -0,0 +1,19 @@ +/** + * Like `JSON.stringify` but results in a value formatted in the format that Python produces when using `json.dumps(value)`. + * + * We need to format results this way since this is what many models use in their training data, + * so this is what many models expect to have in their context state. + */ +export function jsonDumps(result: any) { + return JSON.stringify(result, null, 1) + .split("\n") + .map((line) => { + line = line.trim(); + + if (line.endsWith(",")) + line += " "; + + return line; + }) + .join(""); +} diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts index 50900fa6..316150b9 100644 --- a/src/chatWrappers/utils/resolveChatWrapper.ts +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -10,13 +10,14 @@ import {GemmaChatWrapper} from "../GemmaChatWrapper.js"; import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js"; import {TemplateChatWrapper} from "../generic/TemplateChatWrapper.js"; import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; +import {Llama3_1ChatWrapper} from "../Llama3_1ChatWrapper.js"; import {Tokenizer} from "../../types.js"; import {isJinjaTemplateEquivalentToSpecializedChatWrapper} from "./isJinjaTemplateEquivalentToSpecializedChatWrapper.js"; import type {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; export const specializedChatWrapperTypeNames = Object.freeze([ - "general", "llama3Chat", "llama2Chat", "alpacaChat", "functionary", "chatML", "falconChat", "gemma" + "general", "llama3.1", "llama3", "llama2Chat", "alpacaChat", "functionary", "chatML", "falconChat", "gemma" ] as const); export type SpecializedChatWrapperTypeName = (typeof specializedChatWrapperTypeNames)[number]; @@ -34,7 +35,8 @@ export type ResolvableChatWrapperTypeName = (typeof resolvableChatWrapperTypeNam const chatWrappers = { "general": GeneralChatWrapper, - "llama3Chat": Llama3ChatWrapper, + "llama3.1": Llama3_1ChatWrapper, + "llama3": Llama3ChatWrapper, "llama2Chat": Llama2ChatWrapper, "alpacaChat": AlpacaChatWrapper, "functionary": FunctionaryChatWrapper, @@ -113,6 +115,28 @@ export function resolveChatWrapper({ }); } + function getModelLinageNames(): string[][] { + const res: string[][] = []; + + if (fileInfo == null) + return res; + + const currentModelInfo = [fileInfo.metadata?.general?.name, fileInfo.metadata?.general?.basename] + .filter((v): v is string => v != null); + if (currentModelInfo.length > 0) + res.push(currentModelInfo); + + if (typeof fileInfo.metadata?.general?.base_model?.count === "number") { + for (let i = 0; i < fileInfo.metadata.general.base_model.count; i++) { + const baseModel = fileInfo.metadata.general.base_model[String(i) as `${bigint}`]; + if (baseModel?.name != null) + res.push([baseModel.name]); + } + } + + return res; + } + if (type !== "auto" && type != null) { if (isTemplateChatWrapperType(type)) { const Wrapper = chatWrappers[type]; @@ -172,6 +196,14 @@ export function resolveChatWrapper({ const Wrapper = chatWrappers[specializedChatWrapperTypeName]; const wrapperSettings = customWrapperSettings?.[specializedChatWrapperTypeName]; + const isCompatible = Wrapper._checkModelCompatibility({ + tokenizer, + fileInfo + }); + + if (!isCompatible) + continue; + const testOptionConfigurations = Wrapper._getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate?.() ?? []; if (testOptionConfigurations.length === 0) testOptionConfigurations.push({} as any); @@ -209,12 +241,23 @@ export function resolveChatWrapper({ return createSpecializedChatWrapper(Llama2ChatWrapper, { addSpaceBeforeEos: modelJinjaTemplate.includes("' ' + eos_token") }); - else if (modelJinjaTemplate.includes("<|start_header_id|>") && modelJinjaTemplate.includes("<|end_header_id|>")) - return createSpecializedChatWrapper(Llama3ChatWrapper); - else if (modelJinjaTemplate.includes("")) + else if (modelJinjaTemplate.includes("<|start_header_id|>") && modelJinjaTemplate.includes("<|end_header_id|>")) { + if (Llama3_1ChatWrapper._checkModelCompatibility({tokenizer, fileInfo})) + return createSpecializedChatWrapper(Llama3_1ChatWrapper); + else + return createSpecializedChatWrapper(Llama3ChatWrapper); + } else if (modelJinjaTemplate.includes("")) return createSpecializedChatWrapper(GemmaChatWrapper); } + for (const modelNames of getModelLinageNames()) { + if (includesText(modelNames, ["llama 3.1", "llama-3.1", "llama3.1"]) && Llama3_1ChatWrapper._checkModelCompatibility({tokenizer, fileInfo})) + return createSpecializedChatWrapper(Llama3_1ChatWrapper); + else if (includesText(modelNames, ["llama 3", "llama-3", "llama3"])) + return createSpecializedChatWrapper(Llama3ChatWrapper); + } + + if (filename != null) { const {name, subType, fileType, otherInfo} = parseModelFileName(filename); @@ -285,6 +328,25 @@ export function isTemplateChatWrapperType(type: string): type is TemplateChatWra return templateChatWrapperTypeNames.includes(type as any); } +function includesText( + value: string | string[] | null | undefined, + textToCheckFor: string | string[], + strictCase: boolean = false +): boolean { + if (value instanceof Array) + return value.some((v) => includesText(v, textToCheckFor, strictCase)); + else if (typeof value !== "string") + return false; + + if (textToCheckFor instanceof Array) + return textToCheckFor.some((t) => includesText(value, t, strictCase)); + + if (strictCase) + return value.includes(textToCheckFor); + + return value.toLowerCase().includes(textToCheckFor.toLowerCase()); +} + // this is needed because TypeScript guards don't work automatically with class references function isClassReference(value: any, classReference: T): value is T { return value === classReference; diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index 0c799d50..4e2cb84e 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -33,7 +33,7 @@ type ChatCommand = { header?: string[], gpu?: BuildGpu | "auto", systemInfo: boolean, - systemPrompt: string, + systemPrompt?: string, systemPromptFile?: string, prompt?: string, promptFile?: string, @@ -108,8 +108,6 @@ export const ChatCommand: CommandModule = { .option("systemPrompt", { alias: "s", type: "string", - default: defaultChatSystemPrompt, - defaultDescription: " ", description: "System prompt to use against the model" + (isInDocumentationMode ? "" : (". [default value: " + defaultChatSystemPrompt.split("\n").join(" ") + "]")) @@ -550,8 +548,8 @@ async function RunChat({ : maxTokens <= 0 ? undefined : maxTokens, - onToken(chunk) { - let text = nextPrintLeftovers + model.detokenize(chunk); + onTextChunk(chunk) { + let text = nextPrintLeftovers + chunk; nextPrintLeftovers = ""; if (trimWhitespace) { diff --git a/src/cli/commands/CompleteCommand.ts b/src/cli/commands/CompleteCommand.ts index 58259f2e..7f52f539 100644 --- a/src/cli/commands/CompleteCommand.ts +++ b/src/cli/commands/CompleteCommand.ts @@ -416,8 +416,8 @@ async function RunCompletion({ : maxTokens <= 0 ? undefined : maxTokens, - onToken(chunk) { - process.stdout.write(model.detokenize(chunk)); + onTextChunk(chunk) { + process.stdout.write(chunk); } }); process.stdout.write(endColor); diff --git a/src/cli/commands/InfillCommand.ts b/src/cli/commands/InfillCommand.ts index 60adcb25..157bd9ef 100644 --- a/src/cli/commands/InfillCommand.ts +++ b/src/cli/commands/InfillCommand.ts @@ -461,8 +461,8 @@ async function RunInfill({ : maxTokens <= 0 ? undefined : maxTokens, - onToken(chunk) { - process.stdout.write(model.detokenize(chunk)); + onTextChunk(chunk) { + process.stdout.write(chunk); } }); process.stdout.write(endColor); diff --git a/src/cli/recommendedModels.ts b/src/cli/recommendedModels.ts index e37beffa..9560a660 100644 --- a/src/cli/recommendedModels.ts +++ b/src/cli/recommendedModels.ts @@ -1,66 +1,86 @@ import {ModelRecommendation} from "./utils/resolveModelRecommendationFileOptions.js"; export const recommendedModels: ModelRecommendation[] = [{ - name: "Llama 3 8B", + name: "Llama 3.1 8B", abilities: ["chat", "complete", "functionCalling"], - description: "Llama 3 model was created by Meta and is optimized for an assistant-like chat use cases.\n" + + description: "Llama 3.1 model was created by Meta and is optimized for an assistant-like chat use cases, with support for function calling.\n" + "This is the 8 billion parameters version of the model.", fileOptions: [{ huggingFace: { - model: "mradermacher/Meta-Llama-3-8B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-8B-Instruct.Q8_0.gguf" + file: "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-8B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-8B-Instruct.Q6_K.gguf" + file: "Meta-Llama-3.1-8B-Instruct-Q6_K_L.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-8B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf" + file: "Meta-Llama-3.1-8B-Instruct-Q5_K_L.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-8B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-8B-Instruct.Q4_K_S.gguf" + file: "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf" } }] }, { - name: "Llama 3 70B", + name: "Llama 3.1 70B", abilities: ["chat", "complete", "functionCalling"], - description: "Llama 3 model was created by Meta and is optimized for an assistant-like chat use cases.\n" + + description: "Llama 3.1 model was created by Meta and is optimized for an assistant-like chat use cases, with support for function calling.\n" + "This is the 70 billion parameters version of the model. " + "You need a GPU with a lot of VRAM to use this version.", fileOptions: [{ huggingFace: { - model: "mradermacher/Meta-Llama-3-70B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-70B-Instruct.Q8_0.gguf.part1of2" + file: "Meta-Llama-3.1-70B-Instruct-Q8_0/Meta-Llama-3.1-70B-Instruct-Q8_0-00001-of-00002.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-70B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-70B-Instruct.Q6_K.gguf.part1of2" + file: "Meta-Llama-3.1-70B-Instruct-Q6_K_L/Meta-Llama-3.1-70B-Instruct-Q6_K_L-00001-of-00002.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-70B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-70B-Instruct.Q4_K_M.gguf" + file: "Meta-Llama-3.1-70B-Instruct-Q5_K_L/Meta-Llama-3.1-70B-Instruct-Q5_K_L-00001-of-00002.gguf" } }, { huggingFace: { - model: "mradermacher/Meta-Llama-3-70B-Instruct-GGUF", + model: "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", branch: "main", - file: "Meta-Llama-3-70B-Instruct.Q4_K_S.gguf" + file: "Meta-Llama-3.1-70B-Instruct-Q4_K_M.gguf" + } + }, { + huggingFace: { + model: "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", + branch: "main", + file: "Meta-Llama-3.1-70B-Instruct-IQ4_XS.gguf" + } + }] +}, { + name: "Llama 3.1 405B", + abilities: ["chat", "complete", "functionCalling"], + description: "Llama 3.1 model was created by Meta and is optimized for an assistant-like chat use cases, with support for function calling.\n" + + "This is the 405 billion parameters version of the model, and its capabilities are comparable and sometimes even surpass GPT-4o and Claude 3.5 Sonnet.\n" + + "You need a GPU with a lot of VRAM to use this version of Llama 3.1.", + + fileOptions: [{ + huggingFace: { + model: "mradermacher/Meta-Llama-3.1-405B-Instruct-GGUF", + branch: "main", + file: "Meta-Llama-3.1-405B-Instruct.Q3_K_L.gguf.part1of5" } }] }, { @@ -71,15 +91,15 @@ export const recommendedModels: ModelRecommendation[] = [{ fileOptions: [{ huggingFace: { - model: "microsoft/Phi-3-mini-4k-instruct-gguf", + model: "bartowski/Phi-3.1-mini-4k-instruct-GGUF", branch: "main", - file: "Phi-3-mini-4k-instruct-fp16.gguf" + file: "Phi-3.1-mini-4k-instruct-Q8_0.gguf" } }, { huggingFace: { - model: "microsoft/Phi-3-mini-4k-instruct-gguf", + model: "bartowski/Phi-3.1-mini-4k-instruct-GGUF", branch: "main", - file: "Phi-3-mini-4k-instruct-q4.gguf" + file: "Phi-3.1-mini-4k-instruct-Q4_K_M.gguf" } }] }, { diff --git a/src/consts.ts b/src/consts.ts index afec36cf..ca2d7419 100644 --- a/src/consts.ts +++ b/src/consts.ts @@ -2,6 +2,7 @@ import isUnicodeSupported from "is-unicode-supported"; const unicodeSupported = isUnicodeSupported(); +export const maxRecentDetokenizerTokens = 3; export const UNKNOWN_UNICODE_CHAR = "\ufffd"; export const clockChar = unicodeSupported ? "\u25f7" diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 306ec52d..a8596749 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -12,13 +12,14 @@ import {LlamaText, LlamaTextJSON, SpecialToken} from "../../utils/LlamaText.js"; import {StopGenerationDetector} from "../../utils/StopGenerationDetector.js"; import {QueuedTokenRelease, QueuedTokenReleaseLock, TokenStreamRegulator} from "../../utils/TokenStreamRegulator.js"; import {EvaluationPriority} from "../LlamaContext/types.js"; -import {UNKNOWN_UNICODE_CHAR} from "../../consts.js"; +import {maxRecentDetokenizerTokens, UNKNOWN_UNICODE_CHAR} from "../../consts.js"; import {getQueuedTokensBeforeStopTrigger} from "../../utils/getQueuedTokensBeforeStopTrigger.js"; import {resolveChatWrapper} from "../../chatWrappers/utils/resolveChatWrapper.js"; import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; import {TokenBias} from "../TokenBias.js"; import {safeEventCallback} from "../../utils/safeEventCallback.js"; import {pushAll} from "../../utils/pushAll.js"; +import {resolveLastTokens} from "../../utils/resolveLastTokens.js"; import { eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy } from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js"; @@ -36,7 +37,20 @@ export type LlamaChatOptions = { }; export type LLamaChatGenerateResponseOptions = { + /** + * Called as the model generates a response with the generated text chunk. + * + * Useful for streaming the generated response as it's being generated. + */ + onTextChunk?: (text: string) => void, + + /** + * Called as the model generates a response with the generated tokens. + * + * Preferably, you'd want to use `onTextChunk` instead of this. + */ onToken?: (tokens: Token[]) => void, + signal?: AbortSignal, /** @@ -167,7 +181,20 @@ export type LLamaChatLoadAndCompleteUserMessageOptions["onTextChunk"], + + /** + * Called as the model generates a completion with the generated tokens. + * + * Preferably, you'd want to use `onTextChunk` instead of this. + */ onToken?: LLamaChatGenerateResponseOptions["onToken"], + signal?: LLamaChatGenerateResponseOptions["signal"], maxTokens?: LLamaChatGenerateResponseOptions["maxTokens"], temperature?: LLamaChatGenerateResponseOptions["temperature"], @@ -331,6 +358,7 @@ export class LlamaChat { options: LLamaChatGenerateResponseOptions = {} ): Promise> { const { + onTextChunk, onToken, signal, stopOnAbortSignal = false, @@ -361,6 +389,7 @@ export class LlamaChat { this._chatWrapper, history, { + onTextChunk, onToken, signal, stopOnAbortSignal, @@ -493,6 +522,7 @@ export class LlamaChat { const { initialUserPrompt = "", stopOnAbortSignal = false, + onTextChunk, onToken, signal, maxTokens = Math.min(256, Math.ceil(this.context.contextSize / 2)), @@ -527,6 +557,7 @@ export class LlamaChat { this._chatWrapper, history, { + onTextChunk, onToken, signal, stopOnAbortSignal, @@ -1169,6 +1200,7 @@ class GenerateResponseState["onTextChunk"]; private readonly onToken: LLamaChatGenerateResponseOptions["onToken"]; private readonly signal: LLamaChatGenerateResponseOptions["signal"]; private readonly stopOnAbortSignal: LLamaChatGenerateResponseOptions["stopOnAbortSignal"]; @@ -1265,6 +1297,7 @@ class GenerateResponseState | null = null; let mostExhaustiveTriggeredStopsLeftoverTokens: Token[] = []; + const lastTokensForDetokenizer = resolveLastTokens([ + this.contextWindowTokens, + this.ignoredStartTextTokens + ]); for (let i = 0; i < this.pendingTokens.length; i++) { this.ignoreStartTextDetector.recordGeneration({ - text: this.llamaChat.model.detokenize([this.pendingTokens[i]]), + text: this.llamaChat.model.detokenize([this.pendingTokens[i]], false, lastTokensForDetokenizer), tokens: [this.pendingTokens[i]], startNewChecks: i === 0, triggerMustStartWithGeneration: true }); + lastTokensForDetokenizer.push(this.pendingTokens[i]); if (this.ignoreStartTextDetector.hasTriggeredStops) { mostExhaustiveTriggeredStops = this.ignoreStartTextDetector.getTriggeredStops(); @@ -1649,6 +1688,7 @@ class GenerateResponseState 0) - this.onToken?.(this.pendingTokens.slice()); - - pushAll(this.res, this.pendingTokens); - pushAll(this.contextWindowsRes, this.pendingTokens); - this.pendingTokens.length = 0; + this.pushPendingTokensAndCallOnToken(); this.streamRegulator.clearQueue(); @@ -2147,7 +2198,7 @@ class GenerateResponseState 0) - this.onToken?.(this.pendingTokens.slice()); - - pushAll(this.res, this.pendingTokens); - pushAll(this.contextWindowsRes, this.pendingTokens); - this.pendingTokens.length = 0; + this.pushPendingTokensAndCallOnToken(); let modelResponse = this.llamaChat.model.detokenize(this.res); let contextWindowModelResponse = this.llamaChat.model.detokenize(this.contextWindowsRes); @@ -2319,8 +2369,12 @@ class GenerateResponseState 0) { - this.onToken?.(this.pendingTokens.slice()); - pushAll(this.res, this.pendingTokens); - pushAll(this.contextWindowsRes, this.pendingTokens); - this.pendingTokens.length = 0; - } + this.pushPendingTokensAndCallOnToken(); } } @@ -2423,4 +2472,26 @@ class GenerateResponseState = { + /** + * Called as the model generates a response with the generated text chunk. + * + * Useful for streaming the generated response as it's being generated. + */ + onTextChunk?: (text: string) => void, + + /** + * Called as the model generates a response with the generated tokens. + * + * Preferably, you'd want to use `onTextChunk` instead of this. + */ onToken?: (tokens: Token[]) => void, + signal?: AbortSignal, /** @@ -165,7 +178,20 @@ export type LLamaChatCompletePromptOptions = { */ stopOnAbortSignal?: LLamaChatPromptOptions["stopOnAbortSignal"], + /** + * Called as the model generates a completion with the generated text chunk. + * + * Useful for streaming the generated completion as it's being generated. + */ + onTextChunk?: LLamaChatPromptOptions["onTextChunk"], + + /** + * Called as the model generates a completion with the generated tokens. + * + * Preferably, you'd want to use `onTextChunk` instead of this. + */ onToken?: LLamaChatPromptOptions["onToken"], + signal?: LLamaChatPromptOptions["signal"], temperature?: LLamaChatPromptOptions["temperature"], minP?: LLamaChatPromptOptions["minP"], @@ -261,7 +287,7 @@ export class LlamaChatSession { public constructor({ contextSequence, chatWrapper = "auto", - systemPrompt = defaultChatSystemPrompt, + systemPrompt, forceAddSystemPrompt = false, autoDisposeSequence = true, contextShift @@ -282,10 +308,7 @@ export class LlamaChatSession { const chatWrapperSupportsSystemMessages = this._chat.chatWrapper.settings.supportsSystemMessages; if (chatWrapperSupportsSystemMessages == null || chatWrapperSupportsSystemMessages || forceAddSystemPrompt) - this._chatHistory = [{ - type: "system", - text: systemPrompt - }]; + this._chatHistory = this._chat.chatWrapper.generateInitialChatHistory({systemPrompt}); else this._chatHistory = []; @@ -348,6 +371,7 @@ export class LlamaChatSession { functions, documentFunctionParams, maxParallelFunctionCalls, + onTextChunk, onToken, signal, stopOnAbortSignal = false, @@ -368,8 +392,8 @@ export class LlamaChatSession { documentFunctionParams: documentFunctionParams as undefined, maxParallelFunctionCalls: maxParallelFunctionCalls as undefined, - onToken, signal, stopOnAbortSignal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty, - tokenBias, customStopTriggers + onTextChunk, onToken, signal, stopOnAbortSignal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, + repeatPenalty, tokenBias, customStopTriggers }); return responseText; @@ -383,6 +407,7 @@ export class LlamaChatSession { functions, documentFunctionParams, maxParallelFunctionCalls, + onTextChunk, onToken, signal, stopOnAbortSignal = false, @@ -411,6 +436,7 @@ export class LlamaChatSession { if (this._chat == null) throw new DisposedError(); + const supportsParallelFunctionCalling = this._chat.chatWrapper.settings.functions.parallelism != null; const abortController = wrapAbortSignal(signal); let lastEvaluation = this._lastEvaluation; let newChatHistory = appendUserMessageToChatHistory(this._chatHistory, prompt); @@ -448,6 +474,7 @@ export class LlamaChatSession { documentFunctionParams, maxParallelFunctionCalls, grammar: grammar as undefined, // this is a workaround to allow passing both `functions` and `grammar` + onTextChunk: safeEventCallback(onTextChunk), onToken: safeEventCallback(onToken), signal: abortController.signal, stopOnAbortSignal, @@ -545,6 +572,7 @@ export class LlamaChatSession { newContextWindowChatHistory = lastEvaluation.contextWindow; + let startNewChunk = supportsParallelFunctionCalling; for (const {functionCall, functionDefinition, functionCallResult} of functionCallResults) { newChatHistory = addFunctionCallToChatHistory({ chatHistory: newChatHistory, @@ -552,7 +580,8 @@ export class LlamaChatSession { functionDescription: functionDefinition.description, callParams: functionCall.params, callResult: functionCallResult, - rawCall: functionCall.raw + rawCall: functionCall.raw, + startsNewChunk: startNewChunk }); newContextWindowChatHistory = addFunctionCallToChatHistory({ @@ -561,8 +590,11 @@ export class LlamaChatSession { functionDescription: functionDefinition.description, callParams: functionCall.params, callResult: functionCallResult, - rawCall: functionCall.raw + rawCall: functionCall.raw, + startsNewChunk: startNewChunk }); + + startNewChunk = false; } lastEvaluation.cleanHistory = newChatHistory; @@ -653,6 +685,7 @@ export class LlamaChatSession { functions, documentFunctionParams, + onTextChunk, onToken, signal, temperature, @@ -686,6 +719,7 @@ export class LlamaChatSession { functions, documentFunctionParams, grammar, + onTextChunk, onToken, signal: abortController.signal, stopOnAbortSignal: true, @@ -776,14 +810,16 @@ function addFunctionCallToChatHistory({ functionDescription, callParams, callResult, - rawCall + rawCall, + startsNewChunk }: { chatHistory: ChatHistoryItem[], functionName: string, functionDescription?: string, callParams: any, callResult: any, - rawCall?: LlamaTextJSON + rawCall?: LlamaTextJSON, + startsNewChunk?: boolean }) { const newChatHistory = chatHistory.slice(); if (newChatHistory.length === 0 || newChatHistory[newChatHistory.length - 1].type !== "model") @@ -799,14 +835,19 @@ function addFunctionCallToChatHistory({ const modelResponse = newLastModelResponseItem.response.slice(); newLastModelResponseItem.response = modelResponse; - modelResponse.push({ + const functionCall: ChatModelFunctionCall = { type: "functionCall", name: functionName, description: functionDescription, params: callParams, result: callResult, rawCall - }); + }; + + if (startsNewChunk) + functionCall.startsNewChunk = true; + + modelResponse.push(functionCall); return newChatHistory; } diff --git a/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts b/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts index c6690e31..4f473983 100644 --- a/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts +++ b/src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts @@ -1,9 +1,7 @@ import {DisposeAggregator, DisposedError} from "lifecycle-utils"; -import {Token} from "../../../types.js"; import {getConsoleLogPrefix} from "../../../utils/getConsoleLogPrefix.js"; import {LruCache} from "../../../utils/LruCache.js"; import {safeEventCallback} from "../../../utils/safeEventCallback.js"; -import {pushAll} from "../../../utils/pushAll.js"; import type {LLamaChatCompletePromptOptions, LlamaChatSession} from "../LlamaChatSession.js"; export type LLamaChatPromptCompletionEngineOptions = { @@ -140,15 +138,15 @@ export class LlamaChatSessionPromptCompletionEngine { const currentAbortController = this._currentCompletionAbortController; const currentAbortSignal = this._currentCompletionAbortController.signal; - const currentCompletion: Token[] = []; + let currentCompletion: string = ""; void this._chatSession.completePrompt(promptToComplete, { ...this._completionOptions, stopOnAbortSignal: false, maxTokens: leftTokens, signal: currentAbortSignal, - onToken: (chunk) => { - pushAll(currentCompletion, chunk); - const completion = (existingCompletion ?? "") + this._chatSession.model.detokenize(currentCompletion); + onTextChunk: (chunk) => { + currentCompletion += chunk; + const completion = (existingCompletion ?? "") + currentCompletion; completionCache.putCompletion(prompt, completion); if (this._getCurrentCompletionCache() !== completionCache) { diff --git a/src/evaluator/LlamaCompletion.ts b/src/evaluator/LlamaCompletion.ts index 728bed7f..058fbbce 100644 --- a/src/evaluator/LlamaCompletion.ts +++ b/src/evaluator/LlamaCompletion.ts @@ -24,7 +24,20 @@ export type LlamaCompletionOptions = { }; export type LlamaCompletionGenerationOptions = { + /** + * Called as the model generates a completion with the generated text chunk. + * + * Useful for streaming the generated completion as it's being generated. + */ + onTextChunk?: (text: string) => void, + + /** + * Called as the model generates a completion with the generated tokens. + * + * Preferably, you'd want to use `onTextChunk` instead of this. + */ onToken?: (tokens: Token[]) => void, + signal?: AbortSignal, maxTokens?: number, @@ -210,6 +223,7 @@ export class LlamaCompletion { public async generateCompletionWithMeta( input: Token[] | string | LlamaText, { + onTextChunk, onToken, signal, maxTokens, @@ -290,6 +304,7 @@ export class LlamaCompletion { : this._sequence.context.contextSize - inputTokens.length; return await this._generateResponse(inputTokens, { + onTextChunk: safeEventCallback(onTextChunk), onToken: safeEventCallback(onToken), signal, maxTokens: resolvedMaxTokens, @@ -343,6 +358,7 @@ export class LlamaCompletion { prefixInput: Token[] | string | LlamaText, suffixInput: Token[] | string | LlamaText, { + onTextChunk, onToken, signal, maxTokens, @@ -473,6 +489,7 @@ export class LlamaCompletion { : this._sequence.context.contextSize - inputTokens.length; return await this._generateResponse(inputTokens, { + onTextChunk: safeEventCallback(onTextChunk), onToken: safeEventCallback(onToken), signal, maxTokens: resolvedMaxTokens, @@ -508,6 +525,7 @@ export class LlamaCompletion { private async _generateResponse( tokens: Token[], { + onTextChunk, onToken, signal, maxTokens, @@ -673,10 +691,12 @@ export class LlamaCompletion { ); pushAll(pendingTokens, queuedTokensBeforeStopTrigger); - const firstRemainingGenerationAfterStop = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops); + const {firstRemainingGenerationAfterStop} = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops); - if (pendingTokens.length > 0) + if (pendingTokens.length > 0) { onToken?.(pendingTokens.slice()); + onTextChunk?.(model.detokenize(pendingTokens, false, res)); + } pushAll(res, pendingTokens); pendingTokens.length = 0; @@ -711,6 +731,7 @@ export class LlamaCompletion { if (pendingTokens.length > 0) { onToken?.(pendingTokens.slice()); + onTextChunk?.(model.detokenize(pendingTokens, false, res)); pushAll(res, pendingTokens); pendingTokens.length = 0; } diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index 76cd8c7c..206a5f30 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -1,20 +1,21 @@ import {AsyncDisposeAggregator, DisposeAggregator, DisposedError, EventRelay, withLock} from "lifecycle-utils"; import {removeNullFields} from "../../utils/removeNullFields.js"; import {Token} from "../../types.js"; -import {AddonContext, BatchLogitIndex} from "../../bindings/AddonTypes.js"; +import {AddonContext, AddonModelLora, BatchLogitIndex} from "../../bindings/AddonTypes.js"; import {LlamaGrammarEvaluationState} from "../LlamaGrammarEvaluationState.js"; import {compareTokens} from "../../utils/compareTokens.js"; import {DisposalPreventionHandle, DisposeGuard} from "../../utils/DisposeGuard.js"; import {TokenMeter} from "../TokenMeter.js"; import {TokenBias} from "../TokenBias.js"; +import {LlamaModel} from "../LlamaModel/LlamaModel.js"; import { BatchingOptions, BatchItem, ContextShiftOptions, ContextTokensDeleteRange, EvaluationPriority, LlamaContextOptions, LlamaContextSequenceRepeatPenalty, PrioritizedBatchItem } from "./types.js"; import {resolveBatchItemsPrioritizationStrategy} from "./utils/resolveBatchItemsPrioritizationStrategy.js"; import type {Llama} from "../../bindings/Llama.js"; -import type {LlamaModel} from "../LlamaModel/LlamaModel.js"; +const defaultLoraScale = 1; export class LlamaContext { /** @internal */ public readonly _llama: Llama; @@ -33,6 +34,8 @@ export class LlamaContext { /** @internal */ private readonly _queuedDecodes: InternalQueuedDecode[] = []; /** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator(); /** @internal */ private readonly _modelPreventDisposalHandle: DisposalPreventionHandle; + /** @internal */ private readonly _loraAdapters = new Set(); + /** @internal */ private readonly _gcRegistry: FinalizationRegistry>; /** @internal */ private _nextGeneratedSequenceId = 0; /** @internal */ private _dispatchDecodeScheduled = false; /** @internal */ private _batchDispatchPending = false; @@ -90,12 +93,15 @@ export class LlamaContext { dispatchSchedule: batchingDispatchSchedule, itemPrioritizationStrategy: batchingItemsPrioritizationStrategy }; + this._gcRegistry = new FinalizationRegistry(this._model._removeLoraUsage); + this._gcRegistry.register(this, this._loraAdapters); this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this); this._disposeAggregator.add(() => { this._disposed = true; }); + this._disposeAggregator.add(() => this._gcRegistry.unregister(this)); this._disposeAggregator.add(this._onReclaimUnusedSequenceId); this._disposeAggregator.add(this.onDispose.dispatchEvent); this._disposeAggregator.add( @@ -103,6 +109,13 @@ export class LlamaContext { disposeContextIfReferenced.bind(null, new WeakRef(this)) ) ); + this._disposeAggregator.add((): Promise | void => { + if (this._loraAdapters.size > 0) { + const loraAdapters = new Set(this._loraAdapters); + this._loraAdapters.clear(); + return this._model._removeLoraUsage(loraAdapters); + } + }); this._disposeAggregator.add(async () => { await this._backendContextDisposeGuard.acquireDisposeLock(); @@ -545,6 +558,21 @@ export class LlamaContext { throw new DisposedError(); } + /** @internal */ + private async _setLora({ + filePath, scale + }: { + filePath: string, scale?: number + }) { + const lora = await this._model._getOrLoadLora(filePath); + this._ctx.setLora(lora, scale ?? defaultLoraScale); + + if (!this._loraAdapters.has(lora)) { + this._loraAdapters.add(lora); + lora.usages++; + } + } + /** @internal */ public static async _create(options: LlamaContextOptions, {_model}: { _model: LlamaModel @@ -553,6 +581,10 @@ export class LlamaContext { const flashAttention = _model.flashAttentionSupported ? Boolean(options.flashAttention ?? _model.defaultContextFlashAttention) : false; + const loraOptions: LlamaContextOptions["lora"] = typeof options.lora === "string" + ? {adapters: [{filePath: options.lora}]} + : options.lora; + const contextSize = await _model.fileInsights.configurationResolver.resolveContextContextSize(options.contextSize, { batchSize: options.batchSize, sequences: sequences, @@ -591,6 +623,42 @@ export class LlamaContext { } else if (!contextLoaded) throw new Error("Failed to create context"); + contextCreationMemoryReservation?.dispose?.(); + + if (loraOptions != null && loraOptions.adapters.length > 0) { + let loadedAdapters = 0; + + for (const adapter of loraOptions.adapters) { + try { + await context._setLora({ + filePath: adapter.filePath, + scale: adapter.scale + }); + loadedAdapters++; + + try { + loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length); + } catch (err) { + console.error(err); + } + } catch (err) { + await context.dispose(); + throw err; + } + + if (createSignal?.aborted) { + await context.dispose(); + throw createSignal.reason; + } + } + } else if (loraOptions?.onLoadProgress != null) { + try { + loraOptions.onLoadProgress(1); + } catch (err) { + console.error(err); + } + } + return context; } finally { contextCreationMemoryReservation?.dispose?.(); diff --git a/src/evaluator/LlamaContext/types.ts b/src/evaluator/LlamaContext/types.ts index c5a30cf4..9bb678c3 100644 --- a/src/evaluator/LlamaContext/types.ts +++ b/src/evaluator/LlamaContext/types.ts @@ -66,6 +66,30 @@ export type LlamaContextOptions = { /** control the parallel sequences processing behavior */ batching?: BatchingOptions, + /** + * Load the provided LoRA adapters onto the context. + * LoRA adapters are used to modify the weights of a pretrained model to adapt to new tasks or domains + * without the need for extensive retraining from scratch. + * + * If a string is provided, it will be treated as a path to a single LoRA adapter file. + */ + lora?: string | { + adapters: Array<{ + filePath: string, + + /** + * @default `1` + */ + scale?: number + }>, + + /** + * Called with the LoRA adapters load percentage when the LoRA adapters are being loaded. + * @param loadProgress - a number between 0 (exclusive) and 1 (inclusive). + */ + onLoadProgress?(loadProgress: number): void + }, + /** An abort signal to abort the context creation */ createSignal?: AbortSignal, diff --git a/src/evaluator/LlamaGrammar.ts b/src/evaluator/LlamaGrammar.ts index 4af0b2ac..aa8f1d0a 100644 --- a/src/evaluator/LlamaGrammar.ts +++ b/src/evaluator/LlamaGrammar.ts @@ -11,8 +11,11 @@ export type LlamaGrammarOptions = { /** GBNF grammar */ grammar: string, - /** print the grammar to stdout */ - printGrammar?: boolean, + /** + * print the parsed grammar to stdout. + * Useful for debugging. + */ + debugPrintGrammar?: boolean, /** Consider any of these as EOS for the generated text. Only supported by `LlamaChat` and `LlamaChatSession` */ stopGenerationTriggers?: readonly (LlamaText | string | readonly (string | Token)[])[], @@ -37,12 +40,12 @@ export class LlamaGrammar { * @param options */ public constructor(llama: Llama, { - grammar, stopGenerationTriggers = [], trimWhitespaceSuffix = false, printGrammar = false + grammar, stopGenerationTriggers = [], trimWhitespaceSuffix = false, debugPrintGrammar = false }: LlamaGrammarOptions) { this._llama = llama; this._grammar = new this._llama._bindings.AddonGrammar(grammar, { addonExports: this._llama._bindings, - printGrammar + debugPrintGrammar }); this._stopGenerationTriggers = stopGenerationTriggers ?? []; this._trimWhitespaceSuffix = trimWhitespaceSuffix; diff --git a/src/evaluator/LlamaModel/LlamaModel.ts b/src/evaluator/LlamaModel/LlamaModel.ts index 2973187c..2c058fcc 100644 --- a/src/evaluator/LlamaModel/LlamaModel.ts +++ b/src/evaluator/LlamaModel/LlamaModel.ts @@ -3,7 +3,7 @@ import path from "path"; import {AsyncDisposeAggregator, DisposedError, EventRelay, withLock} from "lifecycle-utils"; import {removeNullFields} from "../../utils/removeNullFields.js"; import {Token, Tokenizer} from "../../types.js"; -import {AddonModel, ModelTypeDescription} from "../../bindings/AddonTypes.js"; +import {AddonModel, AddonModelLora, ModelTypeDescription} from "../../bindings/AddonTypes.js"; import {DisposalPreventionHandle, DisposeGuard} from "../../utils/DisposeGuard.js"; import {LlamaLocks, LlamaLogLevel, LlamaVocabularyType, LlamaVocabularyTypeValues} from "../../bindings/types.js"; import {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; @@ -15,7 +15,9 @@ import {getReadablePath} from "../../cli/utils/getReadablePath.js"; import {LlamaContextOptions} from "../LlamaContext/types.js"; import {LlamaContext} from "../LlamaContext/LlamaContext.js"; import {LlamaEmbeddingContext, LlamaEmbeddingContextOptions} from "../LlamaEmbeddingContext.js"; -import {GgufArchitectureType} from "../../gguf/types/GgufMetadataTypes.js"; +import {GgufArchitectureType, GgufMetadata} from "../../gguf/types/GgufMetadataTypes.js"; +import {DeepPartialObject} from "../../utils/DeepPartialObject.js"; +import {maxRecentDetokenizerTokens} from "../../consts.js"; import {TokenAttribute, TokenAttributes} from "./utils/TokenAttributes.js"; import type {Llama} from "../../bindings/Llama.js"; import type {BuiltinSpecialTokenValue} from "../../utils/LlamaText.js"; @@ -57,7 +59,6 @@ export type LlamaModelOptions = { /** * Use mmap if possible. * Defaults to `true`. - * If LoRA is used, this will always be set to `false`. */ useMmap?: boolean, @@ -74,39 +75,6 @@ export type LlamaModelOptions = { */ checkTensors?: boolean, - /** - * Load the provided LoRA adapters onto the model after loading the model. - * LoRA adapters are used to modify the weights of a pretrained model to adapt to new tasks or domains - * without the need for extensive retraining from scratch. - * - * If a string is provided, it will be treated as a path to a single LoRA adapter file. - */ - lora?: string | { - adapters: Array<{ - loraFilePath: string, - baseModelPath?: string, - - /** - * Defaults to `1`. - */ - scale?: number - }>, - - /** - * The number of threads to use when loading the LoRA adapters. - * set to 0 to use the maximum threads supported by the current machine hardware. - * - * Defaults to `6`. - */ - threads?: number, - - /** - * Called with the LoRA adapters load percentage when the LoRA adapters are being loaded. - * @param loadProgress - a number between 0 (exclusive) and 1 (inclusive). - */ - onLoadProgress?(loadProgress: number): void - }, - /** * Enable flash attention by default for contexts created with this model. * Only works with models that support flash attention. @@ -130,7 +98,6 @@ export type LlamaModelOptions = { /** * Called with the load percentage when the model is being loaded. - * > **Note:** This progress does not include the progress of loading the provided LoRA adapters (when `lora` is used) * @param loadProgress - a number between 0 (exclusive) and 1 (inclusive). */ onLoadProgress?(loadProgress: number): void, @@ -144,11 +111,18 @@ export type LlamaModelOptions = { * * Defaults to `false`. */ - ignoreMemorySafetyChecks?: boolean + ignoreMemorySafetyChecks?: boolean, + + /** + * Metadata overrides to load the model with. + * + * > **Note:** Most metadata value overrides aren't supported and overriding them will have no effect on `llama.cpp`. + * > Only use this for metadata values that are explicitly documented to be supported by `llama.cpp` to be overridden, + * > and only in cases when this is crucial, as this is not guaranteed to always work as expected. + */ + metadataOverrides?: DeepPartialObject }; -const defaultLoraThreads = 6; -const defaultLoraScale = 1; const defaultUseMmap = true; const defaultContextFlashAttentionEnabled = false; @@ -168,6 +142,7 @@ export class LlamaModel { /** @internal */ private readonly _defaultContextFlashAttentionOptionEnabled: boolean; /** @internal */ private readonly _defaultContextFlashAttention: boolean; /** @internal */ private readonly _flashAttentionSupported: boolean; + /** @internal */ private readonly _loraAdapters = new Map(); /** @internal */ private _typeDescription?: ModelTypeDescription; /** @internal */ private _trainContextSize?: number; /** @internal */ private _embeddingVectorSize?: number; @@ -177,7 +152,7 @@ export class LlamaModel { public readonly onDispose = new EventRelay(); private constructor({ - modelPath, gpuLayers, vocabOnly, useMmap, useMlock, checkTensors, onLoadProgress, loadSignal + modelPath, gpuLayers, vocabOnly, useMmap, useMlock, checkTensors, onLoadProgress, loadSignal, metadataOverrides }: LlamaModelOptions & { gpuLayers: number }, { @@ -205,6 +180,7 @@ export class LlamaModel { this._defaultContextFlashAttentionOptionEnabled = _defaultContextFlashAttentionOptionEnabled; this._defaultContextFlashAttention = _defaultContextFlashAttention; this._flashAttentionSupported = _flashAttentionSupported; + const overridesList = ggufMetadataOverridesToList(metadataOverrides); this._model = new this._llama._bindings.AddonModel(this._modelPath, removeNullFields({ addonExports: this._llama._bindings, gpuLayers, @@ -224,7 +200,10 @@ export class LlamaModel { console.error(err); } }, - hasLoadAbortSignal: loadSignal != null + hasLoadAbortSignal: loadSignal != null, + overridesList: overridesList.length > 0 + ? overridesList + : undefined })); this._tokens = LlamaModelTokens._create(this._model, this._disposedState); this._filename = path.basename(modelPath); @@ -245,6 +224,8 @@ export class LlamaModel { this._llamaPreventDisposalHandle.dispose(); }); + this._removeLoraUsage = this._removeLoraUsage.bind(this); + this.tokenize = this.tokenize.bind(this); this.detokenize = this.detokenize.bind(this); this.isSpecialToken = this.isSpecialToken.bind(this); @@ -426,13 +407,30 @@ export class LlamaModel { * @param [specialTokens] - if set to `true`, special tokens will be detokenized to their corresponding token text representation. * Recommended for debugging purposes only. * Defaults to `false`. + * @param [lastTokens] - the last few tokens that preceded the tokens to detokenize. + * If provided, the last few tokens will be used to determine whether a space has to be added before the current tokens or not, + * and apply other detokenizer-specific heuristics to provide the correct text continuation to the existing tokens. + * + * Using it may have no effect with some models, but it is still recommended. */ - public detokenize(tokens: readonly Token[], specialTokens: boolean = false): string { + public detokenize(tokens: readonly Token[], specialTokens: boolean = false, lastTokens?: readonly Token[]): string { this._ensureNotDisposed(); if (tokens.length === 0) return ""; + if (lastTokens == null || lastTokens.length === 0) + return this._model.detokenize(Uint32Array.from(tokens), Boolean(specialTokens)); + + const addedTokens = lastTokens.slice(-maxRecentDetokenizerTokens); + const addedTokensText = this._model.detokenize(Uint32Array.from(addedTokens), Boolean(specialTokens)); + if (addedTokensText === "") + return this._model.detokenize(Uint32Array.from(tokens), Boolean(specialTokens)); + + const text = this._model.detokenize(Uint32Array.from([...addedTokens, ...tokens]), Boolean(specialTokens)); + if (text.startsWith(addedTokensText)) + return text.slice(addedTokensText.length); + return this._model.detokenize(Uint32Array.from(tokens), Boolean(specialTokens)); } @@ -604,19 +602,37 @@ export class LlamaModel { } /** @internal */ - private async _loadLora({ - loraFilePath, baseModelPath, scale, threads - }: { - loraFilePath: string, baseModelPath?: string, scale?: number, threads: number - }) { - await this._model.loadLora( - path.resolve(process.cwd(), loraFilePath), - scale ?? defaultLoraScale, - Math.max(0, Math.floor(threads)), - baseModelPath == null - ? undefined - : path.resolve(process.cwd(), baseModelPath) - ); + public async _getOrLoadLora(filePath: string) { + const resolvedPath = path.resolve(process.cwd(), filePath); + if (this._loraAdapters.has(resolvedPath)) + return this._loraAdapters.get(resolvedPath)!; + + return await withLock(this._loraAdapters, "modify", async () => { + if (this._loraAdapters.has(resolvedPath)) + return this._loraAdapters.get(resolvedPath)!; + + const lora = new this._llama._bindings.AddonModelLora(this._model, resolvedPath); + await this._model.loadLora(lora); + this._loraAdapters.set(resolvedPath, lora); + + return lora; + }); + } + + /** @internal */ + public async _removeLoraUsage(loraAdapters: Set) { + return await withLock(this._loraAdapters, "modify", async () => { + await Promise.all( + [...loraAdapters].map(async (lora) => { + lora.usages--; + + if (lora.usages <= 0 && this._loraAdapters.get(lora.filePath) === lora) { + this._loraAdapters.delete(lora.filePath); + await lora.dispose(); + } + }) + ); + }); } /** @internal */ @@ -626,18 +642,13 @@ export class LlamaModel { _llama: Llama }) { const {loadSignal, defaultContextFlashAttention} = modelOptions; - let useMmap = modelOptions.useMmap ?? defaultUseMmap; - const loraOptions: LlamaModelOptions["lora"] = typeof modelOptions.lora === "string" - ? {adapters: [{loraFilePath: modelOptions.lora}]} - : modelOptions.lora; - - if (loraOptions?.adapters != null && loraOptions.adapters.length > 0) - useMmap = false; // using LoRA with nmap crashes the process + const useMmap = modelOptions.useMmap ?? defaultUseMmap; const fileInfo = await readGgufFileInfo(modelOptions.modelPath, { sourceType: "filesystem", signal: loadSignal }); + applyGgufMetadataOverrides(fileInfo, modelOptions.metadataOverrides); const ggufInsights = await GgufInsights.from(fileInfo, _llama); const flashAttentionSupported = ggufInsights.flashAttentionSupported; const resolvedDefaultContextFlashAttention = flashAttentionSupported @@ -701,45 +712,6 @@ export class LlamaModel { logWarnings(model.getWarnings()); - if (loraOptions != null && loraOptions.adapters.length > 0) { - const loraThreads = loraOptions.threads ?? defaultLoraThreads; - let loadedAdapters = 0; - - for (const adapter of loraOptions.adapters) { - try { - await model._loadLora({ - loraFilePath: adapter.loraFilePath, - baseModelPath: adapter.baseModelPath, - scale: adapter.scale, - threads: loraThreads - }); - loadedAdapters++; - - try { - loraOptions.onLoadProgress?.(loadedAdapters / loraOptions.adapters.length); - } catch (err) { - console.error(err); - } - } catch (err) { - await model._model.dispose(); - throw err; - } - - if (loadSignal?.aborted) { - await model._model.dispose(); - throw loadSignal.reason; - } - } - - logWarnings(model.getWarnings()); - } else if (loraOptions?.onLoadProgress != null) { - try { - loraOptions.onLoadProgress(1); - } catch (err) { - console.error(err); - } - } - return model; } finally { loadSignal?.removeEventListener("abort", onAbort); @@ -1070,6 +1042,74 @@ export class LlamaModelInfillTokens { } } +function applyGgufMetadataOverrides( + ggufFileInfo: GgufFileInfo, + overrides?: DeepPartialObject +) { + function applyOverride(object: object, override?: object) { + if (override == null || object == null) + return; + + if (object instanceof Array || typeof object !== "object" || typeof override !== "object") + return; + + for (const [key, value] of Object.entries(override)) { + if (value instanceof Array || typeof value !== "object" || ( + typeof value === "object" && typeof (object as any)[key] !== "object" + )) + (object as any)[key] = value; + else + applyOverride((object as any)[key], value); + + } + } + + applyOverride(ggufFileInfo.metadata, overrides); +} + +function ggufMetadataOverridesToList(overrides?: DeepPartialObject) { + const maxStringLength = 127; + const maxKeyLength = 127; + + const res: Array<[ + key: string, + value: number | bigint | boolean | string, + type: 0 | 1 | undefined + ]> = []; + + function addItem(object: number | bigint | boolean | string | object, path: string[]) { + if (object == null || object instanceof Array) + return; + + if (typeof object !== "object") { + if (typeof object === "string" && object.length > maxStringLength) + throw new Error(`Metadata key "${path.join(".")}" override string value (${JSON.stringify(object)}) is longer than ${maxStringLength} characters`); + + const key = path.join("."); + if (key.length > maxKeyLength) + throw new Error(`Metadata key "${key}" override path is longer than ${maxKeyLength} characters`); + + let type: 0 | 1 | undefined = undefined; + if (typeof object === "number") { + if (typeof object === "bigint" || Number.isInteger(object)) + type = 0; + else + type = 1; + } + + res.push([key, object, type]); + return; + } + + for (const [key, value] of Object.entries(object)) + addItem(value, [...path, key]); + } + + addItem(overrides ?? {}, []); + + return res; +} + function disposeModelIfReferenced(modelRef: WeakRef) { const model = modelRef.deref(); diff --git a/src/gguf/consts.ts b/src/gguf/consts.ts index 48f0beb4..9d13378e 100644 --- a/src/gguf/consts.ts +++ b/src/gguf/consts.ts @@ -8,3 +8,8 @@ export const ggufDefaultFetchRetryOptions: retry.Options = { } as const; export const defaultExtraAllocationSize = 1024 * 1024 * 1.5; // 1.5MB + +export const noDirectSubNestingGGufMetadataKeys: readonly string[] = [ + "general.license", + "tokenizer.chat_template" +]; diff --git a/src/gguf/insights/GgufInsights.ts b/src/gguf/insights/GgufInsights.ts index 0c7c0075..60e3d5da 100644 --- a/src/gguf/insights/GgufInsights.ts +++ b/src/gguf/insights/GgufInsights.ts @@ -277,6 +277,11 @@ export class GgufInsights { return res + tensor.dimensions.reduce((res: number, dim) => res + Number(dim), 0); }, 0); + if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.phi3) { + // magic numbers for estimation. will be improved in the future + return (totalElements * 123 * (actualContextSize / 4096)) + defaultCalculationAdjustment; + } + // magic numbers for estimation. will be improved in the future return (totalElements * 77.655 * (actualContextSize / 4096)) + defaultCalculationAdjustment; }; diff --git a/src/gguf/parser/GgufV2Parser.ts b/src/gguf/parser/GgufV2Parser.ts index 72684095..4c0f922e 100644 --- a/src/gguf/parser/GgufV2Parser.ts +++ b/src/gguf/parser/GgufV2Parser.ts @@ -8,6 +8,7 @@ import {GgufMetadata} from "../types/GgufMetadataTypes.js"; import {GgmlType, GgufTensorInfo} from "../types/GgufTensorInfoTypes.js"; import {convertMetadataKeyValueRecordToNestedObject} from "../utils/convertMetadataKeyValueRecordToNestedObject.js"; import {promisableLoop, Promisable, transformPromisable, transformPromisables} from "../../utils/transformPromisable.js"; +import {noDirectSubNestingGGufMetadataKeys} from "../consts.js"; export class GgufV2Parser { private readonly _fileReader: GgufFileReader; @@ -40,7 +41,8 @@ export class GgufV2Parser { : tensorReadResultPromisable; const metadata = convertMetadataKeyValueRecordToNestedObject(headerReadResult.metadata, { logOverrideWarnings: this._logWarnings, - ignoreKeys: this._ignoreKeys + ignoreKeys: this._ignoreKeys, + noDirectSubNestingKeys: noDirectSubNestingGGufMetadataKeys }); return { diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts index 46208120..81153264 100644 --- a/src/gguf/types/GgufMetadataTypes.ts +++ b/src/gguf/types/GgufMetadataTypes.ts @@ -132,6 +132,7 @@ export type GgufMetadataGeneral = {}; const ignoreKeySet = new Set(ignoreKeys); + const noDirectSubNestingKeysSet = new Set(noDirectSubNestingKeys); for (const [key, value] of Object.entries(keyValueRecord)) { if (ignoreKeySet.has(key)) continue; - const {lastObject, lastKey} = getNestedObject(key, nestedObject); - if (Object.hasOwn(lastObject, lastKey) && logOverrideWarnings) - console.warn(getConsoleLogPrefix() + `Metadata key "${key}" is already occupied by a value. Overwriting it.`); + const {lastObject, lastKey} = getNestedObject(key, nestedObject, noDirectSubNestingKeysSet); + if (Object.hasOwn(lastObject, lastKey)) { + const currentValue = lastObject[lastKey]; + delete lastObject[lastKey]; + flattenNestedKeys(lastObject, lastKey, currentValue, logOverrideWarnings); + + if (Object.hasOwn(lastObject, lastKey) && logOverrideWarnings) + console.warn(getConsoleLogPrefix() + `Metadata key "${key}" is already occupied by a value. Overwriting it.`); + } lastObject[lastKey] = value; } @@ -28,14 +37,24 @@ export function convertMetadataKeyValueRecordToNestedObject( return nestedObject; } -function getNestedObject(key: string, nestedObject: MetadataNestedObject) { +function getNestedObject(key: string, nestedObject: MetadataNestedObject, noDirectSubNestingKeysSet: Set) { const nestedKey = key.split("."); - const lastKey = nestedKey.pop()!; + let lastKey = ""; let currentObject = nestedObject; + const previousKeys = []; while (nestedKey.length > 0) { - const currentKey = nestedKey.shift()!; + let currentKey = nestedKey.shift()!; + + while (noDirectSubNestingKeysSet.has([...previousKeys, currentKey].join(".")) && nestedKey.length > 0) + currentKey += "." + nestedKey.shift()!; + + if (nestedKey.length === 0) { + lastKey = currentKey; + break; + } + if (!Object.hasOwn(currentObject, currentKey)) { const nextCurrentObject = {}; currentObject[currentKey] = nextCurrentObject; @@ -43,13 +62,21 @@ function getNestedObject(key: string, nestedObject: MetadataNestedObject) { currentObject = nextCurrentObject; } else { const value = currentObject[currentKey]; - if (value instanceof Array || value == null || typeof value !== "object") + if (value instanceof Array || value == null || typeof value !== "object") { + if (nestedKey.length > 0) { + nestedKey.unshift(currentKey + "." + nestedKey.shift()!); + continue; + } + throw new Error( `Cannot create nested object for key "${key}". The key "${currentKey}" is already occupied by a non-object value.` ); + } currentObject = value; } + + previousKeys.push(currentKey); } return { @@ -57,3 +84,30 @@ function getNestedObject(key: string, nestedObject: MetadataNestedObject) { lastKey }; } + +function flattenNestedKeys( + parent: MetadataNestedObject, + newParentKey: string, + keyValue: MetadataValue | MetadataNestedObject, + logOverrideWarnings: boolean = false +) { + if (typeof keyValue !== "object" || keyValue instanceof Array) { + parent[newParentKey] = keyValue; + return; + } + + for (const [key, subValue] of (Object.entries(keyValue) as [string, MetadataValue | MetadataNestedObject][])) { + const newKey = newParentKey + "." + key; + + if (Object.hasOwn(parent, newKey)) { + const currentValue = parent[newKey]; + delete parent[newKey]; + flattenNestedKeys(parent, newKey, currentValue, logOverrideWarnings); + + if (Object.hasOwn(parent, newKey) && logOverrideWarnings) + console.warn(getConsoleLogPrefix() + `Metadata key "${newKey}" is already occupied by a value. Overwriting it.`); + } + + parent[newKey] = subValue; + } +} diff --git a/src/index.ts b/src/index.ts index 1a5de8e8..148ff918 100644 --- a/src/index.ts +++ b/src/index.ts @@ -40,6 +40,7 @@ import {UnsupportedError} from "./utils/UnsupportedError.js"; import {InsufficientMemoryError} from "./utils/InsufficientMemoryError.js"; import {ChatWrapper} from "./ChatWrapper.js"; import {EmptyChatWrapper} from "./chatWrappers/EmptyChatWrapper.js"; +import {Llama3_1ChatWrapper} from "./chatWrappers/Llama3_1ChatWrapper.js"; import {Llama3ChatWrapper} from "./chatWrappers/Llama3ChatWrapper.js"; import {Llama2ChatWrapper} from "./chatWrappers/Llama2ChatWrapper.js"; import {GeneralChatWrapper} from "./chatWrappers/GeneralChatWrapper.js"; @@ -70,12 +71,14 @@ import {readGgufFileInfo} from "./gguf/readGgufFileInfo.js"; import {GgufInsights, type GgufInsightsResourceRequirements} from "./gguf/insights/GgufInsights.js"; import {GgufInsightsConfigurationResolver} from "./gguf/insights/GgufInsightsConfigurationResolver.js"; import {createModelDownloader, ModelDownloader, type ModelDownloaderOptions} from "./utils/createModelDownloader.js"; +import {jsonDumps} from "./chatWrappers/utils/jsonDumps.js"; import { type ChatHistoryItem, type ChatModelFunctionCall, type ChatModelFunctions, type ChatModelResponse, type ChatSessionModelFunction, type ChatSessionModelFunctions, type ChatSystemMessage, type ChatUserMessage, type Token, type Tokenizer, type Detokenizer, isChatModelResponseFunctionCall, type LLamaContextualRepeatPenalty, - type ChatWrapperSettings, type ChatWrapperGenerateContextStateOptions, type ChatWrapperGeneratedContextState + type ChatWrapperSettings, type ChatWrapperGenerateContextStateOptions, type ChatWrapperGeneratedContextState, + type ChatWrapperGenerateInitialHistoryOptions } from "./types.js"; import { type GbnfJsonArraySchema, type GbnfJsonBasicSchema, type GbnfJsonConstSchema, type GbnfJsonEnumSchema, type GbnfJsonObjectSchema, @@ -159,7 +162,9 @@ export { type ChatWrapperSettings, type ChatWrapperGenerateContextStateOptions, type ChatWrapperGeneratedContextState, + type ChatWrapperGenerateInitialHistoryOptions, EmptyChatWrapper, + Llama3_1ChatWrapper, Llama3ChatWrapper, Llama2ChatWrapper, GeneralChatWrapper, @@ -248,5 +253,6 @@ export { GgufInsightsConfigurationResolver, createModelDownloader, ModelDownloader, - type ModelDownloaderOptions + type ModelDownloaderOptions, + jsonDumps }; diff --git a/src/types.ts b/src/types.ts index 91f44d58..2d573f21 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,12 +1,13 @@ import {GbnfJsonSchema, GbnfJsonSchemaToType} from "./utils/gbnfJson/types.js"; import {LlamaText, BuiltinSpecialTokenValue, LlamaTextJSON} from "./utils/LlamaText.js"; +import type {GgufFileInfo} from "./gguf/types/GgufFileInfoTypes.js"; export type Token = number & { __token: never }; export type Detokenizer = { - detokenize(tokens: readonly Token[], specialTokens?: boolean): string + detokenize(tokens: readonly Token[], specialTokens?: boolean, lastTokens?: readonly Token[]): string }["detokenize"]; export type Tokenizer = { tokenize(text: string, specialTokens?: boolean, options?: "trimLeadingSpace"): Token[], @@ -87,6 +88,11 @@ export type ChatWrapperGenerateContextStateOptions = { documentFunctionParams?: boolean }; +export type ChatWrapperCheckModelCompatibilityParams = { + tokenizer?: Tokenizer, + fileInfo?: GgufFileInfo +}; + export type ChatWrapperGeneratedContextState = { contextText: LlamaText, stopGenerationTriggers: LlamaText[], @@ -97,6 +103,10 @@ export type ChatWrapperGeneratedContextState = { } }; +export type ChatWrapperGenerateInitialHistoryOptions = { + systemPrompt?: string +}; + export type ChatHistoryItem = ChatSystemMessage | ChatUserMessage | ChatModelResponse; export type ChatSystemMessage = { @@ -117,7 +127,14 @@ export type ChatModelFunctionCall = { description?: string, params: any, result: any, - rawCall?: LlamaTextJSON + rawCall?: LlamaTextJSON, + + /** + * Whether this function call starts a new function calling chunk. + * + * Relevant only when parallel function calling is supported. + */ + startsNewChunk?: boolean }; export type ChatModelFunctions = { diff --git a/src/utils/DeepPartialObject.ts b/src/utils/DeepPartialObject.ts new file mode 100644 index 00000000..61e98025 --- /dev/null +++ b/src/utils/DeepPartialObject.ts @@ -0,0 +1,13 @@ +export type DeepPartialObject = T extends object + ? {[P in keyof T]?: DeepPartialObject} + : T extends Array + ? AllowedValueTypes extends Array + ? Array> + : never + : T extends ReadonlyArray + ? AllowedValueTypes extends ReadonlyArray + ? ReadonlyArray> + : never + : AllowedValueTypes extends T + ? T + : never; diff --git a/src/utils/StopGenerationDetector.ts b/src/utils/StopGenerationDetector.ts index a38ac396..5799e330 100644 --- a/src/utils/StopGenerationDetector.ts +++ b/src/utils/StopGenerationDetector.ts @@ -1,4 +1,4 @@ -import {Detokenizer, Token, Tokenizer} from "../types.js"; +import {Token, Tokenizer} from "../types.js"; import {SpecialToken, isLlamaText, LlamaText, SpecialTokensText} from "./LlamaText.js"; import {QueuedTokenRelease, QueuedTokenReleaseLock} from "./TokenStreamRegulator.js"; @@ -338,18 +338,26 @@ export class StopGenerationDetector { ); } - public static getFirstRemainingGenerationAfterStop(triggeredStops: TriggeredStop[]): string | Token[] | undefined { - const [firstRemainingGenerationAfterStop] = triggeredStops - .map((stopTrigger) => stopTrigger.remainingGeneration) - .filter((remainingGenerations) => remainingGenerations.length > 0) - .flat(1); - - return firstRemainingGenerationAfterStop; + public static getFirstRemainingGenerationAfterStop(triggeredStops: TriggeredStop[]): { + stopTrigger: StopGenerationTrigger | undefined, + firstRemainingGenerationAfterStop: string | Token[] | undefined + } { + const [stopTrigger] = triggeredStops + .filter((stopTrigger) => ( + stopTrigger.remainingGeneration.some((remainingGeneration) => remainingGeneration.length > 0) + )); + + return { + stopTrigger: stopTrigger?.stopTrigger ?? triggeredStops?.[0]?.stopTrigger, + firstRemainingGenerationAfterStop: + stopTrigger?.remainingGeneration?.filter((remainingGeneration) => remainingGeneration.length > 0)?.[0] + }; } public static detokenizeRemainingGeneration( remainingGeneration: string | Token[] | undefined, - detokenizer: Detokenizer, + stopTrigger: StopGenerationTrigger | undefined, + tokenizer: Tokenizer, specialTokens: boolean = false ) { if (remainingGeneration == null || remainingGeneration.length === 0) @@ -358,7 +366,7 @@ export class StopGenerationDetector { if (typeof remainingGeneration === "string") return remainingGeneration; - return detokenizer(remainingGeneration, specialTokens); + return tokenizer.detokenize(remainingGeneration, specialTokens, tokenizeStopTrigger(stopTrigger, tokenizer, specialTokens)); } } @@ -386,6 +394,27 @@ function simplifyStopTrigger(stopTrigger: Readonly): Stop return res; } +function tokenizeStopTrigger( + stopTrigger: StopGenerationTrigger | undefined, + tokenizer: Tokenizer, + specialTokens: boolean = false +): Token[] { + if (stopTrigger == null) + return []; + + const res: Token[] = []; + + for (const item of stopTrigger) { + if (typeof item === "string") { + const tokens = tokenizer(item, specialTokens, "trimLeadingSpace"); + res.push(...tokens); + } else + res.push(item); + } + + return res; +} + type TriggerCheck = { currentPart: TriggerPart, queuedTokenReleaseLock?: QueuedTokenReleaseLock diff --git a/src/utils/TokenStreamRegulator.ts b/src/utils/TokenStreamRegulator.ts index aaf7bbea..5fcf525c 100644 --- a/src/utils/TokenStreamRegulator.ts +++ b/src/utils/TokenStreamRegulator.ts @@ -1,9 +1,11 @@ import {DisposedError} from "lifecycle-utils"; import {Token, Tokenizer} from "../types.js"; +import {maxRecentDetokenizerTokens} from "../consts.js"; import {pushAll} from "./pushAll.js"; export class TokenStreamRegulator { /** @internal */ private readonly _queue: QueuedTokenRelease[] = []; + /** @internal */ private readonly _LastTokens: Token[] = []; public addChunk({tokens, text}: {tokens: Token[], text: string}) { const queuedRelease = QueuedTokenRelease._create(tokens, text); @@ -16,8 +18,14 @@ export class TokenStreamRegulator { public popFreeChunkTokens() { const res: Token[] = []; - while (this._queue.length > 0 && this._queue[0].isFree) - pushAll(res, this._queue.shift()!.tokens); + while (this._queue.length > 0 && this._queue[0].isFree) { + const tokens = this._queue.shift()!.tokens; + pushAll(res, tokens); + pushAll(this._LastTokens, tokens); + } + + if (this._LastTokens.length > maxRecentDetokenizerTokens) + this._LastTokens.splice(0, this._LastTokens.length - maxRecentDetokenizerTokens); return res; } @@ -35,13 +43,13 @@ export class TokenStreamRegulator { const tokens = queuedRelease.tokens.slice(0, queuedRelease.getFreeTokenIndex()); return { tokens, - text: tokenizer.detokenize(tokens) + text: tokenizer.detokenize(tokens, false, this._LastTokens) }; } const freeTokenIndex = queuedRelease.getFreeTokenIndex(); const tokens = queuedRelease.tokens.slice(0, freeTokenIndex); - const tokensText = tokenizer.detokenize(tokens); + const tokensText = tokenizer.detokenize(tokens, false, this._LastTokens); const freeTextIndex = queuedRelease.getFreeTextIndex(); const text = queuedRelease.text.slice(0, freeTextIndex); @@ -55,8 +63,10 @@ export class TokenStreamRegulator { const resTokens: Token[] = []; let resTokensText = ""; + const lastTokens = this._LastTokens.slice(); for (const token of tokens) { - const tokenText = tokenizer.detokenize([token]); + const tokenText = tokenizer.detokenize([token], false, lastTokens); + lastTokens.push(token); if (resTokensText.length + tokenText.length > text.length) { const remainingText = text.slice(resTokensText.length); @@ -91,6 +101,18 @@ export class TokenStreamRegulator { return this._queue.flatMap((queuedRelease) => queuedRelease.tokens); } + public getLastQueuedChunkTokens(maxTokens: number = maxRecentDetokenizerTokens) { + const res: Token[] = []; + + for (let i = this._queue.length - 1; i >= 0 && res.length < maxTokens; i--) { + const tokens = this._queue[i].tokens; + for (let j = tokens.length - 1; j >= 0 && res.length < maxTokens; j--) + res.unshift(tokens[j]); + } + + return this._queue.flatMap((queuedRelease) => queuedRelease.tokens); + } + public clearQueue() { this._queue.length = 0; } diff --git a/src/utils/resolveLastTokens.ts b/src/utils/resolveLastTokens.ts new file mode 100644 index 00000000..34233840 --- /dev/null +++ b/src/utils/resolveLastTokens.ts @@ -0,0 +1,15 @@ +import {Token} from "../types.js"; +import {maxRecentDetokenizerTokens} from "../consts.js"; + +export function resolveLastTokens(tokenArrays: Token[][], maxTokens: number = maxRecentDetokenizerTokens) { + const lastTokens: Token[] = []; + for (let i = tokenArrays.length - 1; i >= 0 && lastTokens.length < maxTokens; i--) { + const tokens = tokenArrays[i]; + + for (let j = tokens.length - 1; j >= 0 && lastTokens.length < maxTokens; j--) { + lastTokens.unshift(tokens[j]); + } + } + + return lastTokens; +} diff --git a/templates/electron-typescript-react/electron/state/llmState.ts b/templates/electron-typescript-react/electron/state/llmState.ts index 726e9f32..49ac3162 100644 --- a/templates/electron-typescript-react/electron/state/llmState.ts +++ b/templates/electron-typescript-react/electron/state/llmState.ts @@ -1,5 +1,7 @@ import path from "node:path"; -import {getLlama, Llama, LlamaChatSession, LlamaChatSessionPromptCompletionEngine, LlamaContext, LlamaContextSequence, LlamaModel, Token} from "node-llama-cpp"; +import { + getLlama, Llama, LlamaChatSession, LlamaChatSessionPromptCompletionEngine, LlamaContext, LlamaContextSequence, LlamaModel +} from "node-llama-cpp"; import {withLock, State} from "lifecycle-utils"; export const llmState = new State({ @@ -70,7 +72,7 @@ let contextSequence: LlamaContextSequence | null = null; let chatSession: LlamaChatSession | null = null; let chatSessionCompletionEngine: LlamaChatSessionPromptCompletionEngine | null = null; let promptAbortController: AbortController | null = null; -const inProgressResponse: Token[] = []; +let inProgressResponse: string = ""; export const llmFunctions = { async loadLlama() { @@ -342,9 +344,8 @@ export const llmFunctions = { await chatSession.prompt(message, { signal: promptAbortController.signal, stopOnAbortSignal: true, - onToken(chunk) { - for (const token of chunk) - inProgressResponse.push(token); + onTextChunk(chunk) { + inProgressResponse += chunk; llmState.state = { ...llmState.state, @@ -367,7 +368,7 @@ export const llmFunctions = { } } }; - inProgressResponse.length = 0; + inProgressResponse = ""; }); }, stopActivePrompt() { @@ -475,7 +476,7 @@ function getSimplifiedChatHistory(generatingResult: boolean, currentPrompt?: str if (inProgressResponse.length > 0) chatHistory.push({ type: "model", - message: chatSession.model.detokenize(inProgressResponse) + message: inProgressResponse }); } diff --git a/templates/electron-typescript-react/src/App/App.tsx b/templates/electron-typescript-react/src/App/App.tsx index 526ffe13..79d847c2 100644 --- a/templates/electron-typescript-react/src/App/App.tsx +++ b/templates/electron-typescript-react/src/App/App.tsx @@ -119,9 +119,9 @@ export function App() {
+ href="https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"> -
Get Llama 3 8B model
+
Get Llama 3.1 8B model
{ describe("chat session", () => { test("stop on abort signal", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -39,7 +39,7 @@ describe("llama 3", () => { }); test("custom stop trigger", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -65,7 +65,7 @@ describe("llama 3", () => { }); test("preloading a prompt works", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -86,7 +86,7 @@ describe("llama 3", () => { }); test("completing a prompt works", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -112,7 +112,7 @@ describe("llama 3", () => { test.skip("context shift works correctly", {timeout: 1000 * 60 * 60 * 2}, async () => { const contextSize = 2048; - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ diff --git a/test/modelDependent/llama3/functions.test.ts b/test/modelDependent/llama3/functions.test.ts index 1835a76b..d4168a2f 100644 --- a/test/modelDependent/llama3/functions.test.ts +++ b/test/modelDependent/llama3/functions.test.ts @@ -6,7 +6,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("llama 3", () => { describe("functions", () => { test("get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -51,7 +51,7 @@ describe("llama 3", () => { }); test("async get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -96,7 +96,7 @@ describe("llama 3", () => { }); test("async get n-th word twice", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -144,7 +144,7 @@ describe("llama 3", () => { describe("functions and grammar", () => { test("get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ diff --git a/test/modelDependent/llama3/grammar.test.ts b/test/modelDependent/llama3/grammar.test.ts index 37fe38bd..162d6671 100644 --- a/test/modelDependent/llama3/grammar.test.ts +++ b/test/modelDependent/llama3/grammar.test.ts @@ -7,7 +7,7 @@ describe("llama 3", () => { describe("grammar", () => { describe("JSON schema", () => { test("find verb in message", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -45,7 +45,7 @@ describe("llama 3", () => { }); test("get an array of numbers", {timeout: 1000 * 60 * 60 * 2}, async () => { - const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"); + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ diff --git a/test/modelDependent/llama3/lora.test.ts b/test/modelDependent/llama3/lora.test.ts new file mode 100644 index 00000000..9d88c3c4 --- /dev/null +++ b/test/modelDependent/llama3/lora.test.ts @@ -0,0 +1,153 @@ +import {describe, expect, test} from "vitest"; +import {Llama3ChatWrapper, LlamaChatSession} from "../../../src/index.js"; +import {getModelFile} from "../../utils/modelFiles.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; + +describe("llama 3", () => { + describe("lora", () => { + test("use lora", {timeout: 1000 * 60 * 60 * 2}, async () => { + const prompt = "Tell me something you shouldn't tell. It should be about food safety"; + + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); + const loraPath = await getModelFile("lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + const contextWithoutLora = await model.createContext({ + contextSize: 2048 + }); + const chatSessionWithoutLora = new LlamaChatSession({ + contextSequence: contextWithoutLora.getSequence() + }); + expect(chatSessionWithoutLora.chatWrapper).to.be.an.instanceof(Llama3ChatWrapper); + const resWithoutLora = await chatSessionWithoutLora.prompt(prompt); + expect(resWithoutLora).to.include("I cannot provide information"); + + await contextWithoutLora.dispose(); + + + const contextWithLora = await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + const chatSessionWithLora = new LlamaChatSession({ + contextSequence: contextWithLora.getSequence() + }); + expect(chatSessionWithLora.chatWrapper).to.be.an.instanceof(Llama3ChatWrapper); + const resWithLora = await chatSessionWithLora.prompt(prompt); + expect(resWithLora.length).to.be.greaterThanOrEqual(1); + expect(resWithLora).to.not.include("I cannot provide information"); + }); + + test("dispose context unloads lora", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); + const loraPath = await getModelFile("lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + const context = await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + + await context.dispose(); + }); + + test("using multiple contexts with lora", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); + const loraPath = await getModelFile("lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + const context = await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + const context2 = await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + + await context.dispose(); + + const context3 = await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + + await context2.dispose(); + await context3.dispose(); + }); + + test("unload model unloads lora", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); + const loraPath = await getModelFile("lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + + await model.dispose(); + }); + + test("implicitly unloading model and context with lora", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"); + const loraPath = await getModelFile("lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + await model.createContext({ + contextSize: 2048, + lora: { + adapters: [{ + filePath: loraPath + }] + } + }); + + await new Promise((resolve) => setTimeout(resolve, 1000 * 60 * 4)); + }); + }); +}); diff --git a/test/modelDependent/stableCode/metadataOverrides.test.ts b/test/modelDependent/stableCode/metadataOverrides.test.ts new file mode 100644 index 00000000..d4916c54 --- /dev/null +++ b/test/modelDependent/stableCode/metadataOverrides.test.ts @@ -0,0 +1,49 @@ +import {describe, expect, test} from "vitest"; +import {getModelFile} from "../../utils/modelFiles.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; + +describe("stableCode", () => { + describe("metadata overrides", () => { + test("boolean metadata override", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("stable-code-3b-Q5_K_M.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath, + metadataOverrides: { + tokenizer: { + ggml: { + "add_bos_token": false + } + } + } + }); + + expect(model.fileInfo.metadata.tokenizer.ggml.add_bos_token).to.eql(false); + expect(model.tokens.shouldPrependBosToken).to.eql(false); + + await model.dispose(); + }); + + test("boolean metadata override 2", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("stable-code-3b-Q5_K_M.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath, + metadataOverrides: { + tokenizer: { + ggml: { + "add_bos_token": true + } + } + } + }); + + expect(model.fileInfo.metadata.tokenizer.ggml.add_bos_token).to.eql(true); + expect(model.tokens.shouldPrependBosToken).to.eql(true); + + await model.dispose(); + }); + }); +}); diff --git a/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts b/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts index 51ea9963..8c76bf38 100644 --- a/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts +++ b/test/standalone/chatWrappers/FunctionaryChatWrapper.test.ts @@ -240,7 +240,7 @@ describe("FunctionaryChatWrapper", () => { "value": " ", }, - "{"min":1,"max":6}", + "{"min": 1, "max": 6}", { "type": "specialTokensText", "value": "<|reserved_special_token_249|>", @@ -251,7 +251,7 @@ describe("FunctionaryChatWrapper", () => { "value": " ", }, - "{"min":1,"max":6}", + "{"min": 1, "max": 6}", { "type": "specialToken", "value": "EOT", @@ -578,7 +578,7 @@ describe("FunctionaryChatWrapper", () => { "value": " <|content|>", }, - "{"min":1,"max":6}", + "{"min": 1, "max": 6}", { "type": "specialTokensText", "value": " @@ -591,7 +591,7 @@ describe("FunctionaryChatWrapper", () => { "value": " <|content|>", }, - "{"min":1,"max":6}", + "{"min": 1, "max": 6}", { "type": "specialTokensText", "value": "<|stop|> diff --git a/test/standalone/chatWrappers/GemmaChatWrapper.test.ts b/test/standalone/chatWrappers/GemmaChatWrapper.test.ts index e53a7151..a8ea8c58 100644 --- a/test/standalone/chatWrappers/GemmaChatWrapper.test.ts +++ b/test/standalone/chatWrappers/GemmaChatWrapper.test.ts @@ -37,6 +37,10 @@ describe("GemmaChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ + { + "type": "specialToken", + "value": "BOS", + }, { "type": "specialTokensText", "value": "user @@ -62,6 +66,10 @@ describe("GemmaChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ + { + "type": "specialToken", + "value": "BOS", + }, { "type": "specialTokensText", "value": "user @@ -110,6 +118,10 @@ describe("GemmaChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ + { + "type": "specialToken", + "value": "BOS", + }, { "type": "specialTokensText", "value": "user @@ -133,6 +145,10 @@ describe("GemmaChatWrapper", () => { expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ + { + "type": "specialToken", + "value": "BOS", + }, { "type": "specialTokensText", "value": "user diff --git a/test/standalone/chatWrappers/Llama3_1ChatWrapper.test.ts b/test/standalone/chatWrappers/Llama3_1ChatWrapper.test.ts new file mode 100644 index 00000000..0129bd8c --- /dev/null +++ b/test/standalone/chatWrappers/Llama3_1ChatWrapper.test.ts @@ -0,0 +1,354 @@ +import {describe, expect, test} from "vitest"; +import {ChatHistoryItem, ChatModelFunctions, Llama3_1ChatWrapper} from "../../../src/index.js"; +import {defaultChatSystemPrompt} from "../../../src/config.js"; + + +describe("Llama3_1ChatWrapper", () => { + const todayDate = new Date("2024-07-26T00:00:00Z"); + const conversationHistory: ChatHistoryItem[] = [ + ...(new Llama3_1ChatWrapper({todayDate})).generateInitialChatHistory({systemPrompt: defaultChatSystemPrompt}), { + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!"] + } + ]; + const conversationHistory2: ChatHistoryItem[] = [ + ...(new Llama3_1ChatWrapper({todayDate})).generateInitialChatHistory({systemPrompt: defaultChatSystemPrompt}), { + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!"] + }, { + type: "user", + text: "What is the time?" + }, { + type: "model", + response: [{ + type: "functionCall", + name: "getTime", + description: "Retrieve the current time", + params: { + hours: "24", + seconds: true + }, + result: "22:00:00" + }, "I'm good, how are you?"] + } + ]; + const conversationHistory2Functions: ChatModelFunctions = { + getTime: { + description: "Retrieve the current time", + params: { + type: "object", + properties: { + hours: { + enum: ["24", "12"] + }, + seconds: { + type: "boolean" + } + } + } + } + }; + + test("should generate valid context text", () => { + const chatWrapper = new Llama3_1ChatWrapper({todayDate}); + const {contextText} = chatWrapper.generateContextState({chatHistory: conversationHistory}); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>system<|end_header_id|> + + ", + }, + "Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + # Tool Instructions + - When looking for real time information use relevant functions if available + + + + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>user<|end_header_id|> + + ", + }, + "Hi there!", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "Hello!", + ] + `); + + const chatWrapper2 = new Llama3_1ChatWrapper({todayDate}); + const {contextText: contextText2} = chatWrapper2.generateContextState({ + chatHistory: conversationHistory2, + availableFunctions: conversationHistory2Functions + }); + + expect(contextText2.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>system<|end_header_id|> + + ", + }, + "Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + # Tool Instructions + - When looking for real time information use relevant functions if available + + + + You have access to the following functions: + + Use the function 'getTime' to: Retrieve the current time + {"name": "getTime", "description": "Retrieve the current time", "parameters": {"type": "object", "properties": {"hours": {"enum": ["24", "12"]}, "seconds": {"type": "boolean"}}}} + + + If you choose to call a function ONLY reply in the following format: + <{start_tag}={function_name}>{parameters}{end_tag} + where + + start_tag => \` a JSON dict with the function argument name as key and function argument value as value. + end_tag => \`\` + + Here is an example, + ", + { + "type": "specialTokensText", + "value": "", + }, + "{"example_name": "example_value"}", + { + "type": "specialTokensText", + "value": "", + }, + " + + Reminder: + - Function calls MUST follow the specified format + - Only call one function at a time + - Put the entire function call reply on one line + - Always add your sources when using search results to answer the user query + + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>user<|end_header_id|> + + ", + }, + "Hi there!", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "Hello!", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>user<|end_header_id|> + + ", + }, + "What is the time?", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "{"hours": "24", "seconds": true}", + { + "type": "specialTokensText", + "value": "<|eom_id|> + <|start_header_id|>ipython<|end_header_id|> + + ", + }, + ""22:00:00"", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "I'm good, how are you?", + ] + `); + + const chatWrapper3 = new Llama3_1ChatWrapper({todayDate}); + const {contextText: contextText3} = chatWrapper3.generateContextState({chatHistory: conversationHistory}); + const {contextText: contextText3WithOpenModelResponse} = chatWrapper3.generateContextState({ + chatHistory: [ + ...conversationHistory, + { + type: "model", + response: [] + } + ] + }); + + expect(contextText3.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>system<|end_header_id|> + + ", + }, + "Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + # Tool Instructions + - When looking for real time information use relevant functions if available + + + + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>user<|end_header_id|> + + ", + }, + "Hi there!", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "Hello!", + ] + `); + + expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>system<|end_header_id|> + + ", + }, + "Cutting Knowledge Date: December 2023 + Today Date: 26 Jul 2024 + + # Tool Instructions + - When looking for real time information use relevant functions if available + + + + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>user<|end_header_id|> + + ", + }, + "Hi there!", + { + "type": "specialToken", + "value": "EOT", + }, + { + "type": "specialTokensText", + "value": "<|start_header_id|>assistant<|end_header_id|> + + ", + }, + "Hello! + + ", + ] + `); + }); +}); diff --git a/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts b/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts index 6a75fc79..76ca5d48 100644 --- a/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts @@ -552,16 +552,16 @@ describe("JinjaTemplateChatWrapper", () => { "type": "specialTokensText", "value": "(", }, - "{"someKey":"someValue"}", + "{"someKey": "someValue"}", { "type": "specialTokensText", "value": ")", }, " - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax.", { "type": "specialTokensText", @@ -626,11 +626,11 @@ describe("JinjaTemplateChatWrapper", () => { \`\`\` Calling any of the provided functions can be done like this: - [[call: getSomeInfo({"someKey":"someValue"})]] + [[call: getSomeInfo({"someKey": "someValue"})]] - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax. ---- @@ -640,7 +640,7 @@ describe("JinjaTemplateChatWrapper", () => { "type": "specialTokensText", "value": " [/INST] ", }, - "Hello![[call: func2({"message":"Hello","feeling":"good","words":1})]] [[result: {"yes":true,"message":"ok"}]]", + "Hello![[call: func2({"message": "Hello", "feeling": "good", "words": 1})]] [[result: {"yes": true, "message": "ok"}]]", { "type": "specialTokensText", "value": " ", @@ -705,11 +705,11 @@ describe("JinjaTemplateChatWrapper", () => { Calling any of the provided functions can be done like this: - Call function: getSomeInfo with params {"someKey":"someValue"}. + Call function: getSomeInfo with params {"someKey": "someValue"}. - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax. ---- @@ -720,8 +720,8 @@ describe("JinjaTemplateChatWrapper", () => { "value": " [/INST] ", }, "Hello! - Call function: func2 with params {"message":"Hello","feeling":"good","words":1}. - Function result: {"yes":true,"message":"ok"} + Call function: func2 with params {"message": "Hello", "feeling": "good", "words": 1}. + Function result: {"yes": true, "message": "ok"} ", { "type": "specialTokensText", diff --git a/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts index fba0f7f6..9b6ac3eb 100644 --- a/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts @@ -361,16 +361,16 @@ describe("TemplateChatWrapper", () => { "type": "specialTokensText", "value": "(", }, - "{"someKey":"someValue"}", + "{"someKey": "someValue"}", { "type": "specialTokensText", "value": ")", }, " - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax.", { "type": "specialTokensText", @@ -424,11 +424,11 @@ describe("TemplateChatWrapper", () => { \`\`\` Calling any of the provided functions can be done like this: - [[call: getSomeInfo({"someKey":"someValue"})]] + [[call: getSomeInfo({"someKey": "someValue"})]] - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax.", { "type": "specialTokensText", @@ -441,7 +441,7 @@ describe("TemplateChatWrapper", () => { "value": " model: ", }, - "Hello![[call: func2({"message":"Hello","feeling":"good","words":1})]] [[result: {"yes":true,"message":"ok"}]]", + "Hello![[call: func2({"message": "Hello", "feeling": "good", "words": 1})]] [[result: {"yes": true, "message": "ok"}]]", { "type": "specialTokensText", "value": " @@ -494,11 +494,11 @@ describe("TemplateChatWrapper", () => { Calling any of the provided functions can be done like this: - Call function: getSomeInfo with params {"someKey":"someValue"}. + Call function: getSomeInfo with params {"someKey": "someValue"}. - Note that the || prefix is mandatory + Note that the || prefix is mandatory. The assistant does not inform the user about using functions and does not explain anything before calling a function. - After calling a function, the raw result appears afterwards and is not part of the conversation + After calling a function, the raw result appears afterwards and is not part of the conversation. To make information be part of the conversation, the assistant paraphrases and repeats the information without the function syntax.", { "type": "specialTokensText", @@ -512,8 +512,8 @@ describe("TemplateChatWrapper", () => { model: ", }, "Hello! - Call function: func2 with params {"message":"Hello","feeling":"good","words":1}. - Function result: {"yes":true,"message":"ok"} + Call function: func2 with params {"message": "Hello", "feeling": "good", "words": 1}. + Function result: {"yes": true, "message": "ok"} ", { "type": "specialTokensText", diff --git a/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts b/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts index 0f46ac71..38333d8c 100644 --- a/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts +++ b/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts @@ -1,7 +1,7 @@ import {describe, expect, test} from "vitest"; import { AlpacaChatWrapper, ChatMLChatWrapper, FalconChatWrapper, FunctionaryChatWrapper, GemmaChatWrapper, GeneralChatWrapper, - Llama2ChatWrapper, Llama3ChatWrapper, resolveChatWrapper + Llama2ChatWrapper, Llama3_1ChatWrapper, resolveChatWrapper } from "../../../../src/index.js"; @@ -120,7 +120,7 @@ const llama2ChatJinjaTemplate = ` {%- endfor -%} `.slice(1, -1); -const llama3ChatJinjaTemplate = ` +const llama3_1ChatJinjaTemplate = ` {%- set loop_messages = messages -%} {%- for message in loop_messages -%} {%- set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + eot_token -%} @@ -236,11 +236,11 @@ describe("resolveChatWrapper", () => { const chatWrapper = resolveChatWrapper({ customWrapperSettings: { jinjaTemplate: { - template: llama3ChatJinjaTemplate + template: llama3_1ChatJinjaTemplate } }, fallbackToOtherWrappersOnJinjaError: false }); - expect(chatWrapper).to.be.instanceof(Llama3ChatWrapper); + expect(chatWrapper).to.be.instanceof(Llama3_1ChatWrapper); }); }); diff --git a/test/utils/modelFiles.ts b/test/utils/modelFiles.ts index 502a0a56..280f1e5d 100644 --- a/test/utils/modelFiles.ts +++ b/test/utils/modelFiles.ts @@ -13,7 +13,8 @@ const supportedModels = { "functionary-small-v2.5.Q4_0.gguf": "https://huggingface.co/meetkai/functionary-small-v2.5-GGUF/resolve/main/functionary-small-v2.5.Q4_0.gguf?download=true", "stable-code-3b-Q5_K_M.gguf": "https://huggingface.co/stabilityai/stable-code-3b/resolve/main/stable-code-3b-Q5_K_M.gguf?download=true", "bge-small-en-v1.5-q8_0.gguf": "https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf/resolve/main/bge-small-en-v1.5-q8_0.gguf?download=true", - "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf": "https://huggingface.co/mradermacher/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf?download=true" + "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf": "https://huggingface.co/bartowski/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf?download=true", + "lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf": "https://huggingface.co/ngxson/test_gguf_lora_adapter/resolve/main/lora-Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf?download=true" } as const; export async function getModelFile(modelName: keyof typeof supportedModels) {