|
25 | 25 |
|
26 | 26 | import com.google.cloud.vertexai.Transport;
|
27 | 27 | import com.google.cloud.vertexai.VertexAI;
|
| 28 | +import io.micrometer.observation.ObservationRegistry; |
28 | 29 | import org.jetbrains.annotations.NotNull;
|
29 | 30 | import org.junit.jupiter.api.Disabled;
|
30 | 31 | import org.junit.jupiter.api.Test;
|
31 | 32 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
32 | 33 |
|
| 34 | +import org.springframework.ai.chat.client.ChatClient; |
33 | 35 | import org.springframework.ai.chat.messages.AssistantMessage;
|
34 | 36 | import org.springframework.ai.chat.messages.Message;
|
35 | 37 | import org.springframework.ai.chat.messages.UserMessage;
|
|
42 | 44 | import org.springframework.ai.converter.BeanOutputConverter;
|
43 | 45 | import org.springframework.ai.converter.ListOutputConverter;
|
44 | 46 | import org.springframework.ai.converter.MapOutputConverter;
|
| 47 | +import org.springframework.ai.model.tool.ToolCallingManager; |
| 48 | +import org.springframework.ai.tool.annotation.Tool; |
45 | 49 | import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
|
46 | 50 | import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
|
47 | 51 | import org.springframework.beans.factory.annotation.Autowired;
|
@@ -293,6 +297,70 @@ void multiModalityPdfTest() throws IOException {
|
293 | 297 | assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API");
|
294 | 298 | }
|
295 | 299 |
|
| 300 | + /** |
| 301 | + * Helper method to create a VertexAI instance for tests |
| 302 | + */ |
| 303 | + private VertexAI vertexAiApi() { |
| 304 | + String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); |
| 305 | + String location = System.getenv("VERTEX_AI_GEMINI_LOCATION"); |
| 306 | + return new VertexAI.Builder().setProjectId(projectId) |
| 307 | + .setLocation(location) |
| 308 | + .setTransport(Transport.REST) |
| 309 | + .build(); |
| 310 | + } |
| 311 | + |
| 312 | + @Test |
| 313 | + void jsonArrayToolCallingTest() { |
| 314 | + // Test for the improved jsonToStruct method that handles JSON arrays in tool |
| 315 | + // calling |
| 316 | + |
| 317 | + ToolCallingManager toolCallingManager = ToolCallingManager.builder() |
| 318 | + .observationRegistry(ObservationRegistry.NOOP) |
| 319 | + .build(); |
| 320 | + |
| 321 | + VertexAiGeminiChatModel chatModelWithTools = VertexAiGeminiChatModel.builder() |
| 322 | + .vertexAI(vertexAiApi()) |
| 323 | + .toolCallingManager(toolCallingManager) |
| 324 | + .defaultOptions(VertexAiGeminiChatOptions.builder() |
| 325 | + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) |
| 326 | + .temperature(0.1) |
| 327 | + .build()) |
| 328 | + .build(); |
| 329 | + |
| 330 | + ChatClient chatClient = ChatClient.builder(chatModelWithTools).build(); |
| 331 | + |
| 332 | + // Create a prompt that will trigger the tool call with a specific request that |
| 333 | + // should invoke the tool |
| 334 | + String response = chatClient.prompt() |
| 335 | + .tools(new ScientistTools()) |
| 336 | + .user("List 3 famous scientists and their discoveries. Make sure to use the tool to get this information.") |
| 337 | + .call() |
| 338 | + .content(); |
| 339 | + |
| 340 | + assertThat(response).isNotEmpty(); |
| 341 | + |
| 342 | + assertThat(response).satisfiesAnyOf(content -> assertThat(content).contains("Einstein"), |
| 343 | + content -> assertThat(content).contains("Newton"), content -> assertThat(content).contains("Curie")); |
| 344 | + |
| 345 | + } |
| 346 | + |
| 347 | + /** |
| 348 | + * Tool class that returns a JSON array to test the jsonToStruct method's ability to |
| 349 | + * handle JSON arrays. This specifically tests the PR changes that improve the |
| 350 | + * jsonToStruct method to handle JSON arrays in addition to JSON objects. |
| 351 | + */ |
| 352 | + public static class ScientistTools { |
| 353 | + |
| 354 | + @Tool(description = "Get information about famous scientists and their discoveries") |
| 355 | + public List<Map<String, String>> getScientists() { |
| 356 | + // Return a JSON array with scientist information |
| 357 | + return List.of(Map.of("name", "Albert Einstein", "discovery", "Theory of Relativity"), |
| 358 | + Map.of("name", "Isaac Newton", "discovery", "Laws of Motion"), |
| 359 | + Map.of("name", "Marie Curie", "discovery", "Radioactivity")); |
| 360 | + } |
| 361 | + |
| 362 | + } |
| 363 | + |
296 | 364 | record ActorsFilmsRecord(String actor, List<String> movies) {
|
297 | 365 |
|
298 | 366 | }
|
|
0 commit comments