From d38f247e23340c491e599f9467921cd18ca44d74 Mon Sep 17 00:00:00 2001 From: Sameel Date: Wed, 23 Apr 2025 14:23:43 -0400 Subject: [PATCH 01/16] make ai sdk native --- evals/index.eval.ts | 2 +- examples/ai_sdk_example.ts | 7 +- examples/external_clients/aisdk.ts | 122 ------------ lib/index.ts | 12 +- lib/llm/LLMProvider.ts | 47 +++-- lib/llm/aisdk.ts | 285 +++++++++++++++++++++++++++++ package-lock.json | 99 ++++------ package.json | 2 +- types/model.ts | 3 +- types/stagehand.ts | 3 +- 10 files changed, 377 insertions(+), 205 deletions(-) delete mode 100644 examples/external_clients/aisdk.ts create mode 100644 lib/llm/aisdk.ts diff --git a/evals/index.eval.ts b/evals/index.eval.ts index 768a193f9..9f0dc474c 100644 --- a/evals/index.eval.ts +++ b/evals/index.eval.ts @@ -34,12 +34,12 @@ import { StagehandEvalError } from "@/types/stagehandErrors"; import { CustomOpenAIClient } from "@/examples/external_clients/customOpenAI"; import OpenAI from "openai"; import { initStagehand } from "./initStagehand"; -import { AISdkClient } from "@/examples/external_clients/aisdk"; import { google } from "@ai-sdk/google"; import { anthropic } from "@ai-sdk/anthropic"; import { groq } from "@ai-sdk/groq"; import { cerebras } from "@ai-sdk/cerebras"; import { openai } from "@ai-sdk/openai"; +import { AISdkClient } from "@/lib/llm/aisdk"; dotenv.config(); /** diff --git a/examples/ai_sdk_example.ts b/examples/ai_sdk_example.ts index b650de5c1..d03a4560c 100644 --- a/examples/ai_sdk_example.ts +++ b/examples/ai_sdk_example.ts @@ -1,15 +1,12 @@ -import { openai } from "@ai-sdk/openai"; import { Stagehand } from "@/dist"; -import { AISdkClient } from "./external_clients/aisdk"; import StagehandConfig from "@/stagehand.config"; +import { openai } from "@ai-sdk/openai"; import { z } from "zod"; async function example() { const stagehand = new Stagehand({ ...StagehandConfig, - llmClient: new AISdkClient({ - model: openai("gpt-4o"), - }), + modelName: openai("gpt-4o"), }); await stagehand.init(); diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts deleted file mode 100644 index 1d72d984f..000000000 --- a/examples/external_clients/aisdk.ts +++ /dev/null @@ -1,122 +0,0 @@ -import { - CoreAssistantMessage, - CoreMessage, - CoreSystemMessage, - CoreTool, - CoreUserMessage, - generateObject, - generateText, - ImagePart, - LanguageModel, - TextPart, -} from "ai"; -import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; -import { ChatCompletion } from "openai/resources"; - -export class AISdkClient extends LLMClient { - public type = "aisdk" as const; - private model: LanguageModel; - - constructor({ model }: { model: LanguageModel }) { - super(model.modelId as AvailableModel); - this.model = model; - } - - async createChatCompletion({ - options, - }: CreateChatCompletionOptions): Promise { - const formattedMessages: CoreMessage[] = options.messages.map((message) => { - if (Array.isArray(message.content)) { - if (message.role === "system") { - const systemMessage: CoreSystemMessage = { - role: "system", - content: message.content - .map((c) => ("text" in c ? c.text : "")) - .join("\n"), - }; - return systemMessage; - } - - const contentParts = message.content.map((content) => { - if ("image_url" in content) { - const imageContent: ImagePart = { - type: "image", - image: content.image_url.url, - }; - return imageContent; - } else { - const textContent: TextPart = { - type: "text", - text: content.text, - }; - return textContent; - } - }); - - if (message.role === "user") { - const userMessage: CoreUserMessage = { - role: "user", - content: contentParts, - }; - return userMessage; - } else { - const textOnlyParts = contentParts.map((part) => ({ - type: "text" as const, - text: part.type === "image" ? "[Image]" : part.text, - })); - const assistantMessage: CoreAssistantMessage = { - role: "assistant", - content: textOnlyParts, - }; - return assistantMessage; - } - } - - return { - role: message.role, - content: message.content, - }; - }); - - if (options.response_model) { - const response = await generateObject({ - model: this.model, - messages: formattedMessages, - schema: options.response_model.schema, - }); - - return { - data: response.object, - usage: { - prompt_tokens: response.usage.promptTokens ?? 0, - completion_tokens: response.usage.completionTokens ?? 0, - total_tokens: response.usage.totalTokens ?? 0, - }, - } as T; - } - - const tools: Record = {}; - - for (const rawTool of options.tools) { - tools[rawTool.name] = { - description: rawTool.description, - parameters: rawTool.parameters, - }; - } - - const response = await generateText({ - model: this.model, - messages: formattedMessages, - tools, - }); - - return { - data: response.text, - usage: { - prompt_tokens: response.usage.promptTokens ?? 0, - completion_tokens: response.usage.completionTokens ?? 0, - total_tokens: response.usage.totalTokens ?? 0, - }, - } as T; - } -} diff --git a/lib/index.ts b/lib/index.ts index a12b63a72..e118e57db 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -45,6 +45,7 @@ import { MissingEnvironmentVariableError, UnsupportedModelError, } from "../types/stagehandErrors"; +import { LanguageModel } from "ai"; dotenv.config({ path: ".env" }); @@ -384,7 +385,7 @@ export class Stagehand { public llmClient: LLMClient; public readonly userProvidedInstructions?: string; private usingAPI: boolean; - private modelName: AvailableModel; + private modelName: AvailableModel | LanguageModel; public apiClient: StagehandAPI | undefined; public readonly waitForCaptchaSolves: boolean; private localBrowserLaunchOptions?: LocalBrowserLaunchOptions; @@ -656,18 +657,23 @@ export class Stagehand { projectId: this.projectId, logger: this.logger, }); + const modelApiKey = + // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel LLMProvider.getModelProvider(this.modelName) === "openai" ? process.env.OPENAI_API_KEY || this.llmClient.clientOptions.apiKey - : LLMProvider.getModelProvider(this.modelName) === "anthropic" + : // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel + LLMProvider.getModelProvider(this.modelName) === "anthropic" ? process.env.ANTHROPIC_API_KEY || this.llmClient.clientOptions.apiKey - : LLMProvider.getModelProvider(this.modelName) === "google" + : // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel + LLMProvider.getModelProvider(this.modelName) === "google" ? process.env.GOOGLE_API_KEY || this.llmClient.clientOptions.apiKey : undefined; const { sessionId } = await this.apiClient.init({ + // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel modelName: this.modelName, modelApiKey: modelApiKey, domSettleTimeoutMs: this.domSettleTimeoutMs, diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index f8c1acd01..c3f270658 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -1,3 +1,8 @@ +import { + UnsupportedModelError, + UnsupportedModelProviderError, +} from "@/types/stagehandErrors"; +import { LanguageModel } from "ai"; import { LogLine } from "../../types/log"; import { AvailableModel, @@ -5,16 +10,26 @@ import { ModelProvider, } from "../../types/model"; import { LLMCache } from "../cache/LLMCache"; +import { AISdkClient } from "./aisdk"; import { AnthropicClient } from "./AnthropicClient"; import { CerebrasClient } from "./CerebrasClient"; import { GoogleClient } from "./GoogleClient"; import { GroqClient } from "./GroqClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; -import { - UnsupportedModelError, - UnsupportedModelProviderError, -} from "@/types/stagehandErrors"; + +function modelToProvider( + modelName: AvailableModel | LanguageModel, +): ModelProvider { + if (typeof modelName === "string") { + const provider = modelToProviderMap[modelName]; + if (!provider) { + throw new UnsupportedModelError(Object.keys(modelToProviderMap)); + } + return provider; + } + return "aisdk"; +} const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gpt-4.1": "openai", @@ -81,21 +96,31 @@ export class LLMProvider { } getClient( - modelName: AvailableModel, + modelName: AvailableModel | LanguageModel, clientOptions?: ClientOptions, ): LLMClient { - const provider = modelToProviderMap[modelName]; + const provider = modelToProvider(modelName); if (!provider) { throw new UnsupportedModelError(Object.keys(modelToProviderMap)); } + if (provider === "aisdk") { + return new AISdkClient({ + model: modelName as LanguageModel, + logger: this.logger, + enableCaching: this.enableCaching, + cache: this.cache, + }); + } + + const availableModel = modelName as AvailableModel; switch (provider) { case "openai": return new OpenAIClient({ logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, - modelName, + modelName: availableModel, clientOptions, }); case "anthropic": @@ -103,7 +128,7 @@ export class LLMProvider { logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, - modelName, + modelName: availableModel, clientOptions, }); case "cerebras": @@ -111,7 +136,7 @@ export class LLMProvider { logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, - modelName, + modelName: availableModel, clientOptions, }); case "groq": @@ -119,7 +144,7 @@ export class LLMProvider { logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, - modelName, + modelName: availableModel, clientOptions, }); case "google": @@ -127,7 +152,7 @@ export class LLMProvider { logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, - modelName, + modelName: availableModel, clientOptions, }); default: diff --git a/lib/llm/aisdk.ts b/lib/llm/aisdk.ts new file mode 100644 index 000000000..d4c8a218a --- /dev/null +++ b/lib/llm/aisdk.ts @@ -0,0 +1,285 @@ +import { + CoreAssistantMessage, + CoreMessage, + CoreSystemMessage, + CoreTool, + CoreUserMessage, + generateObject, + generateText, + ImagePart, + LanguageModel, + TextPart, +} from "ai"; +import { + CreateChatCompletionOptions, + LLMClient, + AvailableModel, + LogLine, +} from "@/dist"; +import { ChatCompletion } from "openai/resources"; +import { LLMCache } from "../cache/LLMCache"; + +export class AISdkClient extends LLMClient { + public type = "aisdk" as const; + private model: LanguageModel; + private logger?: (message: LogLine) => void; + private cache: LLMCache | undefined; + private enableCaching: boolean; + + constructor({ + model, + logger, + enableCaching = false, + cache, + }: { + model: LanguageModel; + logger?: (message: LogLine) => void; + enableCaching?: boolean; + cache?: LLMCache; + }) { + super(model.modelId as AvailableModel); + this.model = model; + this.logger = logger; + this.cache = cache; + this.enableCaching = enableCaching; + } + + async createChatCompletion({ + options, + }: CreateChatCompletionOptions): Promise { + this.logger({ + category: "aisdk", + message: "creating chat completion", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(options), + type: "object", + }, + modelName: { + value: this.model.modelId, + type: "string", + }, + }, + }); + + const cacheOptions = { + model: this.model.modelId, + messages: options.messages, + response_model: options.response_model, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + this.logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + }, + }); + return cachedResponse; + } else { + this.logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + } + } + + const formattedMessages: CoreMessage[] = options.messages.map((message) => { + if (Array.isArray(message.content)) { + if (message.role === "system") { + const systemMessage: CoreSystemMessage = { + role: "system", + content: message.content + .map((c) => ("text" in c ? c.text : "")) + .join("\n"), + }; + return systemMessage; + } + + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ImagePart = { + type: "image", + image: content.image_url.url, + }; + return imageContent; + } else { + const textContent: TextPart = { + type: "text", + text: content.text, + }; + return textContent; + } + }); + + if (message.role === "user") { + const userMessage: CoreUserMessage = { + role: "user", + content: contentParts, + }; + return userMessage; + } else { + const textOnlyParts = contentParts.map((part) => ({ + type: "text" as const, + text: part.type === "image" ? "[Image]" : part.text, + })); + const assistantMessage: CoreAssistantMessage = { + role: "assistant", + content: textOnlyParts, + }; + return assistantMessage; + } + } + + return { + role: message.role, + content: message.content, + }; + }); + + if (options.response_model) { + const response = await generateObject({ + model: this.model, + messages: formattedMessages, + schema: options.response_model.schema, + }); + + const result = { + data: response.object, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + + if (this.enableCaching) { + this.logger({ + category: "llm_cache", + message: "caching response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + response: { + value: JSON.stringify(result), + type: "object", + }, + }, + }); + this.cache.set(cacheOptions, result, options.requestId); + } + + this.logger({ + category: "aisdk", + message: "response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + return result; + } + + const tools: Record = {}; + + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + + const response = await generateText({ + model: this.model, + messages: formattedMessages, + tools, + }); + + const result = { + data: response.text, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + + if (this.enableCaching) { + this.logger({ + category: "llm_cache", + message: "caching response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + response: { + value: JSON.stringify(result), + type: "object", + }, + }, + }); + this.cache.set(cacheOptions, result, options.requestId); + } + + this.logger({ + category: "aisdk", + message: "response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + return result; + } +} diff --git a/package-lock.json b/package-lock.json index d6eb1a79f..696ced0a9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@anthropic-ai/sdk": "0.39.0", "@browserbasehq/sdk": "^2.4.0", "@google/genai": "^0.8.0", + "ai": "^4.3.9", "openai": "^4.87.1", "pino": "^9.6.0", "pino-pretty": "^13.0.0", @@ -36,7 +37,6 @@ "@types/node": "^20.11.30", "@types/ws": "^8.5.13", "adm-zip": "^0.5.16", - "ai": "^4.3.0", "autoevals": "^0.0.64", "braintrust": "^0.0.171", "chalk": "^5.4.1", @@ -357,14 +357,13 @@ } }, "node_modules/@ai-sdk/react": { - "version": "1.2.6", - "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.2.6.tgz", - "integrity": "sha512-5BFChNbcYtcY9MBStcDev7WZRHf0NpTrk8yfSoedWctB3jfWkFd1HECBvdc8w3mUQshF2MumLHtAhRO7IFtGGQ==", - "dev": true, + "version": "1.2.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.2.9.tgz", + "integrity": "sha512-/VYm8xifyngaqFDLXACk/1czDRCefNCdALUyp+kIX6DUIYUWTM93ISoZ+qJ8+3E+FiJAKBQz61o8lIIl+vYtzg==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "2.2.4", - "@ai-sdk/ui-utils": "1.2.5", + "@ai-sdk/provider-utils": "2.2.7", + "@ai-sdk/ui-utils": "1.2.8", "swr": "^2.2.5", "throttleit": "2.1.0" }, @@ -382,10 +381,9 @@ } }, "node_modules/@ai-sdk/react/node_modules/@ai-sdk/provider": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", - "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -395,13 +393,12 @@ } }, "node_modules/@ai-sdk/react/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.4", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", - "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", + "@ai-sdk/provider": "1.1.3", "nanoid": "^3.3.8", "secure-json-parse": "^2.7.0" }, @@ -651,14 +648,13 @@ } }, "node_modules/@ai-sdk/ui-utils": { - "version": "1.2.5", - "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.2.5.tgz", - "integrity": "sha512-XDgqnJcaCkDez7qolvk+PDbs/ceJvgkNkxkOlc9uDWqxfDJxtvCZ+14MP/1qr4IBwGIgKVHzMDYDXvqVhSWLzg==", - "dev": true, + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.2.8.tgz", + "integrity": "sha512-nls/IJCY+ks3Uj6G/agNhXqQeLVqhNfoJbuNgCny+nX2veY5ADB91EcZUqVeQ/ionul2SeUswPY6Q/DxteY29Q==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", - "@ai-sdk/provider-utils": "2.2.4", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7", "zod-to-json-schema": "^3.24.1" }, "engines": { @@ -669,10 +665,9 @@ } }, "node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", - "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -682,13 +677,12 @@ } }, "node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.4", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", - "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", + "@ai-sdk/provider": "1.1.3", "nanoid": "^3.3.8", "secure-json-parse": "^2.7.0" }, @@ -2312,7 +2306,6 @@ "version": "1.9.0", "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", - "dev": true, "license": "Apache-2.0", "engines": { "node": ">=8.0.0" @@ -2667,7 +2660,6 @@ "version": "1.0.36", "resolved": "https://registry.npmjs.org/@types/diff-match-patch/-/diff-match-patch-1.0.36.tgz", "integrity": "sha512-xFdR6tkm0MWvBfO8xXCSsinYxHcqkQUlcHeSpMC2ukzOb6lwQAfDmW+Qt0AvlGd8HpsS28qKsB+oPeJn9I39jg==", - "dev": true, "license": "MIT" }, "node_modules/@types/estree": { @@ -3201,16 +3193,15 @@ } }, "node_modules/ai": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/ai/-/ai-4.3.0.tgz", - "integrity": "sha512-PxyQYKhWaU3LiZEpeKRaekVonZIbWdKAwgnqm0CSAxy1MFufmYEC5SM5Mc9uiK2DoHcbAL3d1jyaQ2fSDAJL8w==", - "dev": true, + "version": "4.3.9", + "resolved": "https://registry.npmjs.org/ai/-/ai-4.3.9.tgz", + "integrity": "sha512-P2RpV65sWIPdUlA4f1pcJ11pB0N1YmqPVLEmC4j8WuBwKY0L3q9vGhYPh0Iv+spKHKyn0wUbMfas+7Z6nTfS0g==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", - "@ai-sdk/provider-utils": "2.2.4", - "@ai-sdk/react": "1.2.6", - "@ai-sdk/ui-utils": "1.2.5", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7", + "@ai-sdk/react": "1.2.9", + "@ai-sdk/ui-utils": "1.2.8", "@opentelemetry/api": "1.9.0", "jsondiffpatch": "0.6.0" }, @@ -3228,10 +3219,9 @@ } }, "node_modules/ai/node_modules/@ai-sdk/provider": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", - "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -3241,13 +3231,12 @@ } }, "node_modules/ai/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.4", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", - "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", + "@ai-sdk/provider": "1.1.3", "nanoid": "^3.3.8", "secure-json-parse": "^2.7.0" }, @@ -4422,7 +4411,6 @@ "version": "5.4.1", "resolved": "https://registry.npmjs.org/chalk/-/chalk-5.4.1.tgz", "integrity": "sha512-zgVZuo2WcZgfUEmsn6eO3kINexW8RAE4maiQ8QNs8CtpPCSyMiYsULR3HQYkm3w8FIA3SberyMJMSldGsW+U3w==", - "dev": true, "license": "MIT", "engines": { "node": "^12.17.0 || ^14.13 || >=16.0.0" @@ -4895,7 +4883,6 @@ "version": "2.0.3", "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", "integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==", - "dev": true, "license": "MIT", "engines": { "node": ">=6" @@ -4934,7 +4921,6 @@ "version": "1.0.5", "resolved": "https://registry.npmjs.org/diff-match-patch/-/diff-match-patch-1.0.5.tgz", "integrity": "sha512-IayShXAgj/QMXgB0IWmKx+rOPuGMhqm5w6jvFxmVenXKIzRqTAAsbBPT3kWQeGANj3jGgvcvv4yK6SxqYmikgw==", - "dev": true, "license": "Apache-2.0" }, "node_modules/digest-fetch": { @@ -6637,7 +6623,6 @@ "version": "0.4.0", "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", - "dev": true, "license": "(AFL-2.1 OR BSD-3-Clause)" }, "node_modules/json-schema-traverse": { @@ -6658,7 +6643,6 @@ "version": "0.6.0", "resolved": "https://registry.npmjs.org/jsondiffpatch/-/jsondiffpatch-0.6.0.tgz", "integrity": "sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==", - "dev": true, "license": "MIT", "dependencies": { "@types/diff-match-patch": "^1.0.36", @@ -7170,7 +7154,6 @@ "version": "3.3.9", "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.9.tgz", "integrity": "sha512-SppoicMGpZvbF1l3z4x7No3OlIjP7QJvC9XR7AhZr1kL133KHnKPztkKDc+Ir4aJ/1VhTySrtKhrsycmrMQfvg==", - "dev": true, "funding": [ { "type": "github", @@ -8049,7 +8032,6 @@ "version": "19.0.0", "resolved": "https://registry.npmjs.org/react/-/react-19.0.0.tgz", "integrity": "sha512-V8AVnmPIICiWpGfm6GLzCR/W5FXLchHop40W4nXBmdlEceh16rCN8O8LNWm5bh5XUX91fh7KpA+W0TgMKmgTpQ==", - "dev": true, "license": "MIT", "peer": true, "engines": { @@ -8777,7 +8759,6 @@ "version": "2.3.3", "resolved": "https://registry.npmjs.org/swr/-/swr-2.3.3.tgz", "integrity": "sha512-dshNvs3ExOqtZ6kJBaAsabhPdHyeY4P2cKwRCniDVifBMoG/SVI7tfLWqPXriVspf2Rg4tPzXJTnwaihIeFw2A==", - "dev": true, "license": "MIT", "dependencies": { "dequal": "^2.0.3", @@ -8853,7 +8834,6 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz", "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==", - "dev": true, "license": "MIT", "engines": { "node": ">=18" @@ -10077,7 +10057,6 @@ "version": "1.4.0", "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.4.0.tgz", "integrity": "sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==", - "dev": true, "license": "MIT", "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" diff --git a/package.json b/package.json index ad675a846..dc06924a6 100644 --- a/package.json +++ b/package.json @@ -63,7 +63,6 @@ "@types/node": "^20.11.30", "@types/ws": "^8.5.13", "adm-zip": "^0.5.16", - "ai": "^4.3.0", "autoevals": "^0.0.64", "braintrust": "^0.0.171", "chalk": "^5.4.1", @@ -91,6 +90,7 @@ "@anthropic-ai/sdk": "0.39.0", "@browserbasehq/sdk": "^2.4.0", "@google/genai": "^0.8.0", + "ai": "^4.3.9", "openai": "^4.87.1", "pino": "^9.6.0", "pino-pretty": "^13.0.0", diff --git a/types/model.ts b/types/model.ts index 447994b6d..cfb1aac0b 100644 --- a/types/model.ts +++ b/types/model.ts @@ -41,7 +41,8 @@ export type ModelProvider = | "anthropic" | "cerebras" | "groq" - | "google"; + | "google" + | "aisdk"; export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions; diff --git a/types/stagehand.ts b/types/stagehand.ts index 7fdc49a1d..cae8c40e1 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -6,6 +6,7 @@ import { AvailableModel, ClientOptions } from "./model"; import { LLMClient } from "../lib/llm/LLMClient"; import { Cookie } from "@playwright/test"; import { AgentProviderType } from "./agent"; +import { LanguageModel } from "ai"; export interface ConstructorParams { /** @@ -58,7 +59,7 @@ export interface ConstructorParams { /** * The model to use for Stagehand */ - modelName?: AvailableModel; + modelName?: AvailableModel | LanguageModel; /** * The LLM client to use for Stagehand */ From 12b889be772e99d89dced76105cdf287200203e1 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 23 Apr 2025 11:44:08 -0700 Subject: [PATCH 02/16] fix external client eval --- evals/llm_clients/hn_aisdk.ts | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/evals/llm_clients/hn_aisdk.ts b/evals/llm_clients/hn_aisdk.ts index 51ab4937e..fbf1d0ebc 100644 --- a/evals/llm_clients/hn_aisdk.ts +++ b/evals/llm_clients/hn_aisdk.ts @@ -1,5 +1,4 @@ import { Stagehand } from "@/dist"; -import { AISdkClient } from "@/examples/external_clients/aisdk"; import { EvalFunction } from "@/types/evals"; import { openai } from "@ai-sdk/openai/dist"; import { z } from "zod"; @@ -12,9 +11,7 @@ export const hn_aisdk: EvalFunction = async ({ }) => { const stagehand = new Stagehand({ ...stagehandConfig, - llmClient: new AISdkClient({ - model: openai("gpt-4o-mini"), - }), + modelName: openai("gpt-4o-mini"), }); await stagehand.init(); await stagehand.page.goto( From 07a271e90c56012c3ad6d47c09c6898b5ef9e901 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 23 Apr 2025 12:04:56 -0700 Subject: [PATCH 03/16] fix build --- lib/llm/aisdk.ts | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/llm/aisdk.ts b/lib/llm/aisdk.ts index d4c8a218a..a908cf92d 100644 --- a/lib/llm/aisdk.ts +++ b/lib/llm/aisdk.ts @@ -10,12 +10,9 @@ import { LanguageModel, TextPart, } from "ai"; -import { - CreateChatCompletionOptions, - LLMClient, - AvailableModel, - LogLine, -} from "@/dist"; +import { CreateChatCompletionOptions, LLMClient } from "./LLMClient"; +import { LogLine } from "../../types/log"; +import { AvailableModel } from "../../types/model"; import { ChatCompletion } from "openai/resources"; import { LLMCache } from "../cache/LLMCache"; From d1010caf92f9f70349150f4466b3cd739109fa95 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 23 Apr 2025 14:14:27 -0700 Subject: [PATCH 04/16] optional logging --- lib/llm/aisdk.ts | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/llm/aisdk.ts b/lib/llm/aisdk.ts index a908cf92d..3fc0a6453 100644 --- a/lib/llm/aisdk.ts +++ b/lib/llm/aisdk.ts @@ -44,7 +44,7 @@ export class AISdkClient extends LLMClient { async createChatCompletion({ options, }: CreateChatCompletionOptions): Promise { - this.logger({ + this.logger?.({ category: "aisdk", message: "creating chat completion", level: 2, @@ -72,7 +72,7 @@ export class AISdkClient extends LLMClient { options.requestId, ); if (cachedResponse) { - this.logger({ + this.logger?.({ category: "llm_cache", message: "LLM cache hit - returning cached response", level: 1, @@ -89,7 +89,7 @@ export class AISdkClient extends LLMClient { }); return cachedResponse; } else { - this.logger({ + this.logger?.({ category: "llm_cache", message: "LLM cache miss - no cached response found", level: 1, @@ -173,7 +173,7 @@ export class AISdkClient extends LLMClient { } as T; if (this.enableCaching) { - this.logger({ + this.logger?.({ category: "llm_cache", message: "caching response", level: 1, @@ -195,7 +195,7 @@ export class AISdkClient extends LLMClient { this.cache.set(cacheOptions, result, options.requestId); } - this.logger({ + this.logger?.({ category: "aisdk", message: "response", level: 2, @@ -239,7 +239,7 @@ export class AISdkClient extends LLMClient { } as T; if (this.enableCaching) { - this.logger({ + this.logger?.({ category: "llm_cache", message: "caching response", level: 1, @@ -261,7 +261,7 @@ export class AISdkClient extends LLMClient { this.cache.set(cacheOptions, result, options.requestId); } - this.logger({ + this.logger?.({ category: "aisdk", message: "response", level: 2, From ebd9b357330a30612af3d43cb40918111712bb19 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 23 Apr 2025 16:47:22 -0700 Subject: [PATCH 05/16] slash formatting for aisdk models --- evals/llm_clients/hn_aisdk.ts | 3 +- examples/ai_sdk_example.ts | 3 +- lib/index.ts | 11 ++---- lib/llm/LLMProvider.ts | 69 +++++++++++++++++++++++++++-------- types/model.ts | 25 +++++++++++++ types/stagehand.ts | 3 +- 6 files changed, 84 insertions(+), 30 deletions(-) diff --git a/evals/llm_clients/hn_aisdk.ts b/evals/llm_clients/hn_aisdk.ts index fbf1d0ebc..31248d7d8 100644 --- a/evals/llm_clients/hn_aisdk.ts +++ b/evals/llm_clients/hn_aisdk.ts @@ -1,6 +1,5 @@ import { Stagehand } from "@/dist"; import { EvalFunction } from "@/types/evals"; -import { openai } from "@ai-sdk/openai/dist"; import { z } from "zod"; export const hn_aisdk: EvalFunction = async ({ @@ -11,7 +10,7 @@ export const hn_aisdk: EvalFunction = async ({ }) => { const stagehand = new Stagehand({ ...stagehandConfig, - modelName: openai("gpt-4o-mini"), + modelName: "aisdk/openai/gpt-4o-mini", }); await stagehand.init(); await stagehand.page.goto( diff --git a/examples/ai_sdk_example.ts b/examples/ai_sdk_example.ts index d03a4560c..4161fbb0b 100644 --- a/examples/ai_sdk_example.ts +++ b/examples/ai_sdk_example.ts @@ -1,12 +1,11 @@ import { Stagehand } from "@/dist"; import StagehandConfig from "@/stagehand.config"; -import { openai } from "@ai-sdk/openai"; import { z } from "zod"; async function example() { const stagehand = new Stagehand({ ...StagehandConfig, - modelName: openai("gpt-4o"), + modelName: "aisdk/openai/gpt-4o", }); await stagehand.init(); diff --git a/lib/index.ts b/lib/index.ts index e118e57db..23b34728c 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -45,7 +45,6 @@ import { MissingEnvironmentVariableError, UnsupportedModelError, } from "../types/stagehandErrors"; -import { LanguageModel } from "ai"; dotenv.config({ path: ".env" }); @@ -385,7 +384,7 @@ export class Stagehand { public llmClient: LLMClient; public readonly userProvidedInstructions?: string; private usingAPI: boolean; - private modelName: AvailableModel | LanguageModel; + private modelName: AvailableModel; public apiClient: StagehandAPI | undefined; public readonly waitForCaptchaSolves: boolean; private localBrowserLaunchOptions?: LocalBrowserLaunchOptions; @@ -659,21 +658,17 @@ export class Stagehand { }); const modelApiKey = - // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel LLMProvider.getModelProvider(this.modelName) === "openai" ? process.env.OPENAI_API_KEY || this.llmClient.clientOptions.apiKey - : // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel - LLMProvider.getModelProvider(this.modelName) === "anthropic" + : LLMProvider.getModelProvider(this.modelName) === "anthropic" ? process.env.ANTHROPIC_API_KEY || this.llmClient.clientOptions.apiKey - : // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel - LLMProvider.getModelProvider(this.modelName) === "google" + : LLMProvider.getModelProvider(this.modelName) === "google" ? process.env.GOOGLE_API_KEY || this.llmClient.clientOptions.apiKey : undefined; const { sessionId } = await this.apiClient.init({ - // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel modelName: this.modelName, modelApiKey: modelApiKey, domSettleTimeoutMs: this.domSettleTimeoutMs, diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index c3f270658..04051561e 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -17,19 +17,9 @@ import { GoogleClient } from "./GoogleClient"; import { GroqClient } from "./GroqClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; - -function modelToProvider( - modelName: AvailableModel | LanguageModel, -): ModelProvider { - if (typeof modelName === "string") { - const provider = modelToProviderMap[modelName]; - if (!provider) { - throw new UnsupportedModelError(Object.keys(modelToProviderMap)); - } - return provider; - } - return "aisdk"; -} +import { openai } from "@ai-sdk/openai"; +import { anthropic } from "@ai-sdk/anthropic"; +import { google } from "@ai-sdk/google"; const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gpt-4.1": "openai", @@ -63,6 +53,31 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gemini-2.0-flash": "google", "gemini-2.5-flash-preview-04-17": "google", "gemini-2.5-pro-preview-03-25": "google", + "aisdk/anthropic/claude-3-5-sonnet-latest": "aisdk", + "aisdk/anthropic/claude-3-5-sonnet-20240620": "aisdk", + "aisdk/anthropicclaude-3-5-sonnet-20241022": "aisdk", + "aisdk/anthropic/claude-3-7-sonnet-20250219": "aisdk", + "aisdk/anthropic/claude-3-7-sonnet-latest": "aisdk", + "aisdk/google/gemini-1.5-flash": "aisdk", + "aisdk/google/gemini-1.5-pro": "aisdk", + "aisdk/google/gemini-1.5-flash-8b": "aisdk", + "aisdk/google/gemini-2.0-flash-lite": "aisdk", + "aisdk/google/gemini-2.0-flash": "aisdk", + "aisdk/google/gemini-2.5-flash-preview-04-17": "aisdk", + "aisdk/google/gemini-2.5-pro-preview-03-25": "aisdk", + "aisdk/openai/gpt-4.1": "aisdk", + "aisdk/openai/gpt-4.1-mini": "aisdk", + "aisdk/openai/gpt-4.1-nano": "aisdk", + "aisdk/openai/o4-mini": "aisdk", + "aisdk/openai/o3": "aisdk", + "aisdk/openai/o3-mini": "aisdk", + "aisdk/openai/o1": "aisdk", + "aisdk/openai/o1-mini": "aisdk", + "aisdk/openai/gpt-4o": "aisdk", + "aisdk/openai/gpt-4o-mini": "aisdk", + "aisdk/openai/gpt-4o-2024-08-06": "aisdk", + "aisdk/openai/gpt-4.5-preview": "aisdk", + "aisdk/openai/o1-preview": "aisdk", }; export class LLMProvider { @@ -96,17 +111,39 @@ export class LLMProvider { } getClient( - modelName: AvailableModel | LanguageModel, + modelName: AvailableModel, clientOptions?: ClientOptions, ): LLMClient { - const provider = modelToProvider(modelName); + const provider = modelToProviderMap[modelName]; if (!provider) { throw new UnsupportedModelError(Object.keys(modelToProviderMap)); } if (provider === "aisdk") { + const parts = modelName.split("/"); + if (parts.length !== 3) { + throw new Error(`Invalid aisdk model format: ${modelName}`); + } + + const [, subProvider, subModelName] = parts; + let languageModel: LanguageModel; + + switch (subProvider) { + case "openai": + languageModel = openai(subModelName); + break; + case "anthropic": + languageModel = anthropic(subModelName); + break; + case "google": + languageModel = google(subModelName); + break; + default: + throw new Error(`Unsupported aisdk sub-provider: ${subProvider}`); + } + return new AISdkClient({ - model: modelName as LanguageModel, + model: languageModel, logger: this.logger, enableCaching: this.enableCaching, cache: this.cache, diff --git a/types/model.ts b/types/model.ts index cfb1aac0b..0d2fae8f2 100644 --- a/types/model.ts +++ b/types/model.ts @@ -32,6 +32,31 @@ export const AvailableModelSchema = z.enum([ "gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", + "aisdk/anthropic/claude-3-5-sonnet-latest", + "aisdk/anthropic/claude-3-5-sonnet-20240620", + "aisdk/anthropicclaude-3-5-sonnet-20241022", + "aisdk/anthropic/claude-3-7-sonnet-20250219", + "aisdk/anthropic/claude-3-7-sonnet-latest", + "aisdk/google/gemini-1.5-flash", + "aisdk/google/gemini-1.5-pro", + "aisdk/google/gemini-1.5-flash-8b", + "aisdk/google/gemini-2.0-flash-lite", + "aisdk/google/gemini-2.0-flash", + "aisdk/google/gemini-2.5-flash-preview-04-17", + "aisdk/google/gemini-2.5-pro-preview-03-25", + "aisdk/openai/gpt-4.1", + "aisdk/openai/gpt-4.1-mini", + "aisdk/openai/gpt-4.1-nano", + "aisdk/openai/o4-mini", + "aisdk/openai/o3", + "aisdk/openai/o3-mini", + "aisdk/openai/o1", + "aisdk/openai/o1-mini", + "aisdk/openai/gpt-4o", + "aisdk/openai/gpt-4o-mini", + "aisdk/openai/gpt-4o-2024-08-06", + "aisdk/openai/gpt-4.5-preview", + "aisdk/openai/o1-preview", ]); export type AvailableModel = z.infer; diff --git a/types/stagehand.ts b/types/stagehand.ts index cae8c40e1..7fdc49a1d 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -6,7 +6,6 @@ import { AvailableModel, ClientOptions } from "./model"; import { LLMClient } from "../lib/llm/LLMClient"; import { Cookie } from "@playwright/test"; import { AgentProviderType } from "./agent"; -import { LanguageModel } from "ai"; export interface ConstructorParams { /** @@ -59,7 +58,7 @@ export interface ConstructorParams { /** * The model to use for Stagehand */ - modelName?: AvailableModel | LanguageModel; + modelName?: AvailableModel; /** * The LLM client to use for Stagehand */ From 6bdd0a952d7d681c3dc1a53c01e134578a21ed07 Mon Sep 17 00:00:00 2001 From: miguel Date: Wed, 23 Apr 2025 17:18:32 -0700 Subject: [PATCH 06/16] start proxying aisdk --- lib/llm/LLMClient.ts | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index e48f1117b..f5ce7781e 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -2,6 +2,7 @@ import { ZodType } from "zod"; import { LLMTool } from "../../types/llm"; import { LogLine } from "../../types/log"; import { AvailableModel, ClientOptions } from "../../types/model"; +import { generateObject, generateText, streamText, streamObject, experimental_generateImage, embed, embedMany, experimental_transcribe, experimental_generateSpeech } from "ai"; export interface ChatMessage { role: "system" | "user" | "assistant"; @@ -102,4 +103,14 @@ export abstract class LLMClient { usage?: LLMResponse["usage"]; }, >(options: CreateChatCompletionOptions): Promise; + + public generateObject = generateObject; + public generateText = generateText; + public streamText = streamText; + public streamObject = streamObject; + public generateImage = experimental_generateImage; + public embed = embed; + public embedMany = embedMany; + public transcribe = experimental_transcribe; + public generateSpeech = experimental_generateSpeech; } From 80caff55a1d3fe31018328f2dca94d6b17d92b1a Mon Sep 17 00:00:00 2001 From: Miguel <36487034+miguelg719@users.noreply.github.com> Date: Thu, 24 Apr 2025 10:29:47 -0700 Subject: [PATCH 07/16] change model-provider routing (#701) * avoid using model-provider map * prettier * remove duplicate dependency --- lib/llm/LLMClient.ts | 12 +- lib/llm/LLMProvider.ts | 77 ++++++---- package-lock.json | 341 ++++++++++++++++++++++++----------------- package.json | 20 ++- types/model.ts | 27 +--- 5 files changed, 267 insertions(+), 210 deletions(-) diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index f5ce7781e..524a71896 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -2,7 +2,17 @@ import { ZodType } from "zod"; import { LLMTool } from "../../types/llm"; import { LogLine } from "../../types/log"; import { AvailableModel, ClientOptions } from "../../types/model"; -import { generateObject, generateText, streamText, streamObject, experimental_generateImage, embed, embedMany, experimental_transcribe, experimental_generateSpeech } from "ai"; +import { + generateObject, + generateText, + streamText, + streamObject, + experimental_generateImage, + embed, + embedMany, + experimental_transcribe, + experimental_generateSpeech, +} from "ai"; export interface ChatMessage { role: "system" | "user" | "assistant"; diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 04051561e..f74a7b100 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -20,6 +20,15 @@ import { OpenAIClient } from "./OpenAIClient"; import { openai } from "@ai-sdk/openai"; import { anthropic } from "@ai-sdk/anthropic"; import { google } from "@ai-sdk/google"; +import { xai } from "@ai-sdk/xai"; +import { azure } from "@ai-sdk/azure"; +import { groq } from "@ai-sdk/groq"; +import { cerebras } from "@ai-sdk/cerebras"; +import { togetherai } from "@ai-sdk/togetherai"; +import { mistral } from "@ai-sdk/mistral"; +import { deepseek } from "@ai-sdk/deepseek"; +import { perplexity } from "@ai-sdk/perplexity"; +import { ollama } from "ollama-ai-provider"; const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gpt-4.1": "openai", @@ -53,31 +62,6 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gemini-2.0-flash": "google", "gemini-2.5-flash-preview-04-17": "google", "gemini-2.5-pro-preview-03-25": "google", - "aisdk/anthropic/claude-3-5-sonnet-latest": "aisdk", - "aisdk/anthropic/claude-3-5-sonnet-20240620": "aisdk", - "aisdk/anthropicclaude-3-5-sonnet-20241022": "aisdk", - "aisdk/anthropic/claude-3-7-sonnet-20250219": "aisdk", - "aisdk/anthropic/claude-3-7-sonnet-latest": "aisdk", - "aisdk/google/gemini-1.5-flash": "aisdk", - "aisdk/google/gemini-1.5-pro": "aisdk", - "aisdk/google/gemini-1.5-flash-8b": "aisdk", - "aisdk/google/gemini-2.0-flash-lite": "aisdk", - "aisdk/google/gemini-2.0-flash": "aisdk", - "aisdk/google/gemini-2.5-flash-preview-04-17": "aisdk", - "aisdk/google/gemini-2.5-pro-preview-03-25": "aisdk", - "aisdk/openai/gpt-4.1": "aisdk", - "aisdk/openai/gpt-4.1-mini": "aisdk", - "aisdk/openai/gpt-4.1-nano": "aisdk", - "aisdk/openai/o4-mini": "aisdk", - "aisdk/openai/o3": "aisdk", - "aisdk/openai/o3-mini": "aisdk", - "aisdk/openai/o1": "aisdk", - "aisdk/openai/o1-mini": "aisdk", - "aisdk/openai/gpt-4o": "aisdk", - "aisdk/openai/gpt-4o-mini": "aisdk", - "aisdk/openai/gpt-4o-2024-08-06": "aisdk", - "aisdk/openai/gpt-4.5-preview": "aisdk", - "aisdk/openai/o1-preview": "aisdk", }; export class LLMProvider { @@ -114,18 +98,12 @@ export class LLMProvider { modelName: AvailableModel, clientOptions?: ClientOptions, ): LLMClient { - const provider = modelToProviderMap[modelName]; - if (!provider) { - throw new UnsupportedModelError(Object.keys(modelToProviderMap)); - } - - if (provider === "aisdk") { + if (modelName.includes("/")) { const parts = modelName.split("/"); - if (parts.length !== 3) { + if (parts.length !== 2) { throw new Error(`Invalid aisdk model format: ${modelName}`); } - - const [, subProvider, subModelName] = parts; + const [subProvider, subModelName] = parts; let languageModel: LanguageModel; switch (subProvider) { @@ -138,6 +116,33 @@ export class LLMProvider { case "google": languageModel = google(subModelName); break; + case "xai": + languageModel = xai(subModelName); + break; + case "azure": + languageModel = azure(subModelName); + break; + case "groq": + languageModel = groq(subModelName); + break; + case "cerebras": + languageModel = cerebras(subModelName); + break; + case "togetherai": + languageModel = togetherai(subModelName); + break; + case "mistral": + languageModel = mistral(subModelName); + break; + case "deepseek": + languageModel = deepseek(subModelName); + break; + case "perplexity": + languageModel = perplexity(subModelName); + break; + case "ollama": + languageModel = ollama(subModelName); + break; default: throw new Error(`Unsupported aisdk sub-provider: ${subProvider}`); } @@ -150,6 +155,10 @@ export class LLMProvider { }); } + const provider = modelToProviderMap[modelName]; + if (!provider) { + throw new UnsupportedModelError(Object.keys(modelToProviderMap)); + } const availableModel = modelName as AvailableModel; switch (provider) { case "openai": diff --git a/package-lock.json b/package-lock.json index 696ced0a9..8c8bf63f8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20,12 +20,6 @@ "zod-to-json-schema": "^3.23.5" }, "devDependencies": { - "@ai-sdk/anthropic": "^1.2.6", - "@ai-sdk/cerebras": "^0.2.6", - "@ai-sdk/google": "^1.2.6", - "@ai-sdk/groq": "^1.2.4", - "@ai-sdk/openai": "^1.0.14", - "@ai-sdk/togetherai": "^0.2.6", "@changesets/changelog-github": "^0.5.0", "@changesets/cli": "^2.27.9", "@eslint/js": "^9.16.0", @@ -54,6 +48,20 @@ "typescript": "^5.2.2", "typescript-eslint": "^8.17.0" }, + "optionalDependencies": { + "@ai-sdk/anthropic": "^1.2.6", + "@ai-sdk/azure": "^1.3.19", + "@ai-sdk/cerebras": "^0.2.6", + "@ai-sdk/deepseek": "^0.2.13", + "@ai-sdk/google": "^1.2.6", + "@ai-sdk/groq": "^1.2.4", + "@ai-sdk/mistral": "^1.2.7", + "@ai-sdk/openai": "^1.0.14", + "@ai-sdk/perplexity": "^1.1.7", + "@ai-sdk/togetherai": "^0.2.6", + "@ai-sdk/xai": "^1.2.15", + "ollama-ai-provider": "^1.2.0" + }, "peerDependencies": { "@playwright/test": "^1.42.1", "deepmerge": "^4.3.1", @@ -65,8 +73,8 @@ "version": "1.2.6", "resolved": "https://registry.npmjs.org/@ai-sdk/anthropic/-/anthropic-1.2.6.tgz", "integrity": "sha512-Mt8ZSkhwnKHfwPPIviv3xgRE/nch2Mu4Fdh7oJDJvPDRJ6tNidCJd3TMwdlrlzPskF7hxCmXmd36yBgZZgt4cA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "@ai-sdk/provider-utils": "2.2.4" @@ -82,8 +90,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -95,8 +103,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -109,12 +117,30 @@ "zod": "^3.23.8" } }, + "node_modules/@ai-sdk/azure": { + "version": "1.3.19", + "resolved": "https://registry.npmjs.org/@ai-sdk/azure/-/azure-1.3.19.tgz", + "integrity": "sha512-XYEa2r7/4UzuXvoTulTRpQ8QWcG5TxO32l0hM7JXz9w4FGXxVGVP4JVKH9p8M9ZLHhDx60JtoL9vg+Yvv8wwPg==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/openai": "1.3.18", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, "node_modules/@ai-sdk/cerebras": { "version": "0.2.6", "resolved": "https://registry.npmjs.org/@ai-sdk/cerebras/-/cerebras-0.2.6.tgz", "integrity": "sha512-XpVqq5462HyQPT55Ptpdb0pwti3fbLOZGDeWgxkSwJTzC+fCzDLZFLJKDCINhiMzD+8CAQPQ/qm9+inFZaF0Og==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/openai-compatible": "0.2.6", "@ai-sdk/provider": "1.1.0", @@ -131,8 +157,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -144,8 +170,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -158,12 +184,47 @@ "zod": "^3.23.8" } }, + "node_modules/@ai-sdk/deepseek": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/@ai-sdk/deepseek/-/deepseek-0.2.13.tgz", + "integrity": "sha512-+Vw+nMdypupRfQb37wT1YmNNAROhCBqVRhHule3dk8A26N/W8GlAfVwLiae1/fK275UddbQWpUTMzvrx7n9HDg==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/openai-compatible": "0.2.13", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, + "node_modules/@ai-sdk/deepseek/node_modules/@ai-sdk/openai-compatible": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai-compatible/-/openai-compatible-0.2.13.tgz", + "integrity": "sha512-tB+lL8Z3j0qDod/mvxwjrPhbLUHp/aQW+NvMoJaqeTtP+Vmv5qR800pncGczxn5WN0pllQm+7aIRDnm69XeSbg==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, "node_modules/@ai-sdk/google": { "version": "1.2.6", "resolved": "https://registry.npmjs.org/@ai-sdk/google/-/google-1.2.6.tgz", "integrity": "sha512-e6vl+hmz7xZzWmsZZkLv89TZc19Vjgqj+RgvJNg03npRiuG4f1R/He1PD/JX6f0az//Y55CcozCcaj4vnMz6gQ==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "@ai-sdk/provider-utils": "2.2.4" @@ -179,8 +240,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -192,8 +253,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -210,8 +271,8 @@ "version": "1.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/groq/-/groq-1.2.4.tgz", "integrity": "sha512-jeO/tO8lGpk7L/zpPPSeZ6tMYwcXq2LbPEgmvTDhdrSj4AwYhL9WiEmZerixsKzNuOxjAzP0QZOns1SpsGaC2A==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "@ai-sdk/provider-utils": "2.2.4" @@ -227,8 +288,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -240,8 +301,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -254,15 +315,32 @@ "zod": "^3.23.8" } }, + "node_modules/@ai-sdk/mistral": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/mistral/-/mistral-1.2.7.tgz", + "integrity": "sha512-MbOMGfnHKcsvjbv4d6OT7Oaz+Wp4jD8yityqC4hASoKoW1s7L52woz25ES8RgAgTRlfbEZ3MOxEzLu58I228bQ==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, "node_modules/@ai-sdk/openai": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.2.3.tgz", - "integrity": "sha512-iIJKMjKYZN3XWVECDassufz3X7rq/b5BQ6Uhnp03i06T8E4QwCamwtLitJXtrqQ+OxL33Ugn3EIZKGaSBo/+qw==", - "dev": true, + "version": "1.3.18", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.3.18.tgz", + "integrity": "sha512-gqOHTOu62Tm2r4yDQx/Z5tWAgUrcTK8wXnC4A8zF/VOCzIjJDxxPsqJRTtQTMgIdGXhwmsv2sZ2PzvvuLeZeEg==", "license": "Apache-2.0", + "optional": true, "dependencies": { - "@ai-sdk/provider": "1.0.10", - "@ai-sdk/provider-utils": "2.1.12" + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" }, "engines": { "node": ">=18" @@ -275,8 +353,8 @@ "version": "0.2.6", "resolved": "https://registry.npmjs.org/@ai-sdk/openai-compatible/-/openai-compatible-0.2.6.tgz", "integrity": "sha512-UOPEWIqG3l5K9O+p7gqiCOWzx66JtmG9v9Mab+S4E7WE34EN6u1QS1pX+RDlRDhZ0/8gNJif0r4Xlc+Ti03yNA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "@ai-sdk/provider-utils": "2.2.4" @@ -292,8 +370,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -305,8 +383,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -319,11 +397,27 @@ "zod": "^3.23.8" } }, + "node_modules/@ai-sdk/perplexity": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/perplexity/-/perplexity-1.1.7.tgz", + "integrity": "sha512-FH2zEADLU/NTuRkQXMbZkUZ0qSsJ5qhufQ+7IsFMuhhKShGt0M8gOZlnkxuolnIjDrOdD3r1r59nZKMsFHuwqw==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, "node_modules/@ai-sdk/provider": { - "version": "1.0.10", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.0.10.tgz", - "integrity": "sha512-pco8Zl9U0xwXI+nCLc0woMtxbvjU8hRmGTseAUiPHFLYAAL8trRPCukg69IDeinOvIeo1SmXxAIdWWPZOLb4Cg==", - "dev": true, + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -333,14 +427,12 @@ } }, "node_modules/@ai-sdk/provider-utils": { - "version": "2.1.12", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.1.12.tgz", - "integrity": "sha512-NLm2Ypkv419jR5TNOvZ057ciSYFKzSDEIIwE8cRyeR1Y5RbuX+auZveqGg6GWsDzvUnn6Xra7BJmr0422v60UA==", - "dev": true, + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.0.10", - "eventsource-parser": "^3.0.0", + "@ai-sdk/provider": "1.1.3", "nanoid": "^3.3.8", "secure-json-parse": "^2.7.0" }, @@ -348,12 +440,7 @@ "node": ">=18" }, "peerDependencies": { - "zod": "^3.0.0" - }, - "peerDependenciesMeta": { - "zod": { - "optional": true - } + "zod": "^3.23.8" } }, "node_modules/@ai-sdk/react": { @@ -380,35 +467,6 @@ } } }, - "node_modules/@ai-sdk/react/node_modules/@ai-sdk/provider": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", - "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/react/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.7", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", - "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "1.1.3", - "nanoid": "^3.3.8", - "secure-json-parse": "^2.7.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.23.8" - } - }, "node_modules/@ai-sdk/solid": { "version": "0.0.54", "resolved": "https://registry.npmjs.org/@ai-sdk/solid/-/solid-0.0.54.tgz", @@ -602,8 +660,8 @@ "version": "0.2.6", "resolved": "https://registry.npmjs.org/@ai-sdk/togetherai/-/togetherai-0.2.6.tgz", "integrity": "sha512-AV3CABMKlIniCe5owr6H/kSirfk3Y/MeBAetrNJxRDhmrxa5VXzbWeMxS5xeS8crqFXWJPPLcwPiwOuFtpQMrA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/openai-compatible": "0.2.6", "@ai-sdk/provider": "1.1.0", @@ -620,8 +678,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "json-schema": "^0.4.0" }, @@ -633,8 +691,8 @@ "version": "2.2.4", "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.4.tgz", "integrity": "sha512-13sEGBxB6kgaMPGOgCLYibF6r8iv8mgjhuToFrOTU09bBxbFQd8ZoARarCfJN6VomCUbUvMKwjTBLb1vQnN+WA==", - "dev": true, "license": "Apache-2.0", + "optional": true, "dependencies": { "@ai-sdk/provider": "1.1.0", "nanoid": "^3.3.8", @@ -664,35 +722,6 @@ "zod": "^3.23.8" } }, - "node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", - "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.7", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", - "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "1.1.3", - "nanoid": "^3.3.8", - "secure-json-parse": "^2.7.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.23.8" - } - }, "node_modules/@ai-sdk/vue": { "version": "0.0.59", "resolved": "https://registry.npmjs.org/@ai-sdk/vue/-/vue-0.0.59.tgz", @@ -788,6 +817,41 @@ "node": ">=14.18" } }, + "node_modules/@ai-sdk/xai": { + "version": "1.2.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/xai/-/xai-1.2.15.tgz", + "integrity": "sha512-18qEYyVHIqTiOMePE00bfx4kJrTHM4dV3D3Rpe+eBISlY80X1FnzZRnRTJo3Q6MOSmW5+ZKVaX9jtryhoFpn0A==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/openai-compatible": "0.2.13", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, + "node_modules/@ai-sdk/xai/node_modules/@ai-sdk/openai-compatible": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai-compatible/-/openai-compatible-0.2.13.tgz", + "integrity": "sha512-tB+lL8Z3j0qDod/mvxwjrPhbLUHp/aQW+NvMoJaqeTtP+Vmv5qR800pncGczxn5WN0pllQm+7aIRDnm69XeSbg==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, "node_modules/@ampproject/remapping": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", @@ -3218,35 +3282,6 @@ } } }, - "node_modules/ai/node_modules/@ai-sdk/provider": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", - "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "node_modules/ai/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.7", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", - "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "1.1.3", - "nanoid": "^3.3.8", - "secure-json-parse": "^2.7.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.23.8" - } - }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -5570,16 +5605,6 @@ "dev": true, "license": "MIT" }, - "node_modules/eventsource-parser": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.0.tgz", - "integrity": "sha512-T1C0XCUimhxVQzW4zFipdx0SficT651NnkR0ZSH3yQwh+mFMdLfgjABVi4YtMTtaL4s168593DaoaRLMqryavA==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=18.0.0" - } - }, "node_modules/express": { "version": "4.21.2", "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", @@ -7260,6 +7285,29 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/ollama-ai-provider": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/ollama-ai-provider/-/ollama-ai-provider-1.2.0.tgz", + "integrity": "sha512-jTNFruwe3O/ruJeppI/quoOUxG7NA6blG3ZyQj3lei4+NnJo7bi3eIRWqlVpRlu/mbzbFXeJSBuYQWF6pzGKww==", + "license": "Apache-2.0", + "optional": true, + "dependencies": { + "@ai-sdk/provider": "^1.0.0", + "@ai-sdk/provider-utils": "^2.0.0", + "partial-json": "0.1.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, "node_modules/on-exit-leak-free": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/on-exit-leak-free/-/on-exit-leak-free-2.1.2.tgz", @@ -7577,6 +7625,13 @@ "node": ">= 0.8" } }, + "node_modules/partial-json": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/partial-json/-/partial-json-0.1.7.tgz", + "integrity": "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA==", + "license": "MIT", + "optional": true + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", diff --git a/package.json b/package.json index dc06924a6..575fe1d17 100644 --- a/package.json +++ b/package.json @@ -46,12 +46,6 @@ "author": "Browserbase", "license": "MIT", "devDependencies": { - "@ai-sdk/anthropic": "^1.2.6", - "@ai-sdk/cerebras": "^0.2.6", - "@ai-sdk/google": "^1.2.6", - "@ai-sdk/groq": "^1.2.4", - "@ai-sdk/openai": "^1.0.14", - "@ai-sdk/togetherai": "^0.2.6", "@changesets/changelog-github": "^0.5.0", "@changesets/cli": "^2.27.9", "@eslint/js": "^9.16.0", @@ -97,6 +91,20 @@ "ws": "^8.18.0", "zod-to-json-schema": "^3.23.5" }, + "optionalDependencies": { + "@ai-sdk/anthropic": "^1.2.6", + "@ai-sdk/azure": "^1.3.19", + "@ai-sdk/cerebras": "^0.2.6", + "@ai-sdk/deepseek": "^0.2.13", + "@ai-sdk/google": "^1.2.6", + "@ai-sdk/groq": "^1.2.4", + "@ai-sdk/mistral": "^1.2.7", + "@ai-sdk/openai": "^1.0.14", + "@ai-sdk/perplexity": "^1.1.7", + "@ai-sdk/togetherai": "^0.2.6", + "@ai-sdk/xai": "^1.2.15", + "ollama-ai-provider": "^1.2.0" + }, "directories": { "doc": "docs", "example": "examples", diff --git a/types/model.ts b/types/model.ts index 0d2fae8f2..bdd9324b2 100644 --- a/types/model.ts +++ b/types/model.ts @@ -32,34 +32,9 @@ export const AvailableModelSchema = z.enum([ "gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", - "aisdk/anthropic/claude-3-5-sonnet-latest", - "aisdk/anthropic/claude-3-5-sonnet-20240620", - "aisdk/anthropicclaude-3-5-sonnet-20241022", - "aisdk/anthropic/claude-3-7-sonnet-20250219", - "aisdk/anthropic/claude-3-7-sonnet-latest", - "aisdk/google/gemini-1.5-flash", - "aisdk/google/gemini-1.5-pro", - "aisdk/google/gemini-1.5-flash-8b", - "aisdk/google/gemini-2.0-flash-lite", - "aisdk/google/gemini-2.0-flash", - "aisdk/google/gemini-2.5-flash-preview-04-17", - "aisdk/google/gemini-2.5-pro-preview-03-25", - "aisdk/openai/gpt-4.1", - "aisdk/openai/gpt-4.1-mini", - "aisdk/openai/gpt-4.1-nano", - "aisdk/openai/o4-mini", - "aisdk/openai/o3", - "aisdk/openai/o3-mini", - "aisdk/openai/o1", - "aisdk/openai/o1-mini", - "aisdk/openai/gpt-4o", - "aisdk/openai/gpt-4o-mini", - "aisdk/openai/gpt-4o-2024-08-06", - "aisdk/openai/gpt-4.5-preview", - "aisdk/openai/o1-preview", ]); -export type AvailableModel = z.infer; +export type AvailableModel = z.infer | string; export type ModelProvider = | "openai" From 29f6634d6d66c150fd81a63da25e657d96bc8d58 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 25 Apr 2025 15:34:38 -0700 Subject: [PATCH 08/16] add aisdk provider map --- lib/index.ts | 6 +++- lib/llm/LLMProvider.ts | 71 ++++++++++++++++------------------------ types/llm.ts | 4 +++ types/stagehandErrors.ts | 8 +++++ 4 files changed, 46 insertions(+), 43 deletions(-) diff --git a/lib/index.ts b/lib/index.ts index 23b34728c..57c0b2734 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -44,6 +44,7 @@ import { StagehandEnvironmentError, MissingEnvironmentVariableError, UnsupportedModelError, + UnsupportedAISDKModelProviderError, } from "../types/stagehandErrors"; dotenv.config({ path: ".env" }); @@ -543,7 +544,10 @@ export class Stagehand { modelName ?? DEFAULT_MODEL_NAME, modelClientOptions, ); - } catch { + } catch (error) { + if (error instanceof UnsupportedAISDKModelProviderError) { + throw error; + } this.llmClient = undefined; } } diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index f74a7b100..5d6698d47 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -1,8 +1,8 @@ import { + UnsupportedAISDKModelProviderError, UnsupportedModelError, UnsupportedModelProviderError, } from "@/types/stagehandErrors"; -import { LanguageModel } from "ai"; import { LogLine } from "../../types/log"; import { AvailableModel, @@ -29,6 +29,22 @@ import { mistral } from "@ai-sdk/mistral"; import { deepseek } from "@ai-sdk/deepseek"; import { perplexity } from "@ai-sdk/perplexity"; import { ollama } from "ollama-ai-provider"; +import { AISDKProvider } from "@/types/llm"; + +const AISDKProviders: Record = { + openai, + anthropic, + google, + xai, + azure, + groq, + cerebras, + togetherai, + mistral, + deepseek, + perplexity, + ollama, +}; const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gpt-4.1": "openai", @@ -104,48 +120,8 @@ export class LLMProvider { throw new Error(`Invalid aisdk model format: ${modelName}`); } const [subProvider, subModelName] = parts; - let languageModel: LanguageModel; - switch (subProvider) { - case "openai": - languageModel = openai(subModelName); - break; - case "anthropic": - languageModel = anthropic(subModelName); - break; - case "google": - languageModel = google(subModelName); - break; - case "xai": - languageModel = xai(subModelName); - break; - case "azure": - languageModel = azure(subModelName); - break; - case "groq": - languageModel = groq(subModelName); - break; - case "cerebras": - languageModel = cerebras(subModelName); - break; - case "togetherai": - languageModel = togetherai(subModelName); - break; - case "mistral": - languageModel = mistral(subModelName); - break; - case "deepseek": - languageModel = deepseek(subModelName); - break; - case "perplexity": - languageModel = perplexity(subModelName); - break; - case "ollama": - languageModel = ollama(subModelName); - break; - default: - throw new Error(`Unsupported aisdk sub-provider: ${subProvider}`); - } + const languageModel = getAISDKLanguageModel(subProvider, subModelName); return new AISdkClient({ model: languageModel, @@ -155,6 +131,17 @@ export class LLMProvider { }); } + function getAISDKLanguageModel(subProvider: string, subModelName: string) { + const aiSDKLanguageModel = AISDKProviders[subProvider]; + if (!aiSDKLanguageModel) { + throw new UnsupportedAISDKModelProviderError( + subProvider, + Object.keys(AISDKProviders), + ); + } + return aiSDKLanguageModel(subModelName); + } + const provider = modelToProviderMap[modelName]; if (!provider) { throw new UnsupportedModelError(Object.keys(modelToProviderMap)); diff --git a/types/llm.ts b/types/llm.ts index f383b97ea..738b73a72 100644 --- a/types/llm.ts +++ b/types/llm.ts @@ -1,6 +1,10 @@ +import { LanguageModel } from "ai"; + export interface LLMTool { type: "function"; name: string; description: string; parameters: Record; } + +export type AISDKProvider = (modelName: string) => LanguageModel; diff --git a/types/stagehandErrors.ts b/types/stagehandErrors.ts index 99b4ad5b0..208b8816d 100644 --- a/types/stagehandErrors.ts +++ b/types/stagehandErrors.ts @@ -57,6 +57,14 @@ export class UnsupportedModelProviderError extends StagehandError { } } +export class UnsupportedAISDKModelProviderError extends StagehandError { + constructor(provider: string, supportedProviders: string[]) { + super( + `${provider} is not currently supported for aiSDK. please use one of the supported model providers: ${supportedProviders}`, + ); + } +} + export class StagehandNotInitializedError extends StagehandError { constructor(prop: string) { super( From 4d5b913a79d83cabc2fd43be67f529b49e412e2d Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 25 Apr 2025 15:45:05 -0700 Subject: [PATCH 09/16] add informed error message --- lib/index.ts | 6 +++++- lib/llm/LLMProvider.ts | 3 ++- types/stagehandErrors.ts | 8 ++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/index.ts b/lib/index.ts index 57c0b2734..ff2765e99 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -45,6 +45,7 @@ import { MissingEnvironmentVariableError, UnsupportedModelError, UnsupportedAISDKModelProviderError, + InvalidAISDKModelFormatError, } from "../types/stagehandErrors"; dotenv.config({ path: ".env" }); @@ -545,7 +546,10 @@ export class Stagehand { modelClientOptions, ); } catch (error) { - if (error instanceof UnsupportedAISDKModelProviderError) { + if ( + error instanceof UnsupportedAISDKModelProviderError || + error instanceof InvalidAISDKModelFormatError + ) { throw error; } this.llmClient = undefined; diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 5d6698d47..f2cf8bf64 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -1,4 +1,5 @@ import { + InvalidAISDKModelFormatError, UnsupportedAISDKModelProviderError, UnsupportedModelError, UnsupportedModelProviderError, @@ -117,7 +118,7 @@ export class LLMProvider { if (modelName.includes("/")) { const parts = modelName.split("/"); if (parts.length !== 2) { - throw new Error(`Invalid aisdk model format: ${modelName}`); + throw new InvalidAISDKModelFormatError(modelName); } const [subProvider, subModelName] = parts; diff --git a/types/stagehandErrors.ts b/types/stagehandErrors.ts index 208b8816d..5c15bbc28 100644 --- a/types/stagehandErrors.ts +++ b/types/stagehandErrors.ts @@ -65,6 +65,14 @@ export class UnsupportedAISDKModelProviderError extends StagehandError { } } +export class InvalidAISDKModelFormatError extends StagehandError { + constructor(modelName: string) { + super( + `${modelName} does not follow correct format for specifying aiSDK models. Please define your modelName as 'provider/model-name'. For example: \`modelName: 'openai/gpt-4o-mini'\``, + ); + } +} + export class StagehandNotInitializedError extends StagehandError { constructor(prop: string) { super( From 90a61cc15229d34d1c52b402f66e4ccbda9f7688 Mon Sep 17 00:00:00 2001 From: Sean McGuire <75873287+seanmcguire12@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:46:22 -0700 Subject: [PATCH 10/16] Apply suggestions from code review Co-authored-by: Miguel <36487034+miguelg719@users.noreply.github.com> --- evals/llm_clients/hn_aisdk.ts | 2 +- examples/ai_sdk_example.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/evals/llm_clients/hn_aisdk.ts b/evals/llm_clients/hn_aisdk.ts index 31248d7d8..68ef89027 100644 --- a/evals/llm_clients/hn_aisdk.ts +++ b/evals/llm_clients/hn_aisdk.ts @@ -10,7 +10,7 @@ export const hn_aisdk: EvalFunction = async ({ }) => { const stagehand = new Stagehand({ ...stagehandConfig, - modelName: "aisdk/openai/gpt-4o-mini", + modelName: "openai/gpt-4o-mini", }); await stagehand.init(); await stagehand.page.goto( diff --git a/examples/ai_sdk_example.ts b/examples/ai_sdk_example.ts index 4161fbb0b..f403926c7 100644 --- a/examples/ai_sdk_example.ts +++ b/examples/ai_sdk_example.ts @@ -5,7 +5,7 @@ import { z } from "zod"; async function example() { const stagehand = new Stagehand({ ...StagehandConfig, - modelName: "aisdk/openai/gpt-4o", + modelName: "openai/gpt-4o", }); await stagehand.init(); From c1c16a44ce741abd11025000cdb8019627f76482 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 28 Apr 2025 11:53:15 -0700 Subject: [PATCH 11/16] null checking --- lib/llm/aisdk.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/llm/aisdk.ts b/lib/llm/aisdk.ts index 3fc0a6453..2448f052b 100644 --- a/lib/llm/aisdk.ts +++ b/lib/llm/aisdk.ts @@ -66,7 +66,7 @@ export class AISdkClient extends LLMClient { response_model: options.response_model, }; - if (this.enableCaching) { + if (this.enableCaching && this.cache) { const cachedResponse = await this.cache.get( cacheOptions, options.requestId, @@ -216,7 +216,7 @@ export class AISdkClient extends LLMClient { const tools: Record = {}; - for (const rawTool of options.tools) { + for (const rawTool of options.tools ?? []) { tools[rawTool.name] = { description: rawTool.description, parameters: rawTool.parameters, From 1848be58e5ed871611f8f7df7ca1d55eb6e9b43e Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 28 Apr 2025 11:55:58 -0700 Subject: [PATCH 12/16] changeset --- .changeset/mean-plums-sin.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/mean-plums-sin.md diff --git a/.changeset/mean-plums-sin.md b/.changeset/mean-plums-sin.md new file mode 100644 index 000000000..33e0cfc89 --- /dev/null +++ b/.changeset/mean-plums-sin.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": patch +--- + +Fixing LLM client support to natively integrate with AI SDK From 213be56fdd6d034fcb05fe1ec05695671dbe87ab Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 29 Apr 2025 11:19:53 -0700 Subject: [PATCH 13/16] fix for together (and other) providers model format --- lib/llm/LLMProvider.ts | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index f2cf8bf64..1a8517fcc 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -1,5 +1,4 @@ import { - InvalidAISDKModelFormatError, UnsupportedAISDKModelProviderError, UnsupportedModelError, UnsupportedModelProviderError, @@ -116,11 +115,9 @@ export class LLMProvider { clientOptions?: ClientOptions, ): LLMClient { if (modelName.includes("/")) { - const parts = modelName.split("/"); - if (parts.length !== 2) { - throw new InvalidAISDKModelFormatError(modelName); - } - const [subProvider, subModelName] = parts; + const firstSlashIndex = modelName.indexOf('/'); + const subProvider = modelName.substring(0, firstSlashIndex); + const subModelName = modelName.substring(firstSlashIndex + 1); const languageModel = getAISDKLanguageModel(subProvider, subModelName); From 205b9db03786a289b500d94ee7b75ad48ce595a4 Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 29 Apr 2025 11:20:37 -0700 Subject: [PATCH 14/16] fix for together (and other) providers model format --- lib/llm/LLMProvider.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index 1a8517fcc..b6cf96854 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -115,7 +115,7 @@ export class LLMProvider { clientOptions?: ClientOptions, ): LLMClient { if (modelName.includes("/")) { - const firstSlashIndex = modelName.indexOf('/'); + const firstSlashIndex = modelName.indexOf("/"); const subProvider = modelName.substring(0, firstSlashIndex); const subModelName = modelName.substring(firstSlashIndex + 1); From aa9897b80374002e031a7ab43bce0d8a68c1e98c Mon Sep 17 00:00:00 2001 From: miguel Date: Tue, 29 Apr 2025 12:03:24 -0700 Subject: [PATCH 15/16] undo delete --- examples/external_clients/aisdk.ts | 122 +++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 examples/external_clients/aisdk.ts diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts new file mode 100644 index 000000000..1d72d984f --- /dev/null +++ b/examples/external_clients/aisdk.ts @@ -0,0 +1,122 @@ +import { + CoreAssistantMessage, + CoreMessage, + CoreSystemMessage, + CoreTool, + CoreUserMessage, + generateObject, + generateText, + ImagePart, + LanguageModel, + TextPart, +} from "ai"; +import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; +import { ChatCompletion } from "openai/resources"; + +export class AISdkClient extends LLMClient { + public type = "aisdk" as const; + private model: LanguageModel; + + constructor({ model }: { model: LanguageModel }) { + super(model.modelId as AvailableModel); + this.model = model; + } + + async createChatCompletion({ + options, + }: CreateChatCompletionOptions): Promise { + const formattedMessages: CoreMessage[] = options.messages.map((message) => { + if (Array.isArray(message.content)) { + if (message.role === "system") { + const systemMessage: CoreSystemMessage = { + role: "system", + content: message.content + .map((c) => ("text" in c ? c.text : "")) + .join("\n"), + }; + return systemMessage; + } + + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ImagePart = { + type: "image", + image: content.image_url.url, + }; + return imageContent; + } else { + const textContent: TextPart = { + type: "text", + text: content.text, + }; + return textContent; + } + }); + + if (message.role === "user") { + const userMessage: CoreUserMessage = { + role: "user", + content: contentParts, + }; + return userMessage; + } else { + const textOnlyParts = contentParts.map((part) => ({ + type: "text" as const, + text: part.type === "image" ? "[Image]" : part.text, + })); + const assistantMessage: CoreAssistantMessage = { + role: "assistant", + content: textOnlyParts, + }; + return assistantMessage; + } + } + + return { + role: message.role, + content: message.content, + }; + }); + + if (options.response_model) { + const response = await generateObject({ + model: this.model, + messages: formattedMessages, + schema: options.response_model.schema, + }); + + return { + data: response.object, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + } + + const tools: Record = {}; + + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + + const response = await generateText({ + model: this.model, + messages: formattedMessages, + tools, + }); + + return { + data: response.text, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + } +} From 35dd6c710fec76fc9de7218c2e055d2dad73ade2 Mon Sep 17 00:00:00 2001 From: Miguel <36487034+miguelg719@users.noreply.github.com> Date: Tue, 29 Apr 2025 12:04:50 -0700 Subject: [PATCH 16/16] Update evals/index.eval.ts --- evals/index.eval.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evals/index.eval.ts b/evals/index.eval.ts index 9f0dc474c..90fe0dc39 100644 --- a/evals/index.eval.ts +++ b/evals/index.eval.ts @@ -39,7 +39,7 @@ import { anthropic } from "@ai-sdk/anthropic"; import { groq } from "@ai-sdk/groq"; import { cerebras } from "@ai-sdk/cerebras"; import { openai } from "@ai-sdk/openai"; -import { AISdkClient } from "@/lib/llm/aisdk"; +import { AISdkClient } from "@/examples/external_clients/aisdk"; dotenv.config(); /**