From e7006ac01d9e508630df9da355a974e3d2ae234c Mon Sep 17 00:00:00 2001 From: hhyqhh Date: Tue, 23 Jul 2024 16:48:16 +0800 Subject: [PATCH] Add Custom API Endpoint for Ollama --- +llms/+internal/callOllamaChatAPI.m | 4 +++- ollamaChat.m | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/+llms/+internal/callOllamaChatAPI.m b/+llms/+internal/callOllamaChatAPI.m index a7e6436..ad0ab02 100644 --- a/+llms/+internal/callOllamaChatAPI.m +++ b/+llms/+internal/callOllamaChatAPI.m @@ -37,9 +37,11 @@ nvp.Seed nvp.TimeOut nvp.StreamFun + nvp.URL = "http://localhost:11434/api/chat" % Add the URL parameter with default value end -URL = "http://localhost:11434/api/chat"; +% Move URL to input argument +URL = nvp.URL; % The JSON for StopSequences must have an array, and cannot say "stop": "foo". % The easiest way to ensure that is to never pass in a scalar … diff --git a/ollamaChat.m b/ollamaChat.m index 2538038..befe07a 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -67,6 +67,7 @@ Model (1,1) string TopK (1,1) {mustBeReal,mustBePositive} = Inf TailFreeSamplingZ (1,1) {mustBeReal} = 1 + Endpoint (1,1) string = "http://localhost:11434/api/chat" % Add default URL property end methods @@ -82,6 +83,7 @@ nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 120 nvp.TailFreeSamplingZ (1,1) {mustBeReal} = 1 nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} + nvp.Endpoint (1,1) string = "http://localhost:11434/api/chat" % Add Endpoint argument end if isfield(nvp,"StreamFun") @@ -105,6 +107,7 @@ this.TailFreeSamplingZ = nvp.TailFreeSamplingZ; this.StopSequences = nvp.StopSequences; this.TimeOut = nvp.TimeOut; + this.Endpoint = nvp.Endpoint; end function [text, message, response] = generate(this, messages, nvp) @@ -147,7 +150,8 @@ TailFreeSamplingZ=this.TailFreeSamplingZ,... StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... - TimeOut=this.TimeOut, StreamFun=this.StreamFun); + TimeOut=this.TimeOut, StreamFun=this.StreamFun, ... + URL=this.Endpoint); if isfield(response.Body.Data,"error") err = response.Body.Data.error;