diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5cfe18ac9b4..2867c847aee 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -22,10 +22,14 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import io.micrometer.observation.Observation; @@ -702,6 +706,10 @@ public TemplateRenderer getTemplateRenderer() { return this.templateRenderer; } + public boolean hasToolConfiguration() { + return !this.toolNames.isEmpty() || !this.toolCallbacks.isEmpty() || !this.toolContext.isEmpty(); + } + /** * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose * settings are replicated from this {@link ChatClientRequest}. @@ -784,7 +792,7 @@ public ChatClientRequestSpec options(T options) { public ChatClientRequestSpec toolNames(String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null"); Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); - this.toolNames.addAll(List.of(toolNames)); + Collections.addAll(this.toolNames, toolNames); return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..5c94f783634 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -94,22 +94,40 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient */ ChatOptions processedChatOptions = inputRequest.getChatOptions(); - if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { - if (!inputRequest.getToolNames().isEmpty()) { - Set toolNames = ToolCallingChatOptions - .mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); - toolCallingChatOptions.setToolNames(toolNames); + if (inputRequest.hasToolConfiguration()) { + if (processedChatOptions == null) { + ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); + if (!inputRequest.getToolNames().isEmpty()) { + builder.toolNames(new HashSet<>(inputRequest.getToolNames())); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = inputRequest.getToolCallbacks(); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + builder.toolCallbacks(inputRequest.getToolCallbacks()); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + builder.toolContext(inputRequest.getToolContext()); + } + + processedChatOptions = builder.build(); } - if (!inputRequest.getToolCallbacks().isEmpty()) { - List toolCallbacks = ToolCallingChatOptions - .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); - ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); - toolCallingChatOptions.setToolCallbacks(toolCallbacks); - } - if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { - Map toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), - toolCallingChatOptions.getToolContext()); - toolCallingChatOptions.setToolContext(toolContext); + else if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (!inputRequest.getToolNames().isEmpty()) { + Set toolNames = ToolCallingChatOptions.mergeToolNames( + new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); + toolCallingChatOptions.setToolNames(toolNames); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = ToolCallingChatOptions + .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + toolCallingChatOptions.setToolCallbacks(toolCallbacks); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + Map toolContext = ToolCallingChatOptions + .mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext()); + toolCallingChatOptions.setToolContext(toolContext); + } } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..0f8788f326b 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -322,6 +322,60 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { .containsAllEntriesOf(toolContext2); } + @Test + void whenToolNamesWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + List toolNames = List.of("tool1", "tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolNames(toolNames.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + } + + @Test + void whenToolCallbacksWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + ToolCallback toolCallback = new TestToolCallback("tool1"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolCallbacks(toolCallback); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + } + + @Test + void whenToolContextWithoutChatOptionsIsProvidedThenToolCallingChatOptionsAreSet() { + Map toolContext = Map.of("key", "value"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolContext(toolContext); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + } + @Test void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { Map advisorParams = Map.of("key1", "value1", "key2", "value2");