diff --git a/Playground/Playground/Views/AppView.swift b/Playground/Playground/Views/AppView.swift index b5745a1..90335eb 100644 --- a/Playground/Playground/Views/AppView.swift +++ b/Playground/Playground/Views/AppView.swift @@ -34,6 +34,12 @@ struct AppView: View { ResponseFormatView(provider: provider) } } + + if provider == .openRouter { + NavigationLink("Fallback Model") { + FallbackModelView(provider: provider) + } + } } .disabled(provider == .openai && viewModel.openaiAPIKey.isEmpty) .disabled(provider == .openRouter && viewModel.openRouterAPIKey.isEmpty) diff --git a/Playground/Playground/Views/FallbackModelView.swift b/Playground/Playground/Views/FallbackModelView.swift new file mode 100644 index 0000000..6bd9ba5 --- /dev/null +++ b/Playground/Playground/Views/FallbackModelView.swift @@ -0,0 +1,150 @@ +// +// FallbackModelView.swift +// Playground +// +// Created by Kevin Hermawan on 10/19/24. +// + +import SwiftUI +import LLMChatOpenAI + +struct FallbackModelView: View { + let provider: ServiceProvider + + @Environment(AppViewModel.self) private var viewModel + @State private var isPreferencesPresented: Bool = false + + @State private var fallbackModel: String = "" + + @State private var prompt: String = "Hi!" + @State private var response: String = "" + @State private var inputTokens: Int = 0 + @State private var outputTokens: Int = 0 + @State private var totalTokens: Int = 0 + + var body: some View { + @Bindable var viewModelBindable = viewModel + + VStack { + Form { + Section("Models") { + Picker("Primary Model", selection: $viewModelBindable.selectedModel) { + ForEach(viewModelBindable.models, id: \.self) { model in + Text(model).tag(model) + } + } + + Picker("Fallback Model", selection: $fallbackModel) { + ForEach(viewModelBindable.models, id: \.self) { model in + Text(model).tag(model) + } + } + } + .disabled(viewModel.models.isEmpty) + + Section("Prompt") { + TextField("Prompt", text: $prompt) + } + + Section("Response") { + Text(response) + } + + UsageSection(inputTokens: inputTokens, outputTokens: outputTokens, totalTokens: totalTokens) + } + + VStack { + SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream) + .disabled(viewModel.models.isEmpty) + } + } + .toolbar { + ToolbarItem(placement: .principal) { + NavigationTitle("Fallback Model") + } + + ToolbarItem(placement: .primaryAction) { + Button("Preferences", systemImage: "gearshape", action: { isPreferencesPresented.toggle() }) + } + } + .sheet(isPresented: $isPreferencesPresented) { + PreferencesView() + } + .onAppear { + viewModel.setup(for: provider) + } + .onDisappear { + viewModel.selectedModel = "" + } + .onChange(of: viewModel.models) { _, models in + if let firstModel = models.first { + fallbackModel = firstModel + } + } + } + + private func onSend() { + clear() + + let messages = [ + ChatMessage(role: .system, content: viewModel.systemPrompt), + ChatMessage(role: .user, content: prompt) + ] + + let options = ChatOptions(temperature: viewModel.temperature) + + Task { + do { + let completion = try await viewModel.chat.send(models: [viewModel.selectedModel, fallbackModel], messages: messages, options: options) + + if let content = completion.choices.first?.message.content { + self.response = content + } + + if let usage = completion.usage { + self.inputTokens = usage.promptTokens + self.outputTokens = usage.completionTokens + self.totalTokens = usage.totalTokens + } + } catch { + print(String(describing: error)) + } + } + } + + private func onStream() { + clear() + + let messages = [ + ChatMessage(role: .system, content: viewModel.systemPrompt), + ChatMessage(role: .user, content: prompt) + ] + + let options = ChatOptions(temperature: viewModel.temperature) + + Task { + do { + for try await chunk in viewModel.chat.stream(models: [viewModel.selectedModel, fallbackModel], messages: messages, options: options) { + if let content = chunk.choices.first?.delta.content { + self.response += content + } + + if let usage = chunk.usage { + self.inputTokens = usage.promptTokens ?? 0 + self.outputTokens = usage.completionTokens ?? 0 + self.totalTokens = usage.totalTokens ?? 0 + } + } + } catch { + print(String(describing: error)) + } + } + } + + private func clear() { + response = "" + inputTokens = 0 + outputTokens = 0 + totalTokens = 0 + } +} diff --git a/README.md b/README.md index 7b0fd95..dfda5ff 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,34 @@ let task = Task { task.cancel() ``` +#### Using Fallback Models (OpenRouter only) + +```swift +Task { + do { + let completion = try await chat.send(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) + + print(completion.choices.first?.message.content ?? "No response") + } catch { + print(String(describing: error)) + } +} + +Task { + do { + for try await chunk in chat.stream(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) { + if let content = chunk.choices.first?.delta.content { + print(content, terminator: "") + } + } + } catch { + print(String(describing: error)) + } +} +``` + +> **Note**: Fallback model functionality is only supported when using OpenRouter. If you use the fallback models method (`send(models:)` or `stream(models:)`) with other providers, only the first model in the array will be used, and the rest will be ignored. To learn more about fallback models, check out the [OpenRouter documentation](https://openrouter.ai/docs/model-routing). + ### Advanced Usage #### Vision diff --git a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md index e263671..2b9ee50 100644 --- a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md +++ b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md @@ -70,6 +70,34 @@ let task = Task { task.cancel() ``` +#### Using Fallback Models (OpenRouter only) + +```swift +Task { + do { + let completion = try await chat.send(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) + + print(completion.choices.first?.message.content ?? "No response") + } catch { + print(String(describing: error)) + } +} + +Task { + do { + for try await chunk in chat.stream(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) { + if let content = chunk.choices.first?.delta.content { + print(content, terminator: "") + } + } + } catch { + print(String(describing: error)) + } +} +``` + +> **Note**: Fallback model functionality is only supported when using OpenRouter. If you use the fallback models method (`send(models:)` or `stream(models:)`) with other providers, only the first model in the array will be used, and the rest will be ignored. To learn more about fallback models, check out the [OpenRouter documentation](https://openrouter.ai/docs/model-routing). + ### Advanced Usage #### Vision diff --git a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift index 44d0e39..45693ca 100644 --- a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift +++ b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift @@ -13,6 +13,10 @@ public struct LLMChatOpenAI { private let endpoint: URL private var headers: [String: String]? = nil + private var isSupportFallbackModel: Bool { + endpoint.host == "openrouter.ai" + } + /// Creates a new instance of ``LLMChatOpenAI``. /// /// - Parameters: @@ -36,15 +40,33 @@ extension LLMChatOpenAI { /// - messages: An array of ``ChatMessage`` objects that represent the conversation history. /// - options: Optional ``ChatOptions`` that customize the completion request. /// - /// - Returns: A `ChatCompletion` object that contains the API's response. + /// - Returns: A ``ChatCompletion`` object that contains the API's response. public func send(model: String, messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { let body = RequestBody(stream: false, model: model, messages: messages, options: options) - let request = try createRequest(for: endpoint, with: body) - let (data, response) = try await URLSession.shared.data(for: request) - try validateHTTPResponse(response) + return try await performRequest(with: body) + } + + /// Sends a chat completion request using fallback models (OpenRouter only). + /// + /// - Parameters: + /// - models: An array of models to use for completion, in order of preference. + /// - messages: An array of ``ChatMessage`` objects that represent the conversation history. + /// - options: Optional ``ChatOptions`` that customize the completion request. + /// + /// - Returns: A ``ChatCompletion`` object that contains the API's response. + /// + /// - Note: This method enables fallback functionality when using OpenRouter. For other providers, only the first model in the array will be used. + public func send(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { + let body: RequestBody - return try JSONDecoder().decode(ChatCompletion.self, from: data) + if isSupportFallbackModel { + body = RequestBody(stream: false, models: models, messages: messages, options: options) + } else { + body = RequestBody(stream: false, model: models.first ?? "", messages: messages, options: options) + } + + return try await performRequest(with: body) } /// Streams a chat completion request. @@ -56,12 +78,46 @@ extension LLMChatOpenAI { /// /// - Returns: An `AsyncThrowingStream` of ``ChatCompletionChunk`` objects. public func stream(model: String, messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { + let body = RequestBody(stream: true, model: model, messages: messages, options: options) + + return performStreamRequest(with: body) + } + + /// Streams a chat completion request using fallback models (OpenRouter only). + /// + /// - Parameters: + /// - models: An array of models to use for completion, in order of preference. + /// - messages: An array of ``ChatMessage`` objects that represent the conversation history. + /// - options: Optional ``ChatOptions`` that customize the completion request. + /// + /// - Returns: An `AsyncThrowingStream` of ``ChatCompletionChunk`` objects. + /// + /// - Note: This method enables fallback functionality when using OpenRouter. For other providers, only the first model in the array will be used. + public func stream(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { + let body: RequestBody + + if isSupportFallbackModel { + body = RequestBody(stream: true, models: models, messages: messages, options: options) + } else { + body = RequestBody(stream: true, model: models.first ?? "", messages: messages, options: options) + } + + return performStreamRequest(with: body) + } + + private func performRequest(with body: RequestBody) async throws -> ChatCompletion { + let request = try createRequest(for: endpoint, with: body) + let (data, response) = try await URLSession.shared.data(for: request) + try validateHTTPResponse(response) + + return try JSONDecoder().decode(ChatCompletion.self, from: data) + } + + private func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream { AsyncThrowingStream { continuation in Task { do { - let body = RequestBody(stream: true, model: model, messages: messages, options: options) let request = try createRequest(for: endpoint, with: body) - let (bytes, response) = try await URLSession.shared.bytes(for: request) try validateHTTPResponse(response) @@ -74,12 +130,9 @@ extension LLMChatOpenAI { } if let data = jsonString.data(using: .utf8) { - do { - let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data) - continuation.yield(chunk) - } catch { - continuation.finish(throwing: error) - } + let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data) + + continuation.yield(chunk) } } } @@ -128,27 +181,50 @@ private extension LLMChatOpenAI { private extension LLMChatOpenAI { struct RequestBody: Encodable { let stream: Bool - let model: String + let model: String? + let models: [String]? let messages: [ChatMessage] let options: ChatOptions? + init(stream: Bool, model: String, messages: [ChatMessage], options: ChatOptions?) { + self.stream = stream + self.model = model + self.models = nil + self.messages = messages + self.options = options + } + + init(stream: Bool, models: [String], messages: [ChatMessage], options: ChatOptions?) { + self.stream = stream + self.model = nil + self.models = models + self.messages = messages + self.options = options + } + func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) try container.encode(stream, forKey: .stream) - try container.encode(model, forKey: .model) try container.encode(messages, forKey: .messages) if stream { try container.encode(["include_usage": true], forKey: .streamOptions) } + if let model = model { + try container.encode(model, forKey: .model) + } else if let models = models { + try container.encode(models, forKey: .models) + try container.encode("fallback", forKey: .route) + } + if let options { try options.encode(to: encoder) } } enum CodingKeys: String, CodingKey { - case stream, model, messages + case stream, model, models, route, messages case streamOptions = "stream_options" } } diff --git a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift index 43b5c7a..e4d8102 100644 --- a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift +++ b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift @@ -97,4 +97,66 @@ final class ChatCompletionTests: XCTestCase { XCTAssertEqual(receivedContent, "The capital of Indonesia is Jakarta.") } + + func testSendChatCompletionWithFallbackModels() async throws { + let mockResponseString = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1694268190, + "model": "openai/gpt-4o", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of Indonesia is Jakarta." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 10, + "total_tokens": 15 + } + } + """ + + URLProtocolMock.mockData = mockResponseString.data(using: .utf8) + let completion = try await chat.send(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) + let choice = completion.choices.first + let message = choice?.message + + XCTAssertEqual(completion.id, "chatcmpl-123") + XCTAssertEqual(completion.model, "openai/gpt-4o") + + // Content + XCTAssertEqual(message?.role, "assistant") + XCTAssertEqual(message?.content, "The capital of Indonesia is Jakarta.") + + // Usage + XCTAssertEqual(completion.usage?.promptTokens, 5) + XCTAssertEqual(completion.usage?.completionTokens, 10) + XCTAssertEqual(completion.usage?.totalTokens, 15) + } + + func testStreamChatCompletionWithFallbackModels() async throws { + URLProtocolMock.mockStreamData = [ + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1694268190,\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"The capital\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1694268190,\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" of Indonesia\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1694268190,\"model\":\"openai/gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" is Jakarta.\"},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n" + ] + + var receivedContent = "" + + for try await chunk in chat.stream(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) { + if let content = chunk.choices.first?.delta.content { + receivedContent += content + } + } + + XCTAssertEqual(receivedContent, "The capital of Indonesia is Jakarta.") + } }