Skip to content

Commit 78d90cd

Browse files
tzolovmarkpollack
authored andcommitted
feat(vertex-ai-gemini): enhance jsonToStruct to support JSON arrays
- Improve the jsonToStruct method in VertexAiGeminiChatModel to handle JSON arrays in addition to JSON objects. When a JSON array is detected, it's now properly converted to a Protobuf Struct with an items field containing the array elements. - Added test Resolves #2647 , #2849 Signed-off-by: Christian Tzolov <[email protected]>
1 parent 2cbfb22 commit 78d90cd

File tree

4 files changed

+107
-7
lines changed

4 files changed

+107
-7
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
package org.springframework.ai.vertexai.gemini;
1818

19-
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
2019
import java.util.ArrayList;
2120
import java.util.Collection;
2221
import java.util.List;
2322
import java.util.Map;
2423

2524
import com.fasterxml.jackson.annotation.JsonInclude;
2625
import com.fasterxml.jackson.annotation.JsonInclude.Include;
26+
import com.fasterxml.jackson.databind.JsonNode;
2727
import com.google.cloud.vertexai.VertexAI;
2828
import com.google.cloud.vertexai.api.Candidate;
2929
import com.google.cloud.vertexai.api.Candidate.FinishReason;
@@ -33,15 +33,16 @@
3333
import com.google.cloud.vertexai.api.FunctionResponse;
3434
import com.google.cloud.vertexai.api.GenerateContentResponse;
3535
import com.google.cloud.vertexai.api.GenerationConfig;
36-
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
3736
import com.google.cloud.vertexai.api.Part;
3837
import com.google.cloud.vertexai.api.SafetySetting;
3938
import com.google.cloud.vertexai.api.Schema;
4039
import com.google.cloud.vertexai.api.Tool;
40+
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
4141
import com.google.cloud.vertexai.generativeai.GenerativeModel;
4242
import com.google.cloud.vertexai.generativeai.PartMaker;
4343
import com.google.cloud.vertexai.generativeai.ResponseStream;
4444
import com.google.protobuf.Struct;
45+
import com.google.protobuf.Value;
4546
import com.google.protobuf.util.JsonFormat;
4647
import io.micrometer.observation.Observation;
4748
import io.micrometer.observation.ObservationRegistry;
@@ -226,7 +227,8 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
226227
this.observationRegistry = observationRegistry;
227228
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
228229

229-
// Wrap the provided tool calling manager in a VertexToolCallingManager to ensure
230+
// Wrap the provided tool calling manager in a VertexToolCallingManager to
231+
// ensure
230232
// compatibility with Vertex AI's OpenAPI schema format.
231233
if (toolCallingManager instanceof VertexToolCallingManager) {
232234
this.toolCallingManager = toolCallingManager;
@@ -334,8 +336,34 @@ private static String structToJson(Struct struct) {
334336

335337
private static Struct jsonToStruct(String json) {
336338
try {
337-
var structBuilder = Struct.newBuilder();
338-
JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder);
339+
JsonNode rootNode = ModelOptionsUtils.OBJECT_MAPPER.readTree(json);
340+
341+
Struct.Builder structBuilder = Struct.newBuilder();
342+
343+
if (rootNode.isArray()) {
344+
// Handle JSON array
345+
List<Value> values = new ArrayList<>();
346+
347+
for (JsonNode element : rootNode) {
348+
String elementJson = element.toString();
349+
Struct.Builder elementBuilder = Struct.newBuilder();
350+
JsonFormat.parser().ignoringUnknownFields().merge(elementJson, elementBuilder);
351+
352+
// Add each parsed object as a value in an array field
353+
values.add(Value.newBuilder().setStructValue(elementBuilder.build()).build());
354+
}
355+
356+
// Add the array to the main struct with a field name like "items"
357+
structBuilder.putFields("items",
358+
Value.newBuilder()
359+
.setListValue(com.google.protobuf.ListValue.newBuilder().addAllValues(values).build())
360+
.build());
361+
}
362+
else {
363+
// Original behavior for single JSON object
364+
JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder);
365+
}
366+
339367
return structBuilder.build();
340368
}
341369
catch (Exception e) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525

2626
import com.google.cloud.vertexai.Transport;
2727
import com.google.cloud.vertexai.VertexAI;
28+
import io.micrometer.observation.ObservationRegistry;
2829
import org.jetbrains.annotations.NotNull;
2930
import org.junit.jupiter.api.Disabled;
3031
import org.junit.jupiter.api.Test;
3132
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3233

34+
import org.springframework.ai.chat.client.ChatClient;
3335
import org.springframework.ai.chat.messages.AssistantMessage;
3436
import org.springframework.ai.chat.messages.Message;
3537
import org.springframework.ai.chat.messages.UserMessage;
@@ -42,6 +44,8 @@
4244
import org.springframework.ai.converter.BeanOutputConverter;
4345
import org.springframework.ai.converter.ListOutputConverter;
4446
import org.springframework.ai.converter.MapOutputConverter;
47+
import org.springframework.ai.model.tool.ToolCallingManager;
48+
import org.springframework.ai.tool.annotation.Tool;
4549
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
4650
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
4751
import org.springframework.beans.factory.annotation.Autowired;
@@ -293,6 +297,70 @@ void multiModalityPdfTest() throws IOException {
293297
assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API");
294298
}
295299

300+
/**
301+
* Helper method to create a VertexAI instance for tests
302+
*/
303+
private VertexAI vertexAiApi() {
304+
String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
305+
String location = System.getenv("VERTEX_AI_GEMINI_LOCATION");
306+
return new VertexAI.Builder().setProjectId(projectId)
307+
.setLocation(location)
308+
.setTransport(Transport.REST)
309+
.build();
310+
}
311+
312+
@Test
313+
void jsonArrayToolCallingTest() {
314+
// Test for the improved jsonToStruct method that handles JSON arrays in tool
315+
// calling
316+
317+
ToolCallingManager toolCallingManager = ToolCallingManager.builder()
318+
.observationRegistry(ObservationRegistry.NOOP)
319+
.build();
320+
321+
VertexAiGeminiChatModel chatModelWithTools = VertexAiGeminiChatModel.builder()
322+
.vertexAI(vertexAiApi())
323+
.toolCallingManager(toolCallingManager)
324+
.defaultOptions(VertexAiGeminiChatOptions.builder()
325+
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH)
326+
.temperature(0.1)
327+
.build())
328+
.build();
329+
330+
ChatClient chatClient = ChatClient.builder(chatModelWithTools).build();
331+
332+
// Create a prompt that will trigger the tool call with a specific request that
333+
// should invoke the tool
334+
String response = chatClient.prompt()
335+
.tools(new ScientistTools())
336+
.user("List 3 famous scientists and their discoveries. Make sure to use the tool to get this information.")
337+
.call()
338+
.content();
339+
340+
assertThat(response).isNotEmpty();
341+
342+
assertThat(response).satisfiesAnyOf(content -> assertThat(content).contains("Einstein"),
343+
content -> assertThat(content).contains("Newton"), content -> assertThat(content).contains("Curie"));
344+
345+
}
346+
347+
/**
348+
* Tool class that returns a JSON array to test the jsonToStruct method's ability to
349+
* handle JSON arrays. This specifically tests the PR changes that improve the
350+
* jsonToStruct method to handle JSON arrays in addition to JSON objects.
351+
*/
352+
public static class ScientistTools {
353+
354+
@Tool(description = "Get information about famous scientists and their discoveries")
355+
public List<Map<String, String>> getScientists() {
356+
// Return a JSON array with scientist information
357+
return List.of(Map.of("name", "Albert Einstein", "discovery", "Theory of Relativity"),
358+
Map.of("name", "Isaac Newton", "discovery", "Laws of Motion"),
359+
Map.of("name", "Marie Curie", "discovery", "Radioactivity"));
360+
}
361+
362+
}
363+
296364
record ActorsFilmsRecord(String actor, List<String> movies) {
297365

298366
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public void functionCallTestInferredOpenApiSchema() {
126126

127127
assertThat(chatResponse.getMetadata()).isNotNull();
128128
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
129-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
129+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(330);
130130

131131
ChatResponse response2 = this.chatModel
132132
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
@@ -201,7 +201,7 @@ public void functionCallUsageTestInferredOpenApiSchemaStream() {
201201
assertThat(chatResponse).isNotNull();
202202
assertThat(chatResponse.getMetadata()).isNotNull();
203203
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
204-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
204+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(330);
205205

206206
}
207207

spring-ai-model/src/main/java/org/springframework/ai/converter/ListOutputConverter.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
*/
3232
public class ListOutputConverter extends AbstractConversionServiceOutputConverter<List<String>> {
3333

34+
public ListOutputConverter() {
35+
this(new DefaultConversionService());
36+
}
37+
3438
public ListOutputConverter(DefaultConversionService defaultConversionService) {
3539
super(defaultConversionService);
3640
}

0 commit comments

Comments
 (0)