Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedDeque;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
Expand All @@ -44,6 +47,10 @@
import org.springframework.ai.tool.observation.ToolCallingObservationDocumentation;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.lang.Nullable;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down Expand Up @@ -79,17 +86,20 @@ public final class DefaultToolCallingManager implements ToolCallingManager {

private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

private final TaskExecutor taskExecutor;

private ToolCallingObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, @Nullable TaskExecutor taskExecutor) {
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

this.observationRegistry = observationRegistry;
this.toolCallbackResolver = toolCallbackResolver;
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
this.taskExecutor = taskExecutor != null ? taskExecutor : this.buildDefaultTaskExecutor();
}

@Override
Expand Down Expand Up @@ -173,64 +183,63 @@ private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt
*/
private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
ToolContext toolContext) {
List<ToolCallback> toolCallbacks = List.of();
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
toolCallbacks = toolCallingChatOptions.getToolCallbacks();
}

List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

Boolean returnDirect = null;

for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

logger.debug("Executing tool call: {}", toolCall.name());

String toolName = toolCall.name();
String toolInputArguments = toolCall.arguments();

ToolCallback toolCallback = toolCallbacks.stream()
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
.findFirst()
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));

if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}
final List<ToolCallback> toolCallbacks = (prompt
.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions)
? toolCallingChatOptions.getToolCallbacks() : List.of();

if (returnDirect == null) {
returnDirect = toolCallback.getToolMetadata().returnDirect();
}
else {
returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect();
}

ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
.toolDefinition(toolCallback.getToolDefinition())
.toolMetadata(toolCallback.getToolMetadata())
.toolCallArguments(toolInputArguments)
.build();

String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
String toolResult;
try {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = this.toolExecutionExceptionProcessor.process(ex);
}
observationContext.setToolCallResult(toolResult);
return toolResult;
});

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
toolCallResult != null ? toolCallResult : ""));
}

return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
final Queue<Boolean> toolsReturnDirect = new ConcurrentLinkedDeque<>();
List<CompletableFuture<ToolResponseMessage.ToolResponse>> futuresToolResponses = assistantMessage.getToolCalls()
.stream()
.map(toolCall -> CompletableFuture.supplyAsync(() -> {
logger.debug("Executing tool call: {}", toolCall.name());

String toolName = toolCall.name();
String toolInputArguments = toolCall.arguments();

ToolCallback toolCallback = toolCallbacks.stream()
.filter(tool -> toolName.equals(tool.getToolDefinition().name()))
.findFirst()
.orElseGet(() -> this.toolCallbackResolver.resolve(toolName));

if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}

toolsReturnDirect.add(toolCallback.getToolMetadata().returnDirect());

ToolCallingObservationContext observationContext = ToolCallingObservationContext.builder()
.toolDefinition(toolCallback.getToolDefinition())
.toolMetadata(toolCallback.getToolMetadata())
.toolCallArguments(toolInputArguments)
.build();

String toolCallResult = ToolCallingObservationDocumentation.TOOL_CALL
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
String toolResult;
try {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = this.toolExecutionExceptionProcessor.process(ex);
}
observationContext.setToolCallResult(toolResult);
return toolResult;
});

return new ToolResponseMessage.ToolResponse(toolCall.id(), toolName,
toolCallResult != null ? toolCallResult : "");
}, this.taskExecutor))
.toList();

final List<ToolResponseMessage.ToolResponse> toolResponses = CompletableFuture
.allOf(futuresToolResponses.toArray(new CompletableFuture[0]))
.thenApply(result -> futuresToolResponses.stream().map(CompletableFuture::join).toList())
.join();

return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()),
toolsReturnDirect.stream().allMatch(Boolean::booleanValue));
}

private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
Expand All @@ -245,6 +254,16 @@ public void setObservationConvention(ToolCallingObservationConvention observatio
this.observationConvention = observationConvention;
}

private TaskExecutor buildDefaultTaskExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setThreadNamePrefix("ai-tool-calling-");
taskExecutor.setCorePoolSize(4);
taskExecutor.setMaxPoolSize(16);
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
taskExecutor.initialize();
return taskExecutor;
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -260,6 +279,8 @@ public final static class Builder {

private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;

private TaskExecutor taskExecutor;

private Builder() {
}

Expand All @@ -279,9 +300,14 @@ public Builder toolExecutionExceptionProcessor(
return this;
}

public Builder taskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
return this;
}

public DefaultToolCallingManager build() {
return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver,
this.toolExecutionExceptionProcessor);
this.toolExecutionExceptionProcessor, this.taskExecutor);
}

}
Expand Down