From cdbce1620bf31c4043b2ef4a5441c709617681c9 Mon Sep 17 00:00:00 2001 From: Filip Hrisafov Date: Fri, 30 May 2025 14:08:06 +0200 Subject: [PATCH 1/2] Create new ToolChatOptions if input request has tool configuration without options Signed-off-by: Filip Hrisafov --- .../ai/chat/client/DefaultChatClient.java | 20 ++++--- .../chat/client/DefaultChatClientUtils.java | 48 +++++++++++------ .../client/DefaultChatClientUtilsTests.java | 54 +++++++++++++++++++ 3 files changed, 101 insertions(+), 21 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5cfe18ac9b4..0fc0b8181fe 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -22,10 +22,14 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import io.micrometer.observation.Observation; @@ -571,7 +575,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List media = new ArrayList<>(); - private final List toolNames = new ArrayList<>(); + private final Set toolNames = new LinkedHashSet<>(); private final List toolCallbacks = new ArrayList<>(); @@ -607,9 +611,9 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, - @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, - ObservationRegistry observationRegistry, + List toolCallbacks, List messages, Collection toolNames, + List media, @Nullable ChatOptions chatOptions, List advisors, + Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { @@ -686,7 +690,7 @@ public List getMedia() { return this.media; } - public List getToolNames() { + public Set getToolNames() { return this.toolNames; } @@ -702,6 +706,10 @@ public TemplateRenderer getTemplateRenderer() { return this.templateRenderer; } + public boolean hasToolConfiguration() { + return !this.toolNames.isEmpty() || !this.toolCallbacks.isEmpty() || !this.toolContext.isEmpty(); + } + /** * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose * settings are replicated from this {@link ChatClientRequest}. @@ -784,7 +792,7 @@ public ChatClientRequestSpec options(T options) { public ChatClientRequestSpec toolNames(String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null"); Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); - this.toolNames.addAll(List.of(toolNames)); + Collections.addAll(this.toolNames, toolNames); return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..2b1894712e2 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -94,22 +94,40 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient */ ChatOptions processedChatOptions = inputRequest.getChatOptions(); - if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { - if (!inputRequest.getToolNames().isEmpty()) { - Set toolNames = ToolCallingChatOptions - .mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); - toolCallingChatOptions.setToolNames(toolNames); + if (inputRequest.hasToolConfiguration()) { + if (processedChatOptions == null) { + ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); + if (!inputRequest.getToolNames().isEmpty()) { + builder.toolNames(inputRequest.getToolNames()); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = inputRequest.getToolCallbacks(); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + builder.toolCallbacks(inputRequest.getToolCallbacks()); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + builder.toolContext(inputRequest.getToolContext()); + } + + processedChatOptions = builder.build(); } - if (!inputRequest.getToolCallbacks().isEmpty()) { - List toolCallbacks = ToolCallingChatOptions - .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); - ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); - toolCallingChatOptions.setToolCallbacks(toolCallbacks); - } - if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { - Map toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), - toolCallingChatOptions.getToolContext()); - toolCallingChatOptions.setToolContext(toolContext); + else if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (!inputRequest.getToolNames().isEmpty()) { + Set toolNames = ToolCallingChatOptions.mergeToolNames( + new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); + toolCallingChatOptions.setToolNames(toolNames); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = ToolCallingChatOptions + .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + toolCallingChatOptions.setToolCallbacks(toolCallbacks); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + Map toolContext = ToolCallingChatOptions + .mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext()); + toolCallingChatOptions.setToolContext(toolContext); + } } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..0f8788f326b 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -322,6 +322,60 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { .containsAllEntriesOf(toolContext2); } + @Test + void whenToolNamesWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + List toolNames = List.of("tool1", "tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolNames(toolNames.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + } + + @Test + void whenToolCallbacksWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + ToolCallback toolCallback = new TestToolCallback("tool1"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolCallbacks(toolCallback); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + } + + @Test + void whenToolContextWithoutChatOptionsIsProvidedThenToolCallingChatOptionsAreSet() { + Map toolContext = Map.of("key", "value"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolContext(toolContext); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + } + @Test void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { Map advisorParams = Map.of("key1", "value1", "key2", "value2"); From 67b962f978f9e93df73d6c81f5fb7e59556b46b9 Mon Sep 17 00:00:00 2001 From: Filip Hrisafov Date: Thu, 17 Jul 2025 19:10:45 +0200 Subject: [PATCH 2/2] Revert change of type Signed-off-by: Filip Hrisafov --- .../ai/chat/client/DefaultChatClient.java | 10 +++++----- .../ai/chat/client/DefaultChatClientUtils.java | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 0fc0b8181fe..2867c847aee 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -575,7 +575,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List media = new ArrayList<>(); - private final Set toolNames = new LinkedHashSet<>(); + private final List toolNames = new ArrayList<>(); private final List toolCallbacks = new ArrayList<>(); @@ -611,9 +611,9 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, Collection toolNames, - List media, @Nullable ChatOptions chatOptions, List advisors, - Map advisorParams, ObservationRegistry observationRegistry, + List toolCallbacks, List messages, List toolNames, List media, + @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, + ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { @@ -690,7 +690,7 @@ public List getMedia() { return this.media; } - public Set getToolNames() { + public List getToolNames() { return this.toolNames; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 2b1894712e2..5c94f783634 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -98,7 +98,7 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient if (processedChatOptions == null) { ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); if (!inputRequest.getToolNames().isEmpty()) { - builder.toolNames(inputRequest.getToolNames()); + builder.toolNames(new HashSet<>(inputRequest.getToolNames())); } if (!inputRequest.getToolCallbacks().isEmpty()) { List toolCallbacks = inputRequest.getToolCallbacks();