Skip to content

feat(ollama): add retry template integration to OllamaChatModel #1852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -122,6 +124,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode

private final ToolCallingManager toolCallingManager;

private final RetryTemplate retryTemplate;

private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

@Deprecated
Expand All @@ -130,14 +134,15 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
ModelManagementOptions modelManagementOptions) {
this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks),
observationRegistry, modelManagementOptions);
observationRegistry, modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);

logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
RetryTemplate retryTemplate) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
Expand All @@ -147,11 +152,13 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCa
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.toolCallingManager = toolCallingManager;
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
this.retryTemplate = retryTemplate;
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

Expand Down Expand Up @@ -237,7 +244,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
this.observationRegistry)
.observe(() -> {

OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));

List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
: ollamaResponse.message()
Expand Down Expand Up @@ -543,6 +550,8 @@ public static final class Builder {

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private RetryTemplate retryTemplate;

private Builder() {
}

Expand Down Expand Up @@ -583,6 +592,11 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti
return this;
}

public Builder retryTemplate(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
return this;
}

public OllamaChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
Expand All @@ -591,7 +605,7 @@ public OllamaChatModel build() {
"toolFunctionCallbacks must not be set when toolCallingManager is set");

return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
this.observationRegistry, this.modelManagementOptions);
this.observationRegistry, this.modelManagementOptions, this.retryTemplate);
}

if (functionCallbackResolver != null) {
Expand All @@ -603,8 +617,12 @@ public OllamaChatModel build() {
toolCallbacks, this.observationRegistry, this.modelManagementOptions);
}

if (this.retryTemplate == null) {
this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
}

return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
this.observationRegistry, this.modelManagementOptions);
this.observationRegistry, this.modelManagementOptions, this.retryTemplate);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.0
*/
// @formatter:off
Expand All @@ -64,8 +65,6 @@ public class OllamaApi {

private static final String DEFAULT_BASE_URL = "http://localhost:11434";

private final ResponseErrorHandler responseErrorHandler;

private final RestClient restClient;

private final WebClient webClient;
Expand Down Expand Up @@ -93,14 +92,16 @@ public OllamaApi(String baseUrl) {
*/
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {

this.responseErrorHandler = new OllamaResponseErrorHandler();
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();

Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
};

this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultStatusHandler(responseErrorHandler)
.defaultHeaders(defaultHeaders).build();

this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
}
Expand All @@ -121,7 +122,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
.uri("/api/chat")
.body(chatRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ChatResponse.class);
}

Expand Down Expand Up @@ -188,7 +188,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
.uri("/api/embed")
.body(embeddingsRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(EmbeddingsResponse.class);
}

Expand All @@ -199,7 +198,6 @@ public ListModelResponse listModels() {
return this.restClient.get()
.uri("/api/tags")
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ListModelResponse.class);
}

Expand All @@ -212,7 +210,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
.uri("/api/show")
.body(showModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ShowModelResponse.class);
}

Expand All @@ -225,7 +222,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
.uri("/api/copy")
.body(copyModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand All @@ -238,7 +234,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
.uri("/api/delete")
.body(deleteModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.tool.MockWeatherService;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -120,6 +121,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -277,6 +278,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
.additionalModels(List.of(ADDITIONAL_MODEL))
.build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -85,6 +86,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand All @@ -47,6 +48,7 @@
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
*
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
public class OllamaChatModelObservationIT extends BaseOllamaIT {
Expand Down Expand Up @@ -169,7 +171,11 @@ public OllamaApi openAiApi() {

@Bean
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
return OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.observationRegistry(observationRegistry)
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.retry.RetryUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -65,7 +66,7 @@ void buildOllamaChatModelWithDeprecatedConstructor() {
void buildOllamaChatModelWithConstructor() {
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(),
ObservationRegistry.NOOP, ModelManagementOptions.builder().build());
ObservationRegistry.NOOP, ModelManagementOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
assertThat(chatModel).isNotNull();
}

Expand All @@ -81,6 +82,7 @@ void buildOllamaChatModel() {
() -> OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.modelManagementOptions(null)
.build());
assertEquals("modelManagementOptions must not be null", exception.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.retry.RetryUtils;

import java.util.Map;

Expand All @@ -34,12 +35,14 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
class OllamaChatRequestTests {

OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();

@Test
Expand Down Expand Up @@ -146,6 +149,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();

var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));
Expand Down
Loading