Skip to content

Commit 4f67959

Browse files
apappascstzolov
authored andcommitted
feat: Enhance Anthropic integration with Thinking
- The `thinking` option is added to `AnthropicChatOptions` and `ChatCompletionRequest`. - The `AnthropicApi` and `AnthropicChatModel` now handle `THINKING` and `REDACTED_THINKING` content blocks in responses. New tests verify parsing of these blocks. - Updated method signatures on ChatCompletionRequestBuilder, deprecating old builders with `with*` prefix in favor of those without. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent fbec267 commit 4f67959

File tree

7 files changed

+414
-67
lines changed

7 files changed

+414
-67
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

+34-31
Original file line numberDiff line numberDiff line change
@@ -295,46 +295,49 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
295295
return new ChatResponse(List.of());
296296
}
297297

298-
List<Generation> generations = chatCompletion.content()
299-
.stream()
300-
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
301-
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
302-
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()))
303-
.toList();
304-
305-
List<Generation> allGenerations = new ArrayList<>(generations);
298+
List<Generation> generations = new ArrayList<>();
299+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
300+
for (ContentBlock content : chatCompletion.content()) {
301+
switch (content.type()) {
302+
case TEXT, TEXT_DELTA:
303+
generations.add(new Generation(new AssistantMessage(content.text(), Map.of()),
304+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
305+
break;
306+
case THINKING, THINKING_DELTA:
307+
Map<String, Object> thinkingProperties = new HashMap<>();
308+
thinkingProperties.put("signature", content.signature());
309+
generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties),
310+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
311+
break;
312+
case REDACTED_THINKING:
313+
Map<String, Object> redactedProperties = new HashMap<>();
314+
redactedProperties.put("data", content.data());
315+
generations.add(new Generation(new AssistantMessage(null, redactedProperties),
316+
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
317+
break;
318+
case TOOL_USE:
319+
var functionCallId = content.id();
320+
var functionName = content.name();
321+
var functionArguments = JsonParser.toJson(content.input());
322+
toolCalls.add(
323+
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
324+
break;
325+
}
326+
}
306327

307328
if (chatCompletion.stopReason() != null && generations.isEmpty()) {
308329
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
309330
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
310-
allGenerations.add(generation);
331+
generations.add(generation);
311332
}
312333

313-
List<ContentBlock> toolToUseList = chatCompletion.content()
314-
.stream()
315-
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
316-
.toList();
317-
318-
if (!CollectionUtils.isEmpty(toolToUseList)) {
319-
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
320-
321-
for (ContentBlock toolToUse : toolToUseList) {
322-
323-
var functionCallId = toolToUse.id();
324-
var functionName = toolToUse.name();
325-
var functionArguments = JsonParser.toJson(toolToUse.input());
326-
327-
toolCalls
328-
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
329-
}
330-
334+
if (!CollectionUtils.isEmpty(toolCalls)) {
331335
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
332336
Generation toolCallGeneration = new Generation(assistantMessage,
333337
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
334-
allGenerations.add(toolCallGeneration);
338+
generations.add(toolCallGeneration);
335339
}
336-
337-
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
340+
return new ChatResponse(generations, this.from(chatCompletion, usage));
338341
}
339342

340343
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
@@ -506,7 +509,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
506509
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
507510
if (!CollectionUtils.isEmpty(toolDefinitions)) {
508511
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
509-
request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build();
512+
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
510513
}
511514

512515
return request;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

+23-2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
5757
private @JsonProperty("temperature") Double temperature;
5858
private @JsonProperty("top_p") Double topP;
5959
private @JsonProperty("top_k") Integer topK;
60+
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;
6061

6162
/**
6263
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
@@ -103,6 +104,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
103104
.temperature(fromOptions.getTemperature())
104105
.topP(fromOptions.getTopP())
105106
.topK(fromOptions.getTopK())
107+
.thinking(fromOptions.getThinking())
106108
.toolCallbacks(
107109
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
108110
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
@@ -174,6 +176,14 @@ public void setTopK(Integer topK) {
174176
this.topK = topK;
175177
}
176178

179+
public ChatCompletionRequest.ThinkingConfig getThinking() {
180+
return this.thinking;
181+
}
182+
183+
public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) {
184+
this.thinking = thinking;
185+
}
186+
177187
@Override
178188
@JsonIgnore
179189
public List<FunctionCallback> getToolCallbacks() {
@@ -308,7 +318,8 @@ public boolean equals(Object o) {
308318
&& Objects.equals(this.metadata, that.metadata)
309319
&& Objects.equals(this.stopSequences, that.stopSequences)
310320
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
311-
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.toolCallbacks, that.toolCallbacks)
321+
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.thinking, that.thinking)
322+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
312323
&& Objects.equals(this.toolNames, that.toolNames)
313324
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
314325
&& Objects.equals(this.toolContext, that.toolContext)
@@ -317,7 +328,7 @@ public boolean equals(Object o) {
317328

318329
@Override
319330
public int hashCode() {
320-
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks,
331+
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, thinking, toolCallbacks,
321332
toolNames, internalToolExecutionEnabled, toolContext, httpHeaders);
322333
}
323334

@@ -365,6 +376,16 @@ public Builder topK(Integer topK) {
365376
return this;
366377
}
367378

379+
public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
380+
this.options.thinking = thinking;
381+
return this;
382+
}
383+
384+
public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) {
385+
this.options.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens);
386+
return this;
387+
}
388+
368389
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
369390
this.options.setToolCallbacks(toolCallbacks);
370391
return this;

0 commit comments

Comments
 (0)