From eb3a91d04e7dc7a645eb8f4c801c9ae248b98d13 Mon Sep 17 00:00:00 2001 From: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:20:19 -0300 Subject: [PATCH 1/3] Parallel Tool Execution Closes gh-4254 Signed-off-by: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> --- .../model/tool/DefaultToolCallingManager.java | 142 ++++++++++-------- 1 file changed, 83 insertions(+), 59 deletions(-) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 5149a98a85c..beb976bdc1e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -20,7 +20,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedDeque; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; @@ -44,6 +47,10 @@ import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.core.task.TaskExecutor; +import org.springframework.core.task.support.ContextPropagatingTaskDecorator; +import org.springframework.lang.Nullable; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -71,6 +78,8 @@ public final class DefaultToolCallingManager implements ToolCallingManager { private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR = DefaultToolExecutionExceptionProcessor.builder().build(); + private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor(); + // @formatter:on private final ObservationRegistry observationRegistry; @@ -79,10 +88,12 @@ public final class DefaultToolCallingManager implements ToolCallingManager { private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor; + private final TaskExecutor taskExecutor; + private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, - ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) { + ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, @Nullable TaskExecutor taskExecutor) { Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null"); Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null"); @@ -90,6 +101,7 @@ public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCa this.observationRegistry = observationRegistry; this.toolCallbackResolver = toolCallbackResolver; this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor; + this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); } @Override @@ -173,64 +185,59 @@ private static List buildConversationHistoryBeforeToolExecution(Prompt */ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage, ToolContext toolContext) { - List toolCallbacks = List.of(); - if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { - toolCallbacks = toolCallingChatOptions.getToolCallbacks(); - } - - List toolResponses = new ArrayList<>(); - - Boolean returnDirect = null; - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - logger.debug("Executing tool call: {}", toolCall.name()); - - String toolName = toolCall.name(); - String toolInputArguments = toolCall.arguments(); - - ToolCallback toolCallback = toolCallbacks.stream() - .filter(tool -> toolName.equals(tool.getToolDefinition().name())) - .findFirst() - .orElseGet(() -> this.toolCallbackResolver.resolve(toolName)); - - if (toolCallback == null) { - throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); - } + final List toolCallbacks = (prompt + .getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) + ? toolCallingChatOptions.getToolCallbacks() : List.of(); - if (returnDirect == null) { - returnDirect = toolCallback.getToolMetadata().returnDirect(); - } - else { - returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect(); - } - - ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() - .toolDefinition(toolCallback.getToolDefinition()) - .toolMetadata(toolCallback.getToolMetadata()) - .toolCallArguments(toolInputArguments) - .build(); - - String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - String toolResult; - try { - toolResult = toolCallback.call(toolInputArguments, toolContext); - } - catch (ToolExecutionException ex) { - toolResult = this.toolExecutionExceptionProcessor.process(ex); - } - observationContext.setToolCallResult(toolResult); - return toolResult; - }); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, - toolCallResult != null ? toolCallResult : "")); - } - - return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect); + final Queue toolsReturnDirect = new ConcurrentLinkedDeque<>(); + List toolResponses = assistantMessage.getToolCalls() + .stream() + .map(toolCall -> CompletableFuture.supplyAsync(() -> { + logger.debug("Executing tool call: {}", toolCall.name()); + + String toolName = toolCall.name(); + String toolInputArguments = toolCall.arguments(); + + ToolCallback toolCallback = toolCallbacks.stream() + .filter(tool -> toolName.equals(tool.getToolDefinition().name())) + .findFirst() + .orElseGet(() -> this.toolCallbackResolver.resolve(toolName)); + + if (toolCallback == null) { + throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); + } + + toolsReturnDirect.add(toolCallback.getToolMetadata().returnDirect()); + + ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder() + .toolDefinition(toolCallback.getToolDefinition()) + .toolMetadata(toolCallback.getToolMetadata()) + .toolCallArguments(toolInputArguments) + .build(); + + String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + String toolResult; + try { + toolResult = toolCallback.call(toolInputArguments, toolContext); + } + catch (ToolExecutionException ex) { + toolResult = this.toolExecutionExceptionProcessor.process(ex); + } + observationContext.setToolCallResult(toolResult); + return toolResult; + }); + + return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, + toolCallResult != null ? toolCallResult : ""); + }, this.taskExecutor)) + .map(CompletableFuture::join) + .toList(); + + return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), + toolsReturnDirect.stream().allMatch(Boolean::booleanValue)); } private List buildConversationHistoryAfterToolExecution(List previousMessages, @@ -245,6 +252,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio this.observationConvention = observationConvention; } + private static TaskExecutor buildDefaultTaskExecutor() { + ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); + taskExecutor.setThreadNamePrefix("ai-toll-calling-"); + taskExecutor.setCorePoolSize(4); + taskExecutor.setMaxPoolSize(16); + taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator()); + taskExecutor.initialize(); + return taskExecutor; + } + public static Builder builder() { return new Builder(); } @@ -260,6 +277,8 @@ public final static class Builder { private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR; + private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR; + private Builder() { } @@ -279,9 +298,14 @@ public Builder toolExecutionExceptionProcessor( return this; } + public Builder taskExecutor(TaskExecutor taskExecutor) { + this.taskExecutor = taskExecutor; + return this; + } + public DefaultToolCallingManager build() { return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver, - this.toolExecutionExceptionProcessor); + this.toolExecutionExceptionProcessor, taskExecutor); } } From bca6afd05c5fc169c52720fcfbb40486a0820f02 Mon Sep 17 00:00:00 2001 From: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:50:44 -0300 Subject: [PATCH 2/3] Fix checkstyle Closes gh-4254 Signed-off-by: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> --- .../ai/model/tool/DefaultToolCallingManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index beb976bdc1e..872ad807864 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -20,8 +20,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Queue; import java.util.Optional; +import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedDeque; @@ -305,7 +305,7 @@ public Builder taskExecutor(TaskExecutor taskExecutor) { public DefaultToolCallingManager build() { return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver, - this.toolExecutionExceptionProcessor, taskExecutor); + this.toolExecutionExceptionProcessor, this.taskExecutor); } } From c13f222f65d31e551c0c66c561fa1fa3df5941cd Mon Sep 17 00:00:00 2001 From: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> Date: Thu, 28 Aug 2025 20:27:21 -0300 Subject: [PATCH 3/3] Code review improvements Closes gh-4254 Signed-off-by: Rafael Cunha <12313126+rafaelrddc@users.noreply.github.com> --- .../model/tool/DefaultToolCallingManager.java | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 872ad807864..53d07bf16c5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -78,8 +78,6 @@ public final class DefaultToolCallingManager implements ToolCallingManager { private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR = DefaultToolExecutionExceptionProcessor.builder().build(); - private static final TaskExecutor DEFAULT_TASK_EXECUTOR = buildDefaultTaskExecutor(); - // @formatter:on private final ObservationRegistry observationRegistry; @@ -101,7 +99,7 @@ public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCa this.observationRegistry = observationRegistry; this.toolCallbackResolver = toolCallbackResolver; this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor; - this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); + this.taskExecutor = taskExecutor != null ? taskExecutor : this.buildDefaultTaskExecutor(); } @Override @@ -190,7 +188,7 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess ? toolCallingChatOptions.getToolCallbacks() : List.of(); final Queue toolsReturnDirect = new ConcurrentLinkedDeque<>(); - List toolResponses = assistantMessage.getToolCalls() + List> futuresToolResponses = assistantMessage.getToolCalls() .stream() .map(toolCall -> CompletableFuture.supplyAsync(() -> { logger.debug("Executing tool call: {}", toolCall.name()); @@ -233,9 +231,13 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolCallResult != null ? toolCallResult : ""); }, this.taskExecutor)) - .map(CompletableFuture::join) .toList(); + final List toolResponses = CompletableFuture + .allOf(futuresToolResponses.toArray(new CompletableFuture[0])) + .thenApply(result -> futuresToolResponses.stream().map(CompletableFuture::join).toList()) + .join(); + return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), toolsReturnDirect.stream().allMatch(Boolean::booleanValue)); } @@ -252,9 +254,9 @@ public void setObservationConvention(ToolCallingObservationConvention observatio this.observationConvention = observationConvention; } - private static TaskExecutor buildDefaultTaskExecutor() { + private TaskExecutor buildDefaultTaskExecutor() { ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); - taskExecutor.setThreadNamePrefix("ai-toll-calling-"); + taskExecutor.setThreadNamePrefix("ai-tool-calling-"); taskExecutor.setCorePoolSize(4); taskExecutor.setMaxPoolSize(16); taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator()); @@ -277,7 +279,7 @@ public final static class Builder { private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR; - private TaskExecutor taskExecutor = DEFAULT_TASK_EXECUTOR; + private TaskExecutor taskExecutor; private Builder() { }