diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index ab8bfe27317..0e6398fb64b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -69,6 +69,8 @@ import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -122,6 +124,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode private final ToolCallingManager toolCallingManager; + private final RetryTemplate retryTemplate; + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; @Deprecated @@ -130,14 +134,15 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks), - observationRegistry, modelManagementOptions); + observationRegistry, modelManagementOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); logger.warn("This constructor is deprecated and will be removed in the next milestone. " + "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead."); } public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, - ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { + ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, + RetryTemplate retryTemplate) { // We do not pass the 'defaultOptions' to the AbstractToolSupport, // because it modifies them. We are using ToolCallingManager instead, // so we just pass empty options here. @@ -147,11 +152,13 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCa Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null"); Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); + this.retryTemplate = retryTemplate; initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); } @@ -237,7 +244,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon this.observationRegistry) .observe(() -> { - OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request); + OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request)); List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message() @@ -543,6 +550,8 @@ public static final class Builder { private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults(); + private RetryTemplate retryTemplate; + private Builder() { } @@ -583,6 +592,11 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti return this; } + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + public OllamaChatModel build() { if (toolCallingManager != null) { Assert.isNull(functionCallbackResolver, @@ -591,7 +605,7 @@ public OllamaChatModel build() { "toolFunctionCallbacks must not be set when toolCallingManager is set"); return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, - this.observationRegistry, this.modelManagementOptions); + this.observationRegistry, this.modelManagementOptions, this.retryTemplate); } if (functionCallbackResolver != null) { @@ -603,8 +617,12 @@ public OllamaChatModel build() { toolCallbacks, this.observationRegistry, this.modelManagementOptions); } + if (this.retryTemplate == null) { + this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + } + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, - this.observationRegistry, this.modelManagementOptions); + this.observationRegistry, this.modelManagementOptions, this.retryTemplate); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index 1a238b5d890..3e522a412bb 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -51,6 +51,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas * @since 0.8.0 */ // @formatter:off @@ -64,8 +65,6 @@ public class OllamaApi { private static final String DEFAULT_BASE_URL = "http://localhost:11434"; - private final ResponseErrorHandler responseErrorHandler; - private final RestClient restClient; private final WebClient webClient; @@ -93,14 +92,16 @@ public OllamaApi(String baseUrl) { */ public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) { - this.responseErrorHandler = new OllamaResponseErrorHandler(); + ResponseErrorHandler responseErrorHandler = new OllamaResponseErrorHandler(); Consumer defaultHeaders = headers -> { headers.setContentType(MediaType.APPLICATION_JSON); headers.setAccept(List.of(MediaType.APPLICATION_JSON)); }; - this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultStatusHandler(responseErrorHandler) + .defaultHeaders(defaultHeaders).build(); this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); } @@ -121,7 +122,6 @@ public ChatResponse chat(ChatRequest chatRequest) { .uri("/api/chat") .body(chatRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(ChatResponse.class); } @@ -188,7 +188,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) { .uri("/api/embed") .body(embeddingsRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(EmbeddingsResponse.class); } @@ -199,7 +198,6 @@ public ListModelResponse listModels() { return this.restClient.get() .uri("/api/tags") .retrieve() - .onStatus(this.responseErrorHandler) .body(ListModelResponse.class); } @@ -212,7 +210,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) { .uri("/api/show") .body(showModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .body(ShowModelResponse.class); } @@ -225,7 +222,6 @@ public ResponseEntity copyModel(CopyModelRequest copyModelRequest) { .uri("/api/copy") .body(copyModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .toBodilessEntity(); } @@ -238,7 +234,6 @@ public ResponseEntity deleteModel(DeleteModelRequest deleteModelRequest) { .uri("/api/delete") .body(deleteModelRequest) .retrieve() - .onStatus(this.responseErrorHandler) .toBodilessEntity(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index 25f800ca9fa..e206d947f77 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -36,6 +36,7 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.tool.MockWeatherService; import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -120,6 +121,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { return OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index 8709a5b8b3a..27405fb13fb 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -44,6 +44,7 @@ import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -277,6 +278,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { .pullModelStrategy(PullModelStrategy.WHEN_MISSING) .additionalModels(List.of(ADDITIONAL_MODEL)) .build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 3c136560733..fb22c022b4c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -27,6 +27,7 @@ import org.springframework.ai.model.Media; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -85,6 +86,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) { return OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java index 916a364ba65..0d8b6a0b714 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java @@ -34,6 +34,7 @@ import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -47,6 +48,7 @@ * Integration tests for observation instrumentation in {@link OllamaChatModel}. * * @author Thomas Vitale + * @author Alexandros Pappas */ @SpringBootTest(classes = OllamaChatModelObservationIT.Config.class) public class OllamaChatModelObservationIT extends BaseOllamaIT { @@ -169,7 +171,11 @@ public OllamaApi openAiApi() { @Bean public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { - return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + return OllamaChatModel.builder() + .ollamaApi(ollamaApi) + .observationRegistry(observationRegistry) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); } } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index dc5b9bfb725..c9f53233617 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -35,6 +35,7 @@ import org.springframework.ai.ollama.api.OllamaModel; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; +import org.springframework.ai.retry.RetryUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -65,7 +66,7 @@ void buildOllamaChatModelWithDeprecatedConstructor() { void buildOllamaChatModelWithConstructor() { ChatModel chatModel = new OllamaChatModel(this.ollamaApi, OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), ToolCallingManager.builder().build(), - ObservationRegistry.NOOP, ModelManagementOptions.builder().build()); + ObservationRegistry.NOOP, ModelManagementOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); assertThat(chatModel).isNotNull(); } @@ -81,6 +82,7 @@ void buildOllamaChatModel() { () -> OllamaChatModel.builder() .ollamaApi(this.ollamaApi) .defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .modelManagementOptions(null) .build()); assertEquals("modelManagementOptions must not be null", exception.getMessage()); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index e3221f02640..0516a0f4b11 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.retry.RetryUtils; import java.util.Map; @@ -34,12 +35,14 @@ /** * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas */ class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(new OllamaApi()) .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); @Test @@ -146,6 +149,7 @@ public void createRequestWithDefaultOptionsModelOverride() { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(new OllamaApi()) .defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content")); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java new file mode 100644 index 00000000000..bb1c1bb22ee --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaRetryTests.java @@ -0,0 +1,97 @@ +package org.springframework.ai.ollama; + +import java.time.Instant; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * Tests for the OllamaRetryTests class. + * + * @author Alexandros Pappas + */ +@ExtendWith(MockitoExtension.class) +class OllamaRetryTests { + + private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + @Mock + private OllamaApi ollamaApi; + + private OllamaChatModel chatModel; + + @BeforeEach + public void beforeEach() { + this.retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + this.retryTemplate.registerListener(this.retryListener); + + this.chatModel = OllamaChatModel.builder() + .ollamaApi(this.ollamaApi) + .defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build()) + .retryTemplate(this.retryTemplate) + .build(); + } + + @Test + void ollamaChatTransientError() { + String promptText = "What is the capital of Bulgaria and what is the size? What it the national anthem?"; + var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(), + OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Response").build(), null, true, + null, null, null, null, null, null); + + when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(expectedChatResponse); + + var result = this.chatModel.call(new Prompt(promptText)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isSameAs("Response"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java index ecde0c935a6..5675e4c4a1a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -40,6 +40,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -82,7 +83,7 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails, public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties, OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager, ObjectProvider observationRegistry, - ObjectProvider observationConvention) { + ObjectProvider observationConvention, RetryTemplate retryTemplate) { var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER; @@ -94,6 +95,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties .modelManagementOptions( new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(), initProperties.getTimeout(), initProperties.getMaxRetries())) + .retryTemplate(retryTemplate) .build(); observationConvention.ifAvailable(chatModel::setObservationConvention); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java index 14493bcb20c..47c3b8ac3df 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaChatAutoConfigurationTests.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -41,7 +42,8 @@ public void propertiesTest() { "spring.ai.ollama.chat.options.topP=0.56", "spring.ai.ollama.chat.options.topK=123") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(OllamaChatProperties.class); var connectionProperties = context.getBean(OllamaConnectionProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java index bd2a8bfd2df..ccaaff04df0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationTests.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -26,6 +27,7 @@ /** * @author Christian Tzolov + * @author Alexandros Pappas * @since 0.8.0 */ public class OllamaEmbeddingAutoConfigurationTests { @@ -41,7 +43,8 @@ public void propertiesTest() { "spring.ai.ollama.embedding.options.topK=13" // @formatter:on ) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OllamaAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class); var connectionProperties = context.getBean(OllamaConnectionProperties.class);