Skip to content

Commit 55d38fd

Browse files
committed
feat(ollama): add retry template integration to OllamaChatModel
1 parent a474b12 commit 55d38fd

File tree

11 files changed

+56
-15
lines changed

11 files changed

+56
-15
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
import org.springframework.ai.ollama.management.OllamaModelManager;
6363
import org.springframework.ai.ollama.management.PullModelStrategy;
6464
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
65+
import org.springframework.ai.retry.RetryUtils;
66+
import org.springframework.retry.support.RetryTemplate;
6567
import org.springframework.util.Assert;
6668
import org.springframework.util.CollectionUtils;
6769
import org.springframework.util.StringUtils;
@@ -77,6 +79,7 @@
7779
* @author luocongqiu
7880
* @author Thomas Vitale
7981
* @author Jihoon Kim
82+
* @author Alexandros Pappas
8083
* @since 1.0.0
8184
*/
8285
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -107,20 +110,32 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
107110

108111
private final OllamaModelManager modelManager;
109112

113+
private final RetryTemplate retryTemplate;
114+
110115
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
111116

112117
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
113118
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
114119
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
120+
this(ollamaApi, defaultOptions, functionCallbackResolver, toolFunctionCallbacks, observationRegistry,
121+
modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
122+
}
123+
124+
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
125+
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
126+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
127+
RetryTemplate retryTemplate) {
115128
super(functionCallbackResolver, defaultOptions, toolFunctionCallbacks);
116129
Assert.notNull(ollamaApi, "ollamaApi must not be null");
117130
Assert.notNull(defaultOptions, "defaultOptions must not be null");
118131
Assert.notNull(observationRegistry, "observationRegistry must not be null");
119132
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
133+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
120134
this.chatApi = ollamaApi;
121135
this.defaultOptions = defaultOptions;
122136
this.observationRegistry = observationRegistry;
123137
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
138+
this.retryTemplate = retryTemplate;
124139
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
125140
}
126141

@@ -198,7 +213,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
198213
this.observationRegistry)
199214
.observe(() -> {
200215

201-
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
216+
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
202217

203218
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
204219
: ollamaResponse.message()
@@ -470,6 +485,8 @@ public static final class Builder {
470485

471486
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
472487

488+
private RetryTemplate retryTemplate;
489+
473490
private Builder() {
474491
}
475492

@@ -513,9 +530,15 @@ public Builder withModelManagementOptions(ModelManagementOptions modelManagement
513530
return this;
514531
}
515532

533+
public Builder withRetryTemplate(RetryTemplate retryTemplate) {
534+
this.retryTemplate = retryTemplate;
535+
return this;
536+
}
537+
516538
public OllamaChatModel build() {
517539
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
518-
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
540+
this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions,
541+
this.retryTemplate);
519542
}
520543

521544
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
*
5454
* @author Christian Tzolov
5555
* @author Thomas Vitale
56+
* @author Alexandros Pappas
5657
* @since 0.8.0
5758
*/
5859
// @formatter:off
@@ -66,8 +67,6 @@ public class OllamaApi {
6667

6768
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
6869

69-
private final ResponseErrorHandler responseErrorHandler;
70-
7170
private final RestClient restClient;
7271

7372
private final WebClient webClient;
@@ -95,14 +94,16 @@ public OllamaApi(String baseUrl) {
9594
*/
9695
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
9796

98-
this.responseErrorHandler = new OllamaResponseErrorHandler();
97+
ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler();
9998

10099
Consumer<HttpHeaders> defaultHeaders = headers -> {
101100
headers.setContentType(MediaType.APPLICATION_JSON);
102101
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
103102
};
104103

105-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
104+
this.restClient = restClientBuilder.baseUrl(baseUrl)
105+
.defaultStatusHandler(responseErrorHandler)
106+
.defaultHeaders(defaultHeaders).build();
106107

107108
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
108109
}
@@ -123,7 +124,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
123124
.uri("/api/chat")
124125
.body(chatRequest)
125126
.retrieve()
126-
.onStatus(this.responseErrorHandler)
127127
.body(ChatResponse.class);
128128
}
129129

@@ -190,7 +190,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
190190
.uri("/api/embed")
191191
.body(embeddingsRequest)
192192
.retrieve()
193-
.onStatus(this.responseErrorHandler)
194193
.body(EmbeddingsResponse.class);
195194
}
196195

@@ -201,7 +200,6 @@ public ListModelResponse listModels() {
201200
return this.restClient.get()
202201
.uri("/api/tags")
203202
.retrieve()
204-
.onStatus(this.responseErrorHandler)
205203
.body(ListModelResponse.class);
206204
}
207205

@@ -214,7 +212,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
214212
.uri("/api/show")
215213
.body(showModelRequest)
216214
.retrieve()
217-
.onStatus(this.responseErrorHandler)
218215
.body(ShowModelResponse.class);
219216
}
220217

@@ -227,7 +224,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
227224
.uri("/api/copy")
228225
.body(copyModelRequest)
229226
.retrieve()
230-
.onStatus(this.responseErrorHandler)
231227
.toBodilessEntity();
232228
}
233229

@@ -240,7 +236,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
240236
.uri("/api/delete")
241237
.body(deleteModelRequest)
242238
.retrieve()
243-
.onStatus(this.responseErrorHandler)
244239
.toBodilessEntity();
245240
}
246241

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.ollama.api.OllamaApi;
3838
import org.springframework.ai.ollama.api.OllamaOptions;
3939
import org.springframework.ai.ollama.api.tool.MockWeatherService;
40+
import org.springframework.ai.retry.RetryUtils;
4041
import org.springframework.beans.factory.annotation.Autowired;
4142
import org.springframework.boot.SpringBootConfiguration;
4243
import org.springframework.boot.test.context.SpringBootTest;
@@ -123,6 +124,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
123124
return OllamaChatModel.builder()
124125
.withOllamaApi(ollamaApi)
125126
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
127+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
126128
.build();
127129
}
128130

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.ollama.management.ModelManagementOptions;
4343
import org.springframework.ai.ollama.management.OllamaModelManager;
4444
import org.springframework.ai.ollama.management.PullModelStrategy;
45+
import org.springframework.ai.retry.RetryUtils;
4546
import org.springframework.beans.factory.annotation.Autowired;
4647
import org.springframework.boot.SpringBootConfiguration;
4748
import org.springframework.boot.test.context.SpringBootTest;
@@ -249,6 +250,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
249250
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
250251
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
251252
.build())
253+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
252254
.build();
253255
}
254256

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.Media;
2828
import org.springframework.ai.ollama.api.OllamaApi;
2929
import org.springframework.ai.ollama.api.OllamaOptions;
30+
import org.springframework.ai.retry.RetryUtils;
3031
import org.springframework.beans.factory.annotation.Autowired;
3132
import org.springframework.boot.SpringBootConfiguration;
3233
import org.springframework.boot.test.context.SpringBootTest;
@@ -84,6 +85,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
8485
return OllamaChatModel.builder()
8586
.withOllamaApi(ollamaApi)
8687
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
88+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8789
.build();
8890
}
8991

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.boot.SpringBootConfiguration;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@
4748
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
4849
*
4950
* @author Thomas Vitale
51+
* @author Alexandros Pappas
5052
*/
5153
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
5254
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -172,6 +174,7 @@ public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegis
172174
return OllamaChatModel.builder()
173175
.withOllamaApi(ollamaApi)
174176
.withObservationRegistry(observationRegistry)
177+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
175178
.build();
176179
}
177180

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.ollama.api.OllamaApi;
3232
import org.springframework.ai.ollama.api.OllamaModel;
3333
import org.springframework.ai.ollama.api.OllamaOptions;
34+
import org.springframework.ai.retry.RetryUtils;
3435

3536
import static org.assertj.core.api.Assertions.assertThat;
3637
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -39,6 +40,7 @@
3940
/**
4041
* @author Jihoon Kim
4142
* @author Christian Tzolov
43+
* @author Alexandros Pappas
4244
* @since 1.0.0
4345
*/
4446
@ExtendWith(MockitoExtension.class)
@@ -53,6 +55,7 @@ public void buildOllamaChatModel() {
5355
() -> OllamaChatModel.builder()
5456
.withOllamaApi(this.ollamaApi)
5557
.withDefaultOptions(OllamaOptions.create().withModel(OllamaModel.LLAMA2))
58+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
5659
.withModelManagementOptions(null)
5760
.build());
5861
assertEquals("modelManagementOptions must not be null", exception.getMessage());

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@
2323
import org.springframework.ai.chat.prompt.Prompt;
2424
import org.springframework.ai.ollama.api.OllamaApi;
2525
import org.springframework.ai.ollama.api.OllamaOptions;
26+
import org.springframework.ai.retry.RetryUtils;
2627

2728
import static org.assertj.core.api.Assertions.assertThat;
2829

2930
/**
3031
* @author Christian Tzolov
3132
* @author Thomas Vitale
33+
* @author Alexandros Pappas
3234
*/
3335
public class OllamaChatRequestTests {
3436

3537
OllamaChatModel chatModel = OllamaChatModel.builder()
3638
.withOllamaApi(new OllamaApi())
3739
.withDefaultOptions(
3840
OllamaOptions.create().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6).withNumGPU(1))
41+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
3942
.build();
4043

4144
@Test
@@ -113,6 +116,7 @@ public void createRequestWithDefaultOptionsModelOverride() {
113116
OllamaChatModel chatModel = OllamaChatModel.builder()
114117
.withOllamaApi(new OllamaApi())
115118
.withDefaultOptions(OllamaOptions.create().withModel("DEFAULT_OPTIONS_MODEL"))
119+
.withRetryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
116120
.build();
117121

118122
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.boot.context.properties.EnableConfigurationProperties;
4242
import org.springframework.context.ApplicationContext;
4343
import org.springframework.context.annotation.Bean;
44+
import org.springframework.retry.support.RetryTemplate;
4445
import org.springframework.web.client.RestClient;
4546
import org.springframework.web.reactive.function.client.WebClient;
4647

@@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails,
8283
public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties,
8384
OllamaInitializationProperties initProperties, List<FunctionCallback> toolFunctionCallbacks,
8485
FunctionCallbackResolver functionCallbackResolver, ObjectProvider<ObservationRegistry> observationRegistry,
85-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
86+
ObjectProvider<ChatModelObservationConvention> observationConvention, RetryTemplate retryTemplate) {
8687
var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy()
8788
: PullModelStrategy.NEVER;
8889

@@ -95,6 +96,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties
9596
.withModelManagementOptions(
9697
new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(),
9798
initProperties.getTimeout(), initProperties.getMaxRetries()))
99+
.withRetryTemplate(retryTemplate)
98100
.build();
99101

100102
observationConvention.ifAvailable(chatModel::setObservationConvention);

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,7 +42,8 @@ public void propertiesTest() {
4142
"spring.ai.ollama.chat.options.topP=0.56",
4243
"spring.ai.ollama.chat.options.topK=123")
4344
// @formatter:on
44-
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
45+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
46+
RestClientAutoConfiguration.class, OllamaAutoConfiguration.class))
4547
.run(context -> {
4648
var chatProperties = context.getBean(OllamaChatProperties.class);
4749
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

0 commit comments

Comments
 (0)