Skip to content

Commit 6cf8f67

Browse files
committed
Fix ollama arguments
1 parent fa7cc94 commit 6cf8f67

File tree

2 files changed

+86
-13
lines changed

2 files changed

+86
-13
lines changed

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import warnings
1718
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
1819

1920
from pydantic import ValidationError
@@ -59,6 +60,19 @@ def __init__(
5960
self.async_client = ollama.AsyncClient(
6061
**kwargs,
6162
)
63+
if "stream" in self.model_params:
64+
raise ValueError("Streaming is not supported by the OllamaLLM wrapper")
65+
# bug-fix with backward compatibility:
66+
# we mistakenly passed all "model_params" under the options argument
67+
# next two lines to be removed in 2.0
68+
if not any(
69+
key in self.model_params for key in ("options", "format", "keep_alive")
70+
):
71+
warnings.warn(
72+
"""Passing options directly without including them in an 'options' key is deprecated. Ie you must use model_params={"options": {"temperature": 0}}""",
73+
DeprecationWarning,
74+
)
75+
self.model_params = {"options": self.model_params}
6276

6377
def get_messages(
6478
self,
@@ -104,7 +118,7 @@ def invoke(
104118
response = self.client.chat(
105119
model=self.model_name,
106120
messages=self.get_messages(input, message_history, system_instruction),
107-
options=self.model_params,
121+
**self.model_params,
108122
)
109123
content = response.message.content or ""
110124
return LLMResponse(content=content)

tests/unit/llm/test_ollama_llm.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,79 @@ def test_ollama_llm_missing_dependency(mock_import: Mock) -> None:
3535

3636

3737
@patch("builtins.__import__")
38-
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
38+
def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None:
3939
mock_ollama = get_mock_ollama()
4040
mock_import.return_value = mock_ollama
4141
mock_ollama.Client.return_value.chat.return_value = MagicMock(
4242
message=MagicMock(content="ollama chat response"),
4343
)
4444
model = "gpt"
4545
model_params = {"temperature": 0.3}
46+
with pytest.warns(DeprecationWarning) as record:
47+
llm = OllamaLLM(
48+
model,
49+
model_params=model_params,
50+
)
51+
assert len(record) == 1
52+
assert (
53+
'you must use model_params={"options": {"temperature": 0}}'
54+
in record[0].message.args[0]
55+
)
56+
57+
question = "What is graph RAG?"
58+
res = llm.invoke(question)
59+
assert isinstance(res, LLMResponse)
60+
assert res.content == "ollama chat response"
61+
messages = [
62+
{"role": "user", "content": question},
63+
]
64+
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
65+
model=model, messages=messages, options={"temperature": 0.3}
66+
)
67+
68+
69+
@patch("builtins.__import__")
70+
def test_ollama_llm_unsupported_streaming(mock_import: Mock) -> None:
71+
mock_ollama = get_mock_ollama()
72+
mock_import.return_value = mock_ollama
73+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
74+
message=MagicMock(content="ollama chat response"),
75+
)
76+
model = "gpt"
77+
model_params = {"stream": True}
78+
with pytest.raises(ValueError):
79+
OllamaLLM(
80+
model,
81+
model_params=model_params,
82+
)
83+
84+
85+
@patch("builtins.__import__")
86+
def test_ollama_llm_happy_path(mock_import: Mock) -> None:
87+
mock_ollama = get_mock_ollama()
88+
mock_import.return_value = mock_ollama
89+
mock_ollama.Client.return_value.chat.return_value = MagicMock(
90+
message=MagicMock(content="ollama chat response"),
91+
)
92+
model = "gpt"
93+
options = {"temperature": 0.3}
94+
model_params = {"options": options, "format": "json"}
4695
question = "What is graph RAG?"
4796
llm = OllamaLLM(
48-
model,
97+
model_name=model,
4998
model_params=model_params,
5099
)
51-
52100
res = llm.invoke(question)
53101
assert isinstance(res, LLMResponse)
54102
assert res.content == "ollama chat response"
55103
messages = [
56104
{"role": "user", "content": question},
57105
]
58106
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
59-
model=model, messages=messages, options=model_params
107+
model=model,
108+
messages=messages,
109+
options=options,
110+
format="json",
60111
)
61112

62113

@@ -68,7 +119,8 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
68119
message=MagicMock(content="ollama chat response"),
69120
)
70121
model = "gpt"
71-
model_params = {"temperature": 0.3}
122+
options = {"temperature": 0.3}
123+
model_params = {"options": options, "format": "json"}
72124
llm = OllamaLLM(
73125
model,
74126
model_params=model_params,
@@ -81,7 +133,10 @@ def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) ->
81133
messages = [{"role": "system", "content": system_instruction}]
82134
messages.append({"role": "user", "content": question})
83135
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
84-
model=model, messages=messages, options=model_params
136+
model=model,
137+
messages=messages,
138+
options=options,
139+
format="json",
85140
)
86141

87142

@@ -93,7 +148,8 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
93148
message=MagicMock(content="ollama chat response"),
94149
)
95150
model = "gpt"
96-
model_params = {"temperature": 0.3}
151+
options = {"temperature": 0.3}
152+
model_params = {"options": options}
97153
llm = OllamaLLM(
98154
model,
99155
model_params=model_params,
@@ -109,7 +165,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non
109165
messages = [m for m in message_history]
110166
messages.append({"role": "user", "content": question})
111167
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
112-
model=model, messages=messages, options=model_params
168+
model=model, messages=messages, options=options
113169
)
114170

115171

@@ -123,7 +179,8 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
123179
message=MagicMock(content="ollama chat response"),
124180
)
125181
model = "gpt"
126-
model_params = {"temperature": 0.3}
182+
options = {"temperature": 0.3}
183+
model_params = {"options": options}
127184
system_instruction = "You are a helpful assistant."
128185
llm = OllamaLLM(
129186
model,
@@ -145,7 +202,7 @@ def test_ollama_invoke_with_message_history_and_system_instruction(
145202
messages.extend(message_history)
146203
messages.append({"role": "user", "content": question})
147204
llm.client.chat.assert_called_once_with( # type: ignore[attr-defined]
148-
model=model, messages=messages, options=model_params
205+
model=model, messages=messages, options=options
149206
)
150207
assert llm.client.chat.call_count == 1 # type: ignore
151208

@@ -156,7 +213,8 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock)
156213
mock_import.return_value = mock_ollama
157214
mock_ollama.ResponseError = ollama.ResponseError
158215
model = "gpt"
159-
model_params = {"temperature": 0.3}
216+
options = {"temperature": 0.3}
217+
model_params = {"options": options}
160218
system_instruction = "You are a helpful assistant."
161219
llm = OllamaLLM(
162220
model,
@@ -187,7 +245,8 @@ async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock:
187245

188246
mock_ollama.AsyncClient.return_value.chat = mock_chat_async
189247
model = "gpt"
190-
model_params = {"temperature": 0.3}
248+
options = {"temperature": 0.3}
249+
model_params = {"options": options}
191250
question = "What is graph RAG?"
192251
llm = OllamaLLM(
193252
model,

0 commit comments

Comments
 (0)