Skip to content

add support for Custom API Endpoint #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - Endpoint
% - TimeOut
% - StreamFun
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -48,7 +49,7 @@
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey, Endpoint=endpoint)

% Copyright 2023-2024 The MathWorks, Inc.

Expand All @@ -67,15 +68,16 @@
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.Endpoint = ""
nvp.TimeOut = 10
nvp.StreamFun = []
end

END_POINT = "https://api.openai.com/v1/chat/completions";
// END_POINT = "https://api.openai.com/v1/chat/completions";

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, nvp.Endpoint, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
Expand Down
39 changes: 39 additions & 0 deletions +llms/+internal/getEndpointFromNvpOrEnv.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function endpoint = getEndpointFromNvpOrEnv(nvp, type='chat')
% This function is undocumented and will change in a future release

% getEndpointFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
%
% This function takes a struct nvp containing name-value pairs and checks
% if it contains a field called "Endpoint". If the field is not found,
% the function attempts to retrieve the API base_url from an environment
% variable called "OPENAI_API_BASE_URL". If both methods fail, the function
% throws an error.

% Copyright 2023 The MathWorks, Inc.

openai_api_base_url = 'https://api.openai.com/v1'
if isfield(nvp, "Endpoint")
endpoint = nvp.Endpoint;
else
if isenv("OPENAI_API_BASE_URL")
base_url = getenv("OPENAI_API_BASE_URL");
if ~startsWith(base_url, 'http')
base_url = openai_api_base_url;
end
else
endpoint = openai_api_base_url;
end

completions.chat = "/chat/completions";
completions.embeddings = "/embeddings";
completions.image_generate = "/images/generations";
completions.image_edits = "/images/edits";
completions.image_variations = "/images/variations";
if ~isfield(completions, type)
endpoint = strcat(base_url, completions.chat);
else
endpoint = strcat(base_url, completions.(type));
end

end
end
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Set up your OpenAI API key. Create a `.env` file in the project root directory w

```
OPENAI_API_KEY=<your key>
OPENAI_API_BASE_URL=<your base_url> (optional, format like: "https://api.openai.com/v1")
```

Then load your `.env` file as follows:
Expand Down
8 changes: 6 additions & 2 deletions extractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
%
% 'ApiKey' - OpenAI API token. It can also be specified by
% setting the environment variable OPENAI_API_KEY
% 'Endpoint' - OpenAI API base_url. It can also be specified by
% setting the environment variable OPENAI_API_BASE_URL
%
% 'TimeOut' - Connection Timeout in seconds (default: 10 secs)
%
Expand All @@ -29,11 +31,13 @@
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeInteger,mustBePositive}
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
nvp.Endpoint (1,1) string {mustBeTextScalar}
end

END_POINT = "https://api.openai.com/v1/embeddings";
// END_POINT = "https://api.openai.com/v1/embeddings";

key = llms.internal.getApiKeyFromNvpOrEnv(nvp);
endpoint = llms.internal.getEndpointFromNvpOrEnv(nvp, 'embeddings');

parameters = struct("input",text,"model",nvp.ModelName);

Expand All @@ -47,7 +51,7 @@
end


response = llms.internal.sendRequest(parameters,key, END_POINT, nvp.TimeOut);
response = llms.internal.sendRequest(parameters,key, endpoint, nvp.TimeOut);

if isfield(response.Body.Data, "data")
emb = [response.Body.Data.data.embedding];
Expand Down
5 changes: 4 additions & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
Tools
FunctionsStruct
ApiKey
Endpoint
StreamFun
end

Expand All @@ -123,6 +124,7 @@
nvp.StopSequences {mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
nvp.Endpoint (1,1) string {mustBeTextScalar}
nvp.PresencePenalty {mustBeValidPenalty} = 0
nvp.FrequencyPenalty {mustBeValidPenalty} = 0
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
Expand Down Expand Up @@ -171,6 +173,7 @@
this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
this.Endpoint = llms.internal.getEndpointFromNvpOrEnv(nvp, "chat");
this.TimeOut = nvp.TimeOut;
end

Expand Down Expand Up @@ -233,7 +236,7 @@
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
ApiKey=this.ApiKey,Endpoint=this.Endpoint,TimeOut=this.TimeOut, StreamFun=this.StreamFun);

if isfield(response.Body.Data,"error")
err = response.Body.Data.error.message;
Expand Down
11 changes: 8 additions & 3 deletions openAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@

properties (Access=private)
ApiKey
Endpoint
end

methods
function this = openAIImages(nvp)
arguments
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["dall-e-2", "dall-e-3"])} = "dall-e-2"
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
nvp.Endpoint (1,1) string {mustBeTextScalar}
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
end

Expand Down Expand Up @@ -96,7 +98,8 @@
nvp.Style (1,1) string {mustBeMember(nvp.Style,["vivid", "natural"])}
end

endpoint = "https://api.openai.com/v1/images/generations";
// endpoint = "https://api.openai.com/v1/images/generations";
endpoint = llms.internal.getEndpointFromNvpOrEnv(nvp, "image_generate");

validatePromptSize(this.ModelName, prompt)
validateSizeNVP(this.ModelName, nvp.Size)
Expand Down Expand Up @@ -194,7 +197,8 @@

validatePromptSize(this.ModelName, prompt)

endpoint = 'https://api.openai.com/v1/images/edits';
// endpoint = 'https://api.openai.com/v1/images/edits';
endpoint = llms.internal.getEndpointFromNvpOrEnv(nvp, "image_edits");

% Required params
numImages = num2str(nvp.NumImages);
Expand Down Expand Up @@ -252,7 +256,8 @@
this.ModelName));
end

endpoint = 'https://api.openai.com/v1/images/variations';
// endpoint = 'https://api.openai.com/v1/images/variations';
endpoint = llms.internal.getEndpointFromNvpOrEnv(nvp, "image_variations");

numImages = num2str(nvp.NumImages);
body = matlab.net.http.io.MultipartFormProvider(...
Expand Down