Skip to content

Commit efb775b

Browse files
committed
Multi-line docstrings and exhaustive tests
1 parent 8ec5123 commit efb775b

File tree

3 files changed

+120
-22
lines changed

3 files changed

+120
-22
lines changed

ollama/_utils.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Dict, Set as TypeSet
77
import sys
88

9-
# Type compatibility layer
9+
# Type compatibility layer for Union and UnionType
1010
if sys.version_info >= (3, 10):
1111
from types import UnionType
1212

@@ -30,6 +30,7 @@ def is_union(tp: Any) -> bool:
3030
bool: 'boolean',
3131
'bool': 'boolean',
3232
type(None): 'null',
33+
None: 'null',
3334
'None': 'null',
3435
# Collection types
3536
list: 'array',
@@ -83,16 +84,16 @@ def _get_json_type(python_type: Any) -> str | List[str]:
8384
if issubclass(get_origin(python_type), (dict, Mapping)):
8485
return 'object'
8586

86-
# Handle both type objects and type references
87+
# Handle both type objects and type references (older Python versions)
8788
type_key = python_type
8889
if isinstance(python_type, type):
8990
type_key = python_type
9091
elif isinstance(python_type, str):
91-
type_key = python_type.lower()
92+
type_key = python_type
9293

9394
# If type not found in map, try to get the type name
9495
if type_key not in TYPE_MAP and hasattr(python_type, '__name__'):
95-
type_key = python_type.__name__.lower()
96+
type_key = python_type.__name__
9697

9798
if type_key in TYPE_MAP:
9899
return TYPE_MAP[type_key]
@@ -138,23 +139,47 @@ def convert_function_to_tool(func: Callable) -> Tool:
138139
if param_name == 'return':
139140
continue
140141

141-
param_desc = None
142+
param_desc_lines = []
143+
found_param = False
144+
indent_level = None
145+
146+
# Process docstring lines
142147
for line in args_section.split('\n'):
143-
line = line.strip()
144-
# Check for parameter name with or without colon, space, or parentheses to mitigate formatting issues
145-
if line.startswith(param_name + ':') or line.startswith(param_name + ' ') or line.startswith(param_name + '('):
146-
param_desc = line.split(':', 1)[1].strip()
148+
stripped_line = line.strip()
149+
if not stripped_line:
150+
continue
151+
152+
# Check for parameter start
153+
if stripped_line.startswith(f'{param_name}:') or stripped_line.startswith(f'{param_name} ') or stripped_line.startswith(f'{param_name}('):
154+
found_param = True
155+
# Get the description part after the parameter name
156+
desc_part = stripped_line.split(':', 1)[1].strip() if ':' in stripped_line else ''
157+
if desc_part:
158+
param_desc_lines.append(desc_part)
159+
# Get the indentation level for continuation lines
160+
indent_level = len(line) - len(line.lstrip())
161+
continue
162+
163+
# Handle continuation lines
164+
if found_param and line.startswith(' ' * (indent_level + 4)):
165+
# Add continuation line, stripped of extra indentation
166+
param_desc_lines.append(stripped_line)
167+
elif found_param and stripped_line:
168+
# If we hit a line with different indentation, we're done with this parameter
147169
break
148170

149-
if not param_desc:
171+
if not found_param:
150172
raise ValueError(f'Parameter {param_name} must have a description in the Args section')
151173

174+
# Join all lines with spaces
175+
param_desc = ' '.join(param_desc_lines).strip()
176+
152177
parameters['properties'][param_name] = {
153178
'type': _get_json_type(param_type),
154179
'description': param_desc,
155180
}
156181

157-
# Only add to required if not optional - could capture and map earlier to save this call
182+
# Only add to required if not optional
158183
if not _is_optional_type(param_type):
159184
parameters['required'].append(param_name)
160185

tests/test_utils.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def test_process_tools():
9898
def func1(x: int) -> str:
9999
"""Simple function 1.
100100
Args:
101-
x: A number
101+
x (integer): A number
102102
"""
103103
pass
104104

105105
def func2(y: str) -> int:
106106
"""Simple function 2.
107107
Args:
108-
y: A string
108+
y (string): A string
109109
"""
110110
pass
111111

@@ -126,3 +126,81 @@ def func2(y: str) -> int:
126126
assert len(tools) == 2
127127
assert tools[0].function.name == 'func1'
128128
assert tools[1].function.name == 'test'
129+
130+
131+
def test_advanced_json_type_conversion():
132+
from typing import Optional, Union, List, Dict, Sequence, Mapping, Set, Tuple, Any
133+
134+
# Test nested collections
135+
assert _get_json_type(List[List[int]]) == 'array'
136+
assert _get_json_type(Dict[str, List[int]]) == 'object'
137+
138+
# Test multiple unions
139+
assert set(_get_json_type(Union[int, str, float])) == {'integer', 'string', 'number'}
140+
141+
# Test collections.abc types
142+
assert _get_json_type(Sequence[int]) == 'array'
143+
assert _get_json_type(Mapping[str, int]) == 'object'
144+
assert _get_json_type(Set[int]) == 'array'
145+
assert _get_json_type(Tuple[int, str]) == 'array'
146+
147+
# Test nested optionals
148+
assert _get_json_type(Optional[List[Optional[int]]]) == 'array'
149+
150+
# Test edge cases
151+
assert _get_json_type(Any) == 'string' # or however you want to handle Any
152+
assert _get_json_type(None) == 'null'
153+
assert _get_json_type(type(None)) == 'null'
154+
155+
# Test complex nested types
156+
complex_type = Dict[str, Union[List[int], Optional[str], Dict[str, bool]]]
157+
assert _get_json_type(complex_type) == 'object'
158+
159+
160+
def test_invalid_types():
161+
# Test that invalid types raise appropriate errors
162+
with pytest.raises(ValueError):
163+
_get_json_type(lambda x: x) # Function type
164+
165+
with pytest.raises(ValueError):
166+
_get_json_type(type) # metaclass
167+
168+
169+
def test_function_docstring_parsing():
170+
from typing import List, Dict, Any
171+
172+
def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]:
173+
"""
174+
Test function with complex docstring.
175+
176+
Args:
177+
x (integer): A number
178+
with multiple lines
179+
y (array of string): A list
180+
with multiple lines
181+
182+
Returns:
183+
object: A dictionary
184+
with multiple lines
185+
"""
186+
pass
187+
188+
tool = convert_function_to_tool(func_with_complex_docs)
189+
assert tool['function']['description'] == 'Test function with complex docstring.'
190+
assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines'
191+
assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines'
192+
193+
194+
def test_tool_validation():
195+
# Test that malformed tool dictionaries are rejected
196+
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
197+
with pytest.raises(ValueError):
198+
process_tools([invalid_tool])
199+
200+
# Test missing required fields
201+
incomplete_tool = {
202+
'type': 'function',
203+
'function': {'name': 'test'}, # missing description and parameters
204+
}
205+
with pytest.raises(ValueError):
206+
process_tools([incomplete_tool])

tests/test_utils_legacy.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
from typing import Dict, List, Mapping, Optional, Sequence, Set, Tuple, Union
1+
from typing import Dict, List, Mapping, Sequence, Set, Tuple, Union
22

33
from ollama._utils import _get_json_type, convert_function_to_tool
44

55

66
def test_json_type_conversion():
77
# Test basic types
8-
assert _get_json_type(str) == 'string'
9-
assert _get_json_type(int) == 'integer'
108
assert _get_json_type(List) == 'array'
119
assert _get_json_type(Dict) == 'object'
1210

13-
# Test Optional
14-
assert _get_json_type(Optional[str]) == 'string'
15-
1611

1712
def test_function_to_tool_conversion():
18-
def add_numbers(x: int, y: Union[int, None] = None) -> int: # Changed Optional to Union
13+
def add_numbers(x: int, y: Union[int, None] = None) -> int:
1914
"""Add two numbers together.
2015
Args:
2116
x (integer): The first number
@@ -41,13 +36,13 @@ def test_function_with_all_typing_types():
4136
def all_types(
4237
x: int,
4338
y: str,
44-
z: Sequence[int],
39+
z: Sequence,
4540
w: Mapping[str, int],
4641
d: Dict[str, int],
4742
s: Set[int],
4843
t: Tuple[int, str],
4944
l: List[int], # noqa: E741
50-
o: Union[int, None], # Changed Optional to Union
45+
o: Union[int, None],
5146
) -> Union[Mapping[str, int], str, None]:
5247
"""
5348
A function with all types.

0 commit comments

Comments
 (0)