Skip to content

Commit ccfe43f

Browse files
committed
feat(anthropic): Add support for streaming thinking events
Add necessary types and update stream processing to handle Anthropic's 'thinking' content blocks and deltas in streaming responses. This resolves an issue where an IllegalArgumentException was thrown for unhandled thinking event types. format Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 3756e16 commit ccfe43f

File tree

4 files changed

+265
-44
lines changed

4 files changed

+265
-44
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -1226,8 +1226,11 @@ public record ContentBlockStartEvent(
12261226

12271227
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type",
12281228
visible = true)
1229-
@JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockToolUse.class, name = "tool_use"),
1230-
@JsonSubTypes.Type(value = ContentBlockText.class, name = "text") })
1229+
@JsonSubTypes({
1230+
@JsonSubTypes.Type(value = ContentBlockToolUse.class, name = "tool_use"),
1231+
@JsonSubTypes.Type(value = ContentBlockText.class, name = "text"),
1232+
@JsonSubTypes.Type(value = ContentBlockThinking.class, name = "thinking")
1233+
})
12311234
public interface ContentBlockBody {
12321235
String type();
12331236
}
@@ -1257,6 +1260,18 @@ public record ContentBlockText(
12571260
@JsonProperty("type") String type,
12581261
@JsonProperty("text") String text) implements ContentBlockBody {
12591262
}
1263+
1264+
/**
1265+
* Thinking content block.
1266+
* @param type The content block type.
1267+
* @param thinking The thinking content.
1268+
*/
1269+
@JsonInclude(Include.NON_NULL)
1270+
public record ContentBlockThinking(
1271+
@JsonProperty("type") String type,
1272+
@JsonProperty("thinking") String thinking) implements ContentBlockBody {
1273+
}
1274+
12601275
}
12611276
// @formatter:on
12621277

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

+61-30
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,32 @@
2626
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent;
2727
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaJson;
2828
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaText;
29+
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaThinking;
30+
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaSignature;
2931
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent;
3032
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockText;
3133
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockToolUse;
34+
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockThinking;
3235
import org.springframework.ai.anthropic.api.AnthropicApi.EventType;
3336
import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent;
3437
import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent;
3538
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
3639
import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent;
3740
import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent;
3841
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
39-
import org.springframework.util.Assert;
4042
import org.springframework.util.CollectionUtils;
4143
import org.springframework.util.StringUtils;
4244

4345
/**
44-
* Helper class to support streaming function calling.
46+
* Helper class to support streaming function calling and thinking events.
4547
* <p>
4648
* It can merge the streamed {@link StreamEvent} chunks in case of function calling
47-
* message.
49+
* message. It passes through other events like text, thinking, and signature deltas.
4850
*
4951
* @author Mariusz Bernacki
5052
* @author Christian Tzolov
5153
* @author Jihoon Kim
54+
* @author Alexandros Pappas
5255
* @since 1.0.0
5356
*/
5457
public class StreamHelper {
@@ -61,13 +64,16 @@ public boolean isToolUseStart(StreamEvent event) {
6164
}
6265

6366
public boolean isToolUseFinish(StreamEvent event) {
64-
65-
if (event == null || event.type() == null || event.type() != EventType.CONTENT_BLOCK_STOP) {
66-
return false;
67-
}
68-
return true;
67+
// Tool use streaming sequence ends with a CONTENT_BLOCK_STOP event.
68+
// The logic relies on the state machine (isInsideTool flag) managed in
69+
// chatCompletionStream to know if this stop event corresponds to a tool use.
70+
return event != null && event.type() != null && event.type() == EventType.CONTENT_BLOCK_STOP;
6971
}
7072

73+
/**
74+
* Merge the tool‑use related streaming events into one aggregate event so that the
75+
* upper layers see a single ContentBlock with the full JSON input.
76+
*/
7177
public StreamEvent mergeToolUseEvents(StreamEvent previousEvent, StreamEvent event) {
7278

7379
ToolUseAggregationEvent eventAggregator = (ToolUseAggregationEvent) previousEvent;
@@ -76,8 +82,7 @@ public StreamEvent mergeToolUseEvents(StreamEvent previousEvent, StreamEvent eve
7682
ContentBlockStartEvent contentBlockStart = (ContentBlockStartEvent) event;
7783

7884
if (ContentBlock.Type.TOOL_USE.getValue().equals(contentBlockStart.contentBlock().type())) {
79-
ContentBlockStartEvent.ContentBlockToolUse cbToolUse = (ContentBlockToolUse) contentBlockStart
80-
.contentBlock();
85+
ContentBlockToolUse cbToolUse = (ContentBlockToolUse) contentBlockStart.contentBlock();
8186

8287
return eventAggregator.withIndex(contentBlockStart.index())
8388
.withId(cbToolUse.id())
@@ -102,6 +107,14 @@ else if (event.type() == EventType.CONTENT_BLOCK_STOP) {
102107
return event;
103108
}
104109

110+
/**
111+
* Converts a raw {@link StreamEvent} potentially containing tool use aggregates or
112+
* other block types (text, thinking) into a {@link ChatCompletionResponse} chunk.
113+
* @param event The incoming StreamEvent.
114+
* @param contentBlockReference Holds the state of the response being built across
115+
* multiple events.
116+
* @return A ChatCompletionResponse representing the processed chunk.
117+
*/
105118
public ChatCompletionResponse eventToChatCompletionResponse(StreamEvent event,
106119
AtomicReference<ChatCompletionResponseBuilder> contentBlockReference) {
107120

@@ -135,28 +148,41 @@ else if (event.type().equals(EventType.TOOL_USE_AGGREGATE)) {
135148
else if (event.type().equals(EventType.CONTENT_BLOCK_START)) {
136149
ContentBlockStartEvent contentBlockStartEvent = (ContentBlockStartEvent) event;
137150

138-
Assert.isTrue(contentBlockStartEvent.contentBlock().type().equals("text"),
139-
"The json content block should have been aggregated. Unsupported content block type: "
140-
+ contentBlockStartEvent.contentBlock().type());
141-
142-
ContentBlockText contentBlockText = (ContentBlockText) contentBlockStartEvent.contentBlock();
143-
ContentBlock contentBlock = new ContentBlock(Type.TEXT, null, contentBlockText.text(),
144-
contentBlockStartEvent.index());
145-
contentBlockReference.get().withType(event.type().name()).withContent(List.of(contentBlock));
151+
if (contentBlockStartEvent.contentBlock() instanceof ContentBlockText textBlock) {
152+
ContentBlock cb = new ContentBlock(Type.TEXT, null, textBlock.text(), contentBlockStartEvent.index());
153+
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
154+
}
155+
else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) {
156+
ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null,
157+
null, null, null, null, null, thinkingBlock.thinking(), null);
158+
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
159+
}
160+
else {
161+
throw new IllegalArgumentException(
162+
"Unsupported content block type: " + contentBlockStartEvent.contentBlock().type());
163+
}
146164
}
147165
else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) {
148-
149166
ContentBlockDeltaEvent contentBlockDeltaEvent = (ContentBlockDeltaEvent) event;
150167

151-
Assert.isTrue(contentBlockDeltaEvent.delta().type().equals("text_delta"),
152-
"The json content block delta should have been aggregated. Unsupported content block type: "
153-
+ contentBlockDeltaEvent.delta().type());
154-
155-
ContentBlockDeltaText deltaTxt = (ContentBlockDeltaText) contentBlockDeltaEvent.delta();
156-
157-
var contentBlock = new ContentBlock(Type.TEXT_DELTA, null, deltaTxt.text(), contentBlockDeltaEvent.index());
158-
159-
contentBlockReference.get().withType(event.type().name()).withContent(List.of(contentBlock));
168+
if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaText txt) {
169+
ContentBlock cb = new ContentBlock(Type.TEXT_DELTA, null, txt.text(), contentBlockDeltaEvent.index());
170+
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
171+
}
172+
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) {
173+
ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(),
174+
null, null, null, null, null, null, thinking.thinking(), null);
175+
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
176+
}
177+
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) {
178+
ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(),
179+
null, null, null, null, null, sig.signature(), null, null);
180+
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
181+
}
182+
else {
183+
throw new IllegalArgumentException(
184+
"Unsupported content block delta type: " + contentBlockDeltaEvent.delta().type());
185+
}
160186
}
161187
else if (event.type().equals(EventType.MESSAGE_DELTA)) {
162188

@@ -173,21 +199,26 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
173199
}
174200

175201
if (messageDeltaEvent.usage() != null) {
176-
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
202+
Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177203
messageDeltaEvent.usage().outputTokens());
178204
contentBlockReference.get().withUsage(totalUsage);
179205
}
180206
}
181207
else if (event.type().equals(EventType.MESSAGE_STOP)) {
182-
// pass through
208+
// pass through as‑is
183209
}
184210
else {
211+
// Any other event types that should propagate upwards without content
185212
contentBlockReference.get().withType(event.type().name()).withContent(List.of());
186213
}
187214

188215
return contentBlockReference.get().build();
189216
}
190217

218+
/**
219+
* Builder for {@link ChatCompletionResponse}. Used internally by {@link StreamHelper}
220+
* to aggregate stream events.
221+
*/
191222
public static class ChatCompletionResponseBuilder {
192223

193224
private String type;

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

+33-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ void functionCallTest() {
288288
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
289289
assertThat(response.getMetadata()).isNotNull();
290290
assertThat(response.getMetadata().getUsage()).isNotNull();
291-
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
291+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(100);
292292
}
293293

294294
@Test
@@ -415,6 +415,38 @@ else if (message.getMetadata().containsKey("data")) { // redacted thinking
415415
}
416416
}
417417

418+
@Test
419+
void thinkingWithStreamingTest() {
420+
UserMessage userMessage = new UserMessage(
421+
"Are there an infinite number of prime numbers such that n mod 4 == 3?");
422+
423+
var promptOptions = AnthropicChatOptions.builder()
424+
.model(AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getName())
425+
.temperature(1.0) // Temperature should be set to 1 when thinking is enabled
426+
.maxTokens(8192)
427+
.thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && <
428+
// max_tokens
429+
.build();
430+
431+
Flux<ChatResponse> responseFlux = this.streamingChatModel
432+
.stream(new Prompt(List.of(userMessage), promptOptions));
433+
434+
String content = responseFlux.collectList()
435+
.block()
436+
.stream()
437+
.map(ChatResponse::getResults)
438+
.flatMap(List::stream)
439+
.map(Generation::getOutput)
440+
.map(AssistantMessage::getText)
441+
.filter(text -> text != null && !text.isBlank())
442+
.collect(Collectors.joining());
443+
444+
logger.info("Response: {}", content);
445+
446+
assertThat(content).isNotBlank();
447+
assertThat(content).contains("prime numbers");
448+
}
449+
418450
@Test
419451
void testToolUseContentBlock() {
420452
UserMessage userMessage = new UserMessage(

0 commit comments

Comments
 (0)