-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Open
Description
May I ask a question? In the chat scenario of Spring AI 1.0.0 version, after the user uploads an image/pdf, the large model interprets it and streams back the interpretation information. At this time, the returned information needs to include the search source. How was this scenario implemented? Using the large model Gemini, but Spring AI 1.0.0 did not return the search source. I modified the source code of the VertexAiGeminiChatModel class myself, but found that only non streaming interfaces can return the search source, while streaming interfaces still cannot return the search source. Here is the modified source code:
protected` List<Generation> responseCandidateToGeneration(Candidate candidate) {
// TODO - The candidateIndex (e.g. choice must be assigned to the generation).
int candidateIndex = candidate.getIndex();
FinishReason candidateFinishReason = candidate.getFinishReason();
// Convert from VertexAI protobuf to VertexAiGeminiApi DTOs
List<VertexAiGeminiApi.LogProbs.TopContent> topCandidates = candidate.getLogprobsResult()
.getTopCandidatesList()
.stream()
.filter(topCandidate -> !topCandidate.getCandidatesList().isEmpty())
.map(topCandidate -> new VertexAiGeminiApi.LogProbs.TopContent(topCandidate.getCandidatesList()
.stream()
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
.toList()))
.toList();
List<VertexAiGeminiApi.LogProbs.Content> chosenCandidates = candidate.getLogprobsResult()
.getChosenCandidatesList()
.stream()
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
.toList();
VertexAiGeminiApi.LogProbs logprobs = new VertexAiGeminiApi.LogProbs(candidate.getAvgLogprobs(), topCandidates,
chosenCandidates);
Map<String, Object> messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason",
candidateFinishReason, "logprobs", logprobs);
ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder()
.finishReason(candidateFinishReason.name())
.build();
com.google.cloud.vertexai.api.GroundingMetadata vertexaiGroundingMetadata = candidate.getGroundingMetadata();
GroundingMetadata groundingMetadata = new GroundingMetadata();
groundingMetadata.setWebSearchQueries(vertexaiGroundingMetadata.getWebSearchQueriesList());
com.google.cloud.vertexai.api.SearchEntryPoint vertexaiSearchEntryPoint = vertexaiGroundingMetadata
.getSearchEntryPoint();
org.springframework.ai.chat.metadata.SearchEntryPoint searchEntryPoint = new org.springframework.ai.chat.metadata.SearchEntryPoint(
vertexaiSearchEntryPoint.getRenderedContent());
groundingMetadata.setSearchEntryPoint(searchEntryPoint);
if (!CollectionUtils.isEmpty(vertexaiGroundingMetadata.getGroundingChunksList())) {
List<GroundingChunk> groundingChunkList = vertexaiGroundingMetadata.getGroundingChunksList()
.stream()
.map(t -> {
com.google.cloud.vertexai.api.GroundingChunk.Web vertexaiWeb = t.getWeb();
GroundingChunk groundingChunk = new GroundingChunk();
Web web = new Web();
web.setUri(vertexaiWeb.getUri());
web.setTitle(vertexaiWeb.getTitle());
groundingChunk.setWeb(web);
return groundingChunk;
})
.toList();
groundingMetadata.setGroundingChunks(groundingChunkList);
}
if (!CollectionUtils.isEmpty(vertexaiGroundingMetadata.getGroundingSupportsList())) {
List<GroundingSupport> groundingSupportList = vertexaiGroundingMetadata.getGroundingSupportsList()
.stream()
.map(t -> {
GroundingSupport groundingSupport = new GroundingSupport();
com.google.cloud.vertexai.api.Segment vertexaiSegment = t.getSegment();
Segment segment = new Segment();
segment.setStartIndex(vertexaiSegment.getStartIndex());
segment.setEndIndex(vertexaiSegment.getEndIndex());
segment.setText(vertexaiSegment.getText());
groundingSupport.setSegment(segment);
groundingSupport.setGroundingChunkIndices(t.getGroundingChunkIndicesList());
return groundingSupport;
})
.toList();
groundingMetadata.setGroundingSupports(groundingSupportList);
}
boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall);
if (isFunctionCall) {
List<AssistantMessage.ToolCall> assistantToolCalls = candidate.getContent()
.getPartsList()
.stream()
.filter(part -> part.hasFunctionCall())
.map(part -> {
FunctionCall functionCall = part.getFunctionCall();
var functionName = functionCall.getName();
String functionArguments = structToJson(functionCall.getArgs());
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
})
.toList();
AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls);
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
}
else {
List<Generation> generations = candidate.getContent()
.getPartsList()
.stream()
.map(part -> new AssistantMessage(part.getText(), messageMetadata))
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata, groundingMetadata))
.toList();
return generations;
}
}