diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 26613c0b7..5fb0c1ffe 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -68,14 +68,12 @@ def __init__( ollama_client_args: Additional arguments for the Ollama client. **model_config: Configuration options for the Ollama model. """ + self.host = host + self.client_args = ollama_client_args or {} self.config = OllamaModel.OllamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) - ollama_client_args = ollama_client_args if ollama_client_args is not None else {} - - self.client = ollama.AsyncClient(host, **ollama_client_args) - @override def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: ignore """Update the Ollama Model configuration with the provided arguments. @@ -306,7 +304,8 @@ async def stream( logger.debug("invoking model") tool_requested = False - response = await self.client.chat(**request) + client = ollama.AsyncClient(self.host, **self.client_args) + response = await client.chat(**request) logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -346,7 +345,9 @@ async def structured_output( formatted_request = self.format_request(messages=prompt) formatted_request["format"] = output_model.model_json_schema() formatted_request["stream"] = False - response = await self.client.chat(**formatted_request) + + client = ollama.AsyncClient(self.host, **self.client_args) + response = await client.chat(**formatted_request) try: content = response.message.content.strip()