Skip to content

Commit e8a4950

Browse files
committed
Addressing comments + cleanup + optional tool
1 parent a4ec34a commit e8a4950

File tree

6 files changed

+33
-131
lines changed

6 files changed

+33
-131
lines changed

ollama/_client.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def chat(
307307
308308
Example:
309309
def add_two_numbers(a: int, b: int) -> int:
310-
\"""
310+
'''
311311
Add two numbers together.
312312
313313
Args:
@@ -316,7 +316,7 @@ def add_two_numbers(a: int, b: int) -> int:
316316
317317
Returns:
318318
int: The sum of a and b
319-
\"""
319+
'''
320320
return a + b
321321
322322
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
@@ -809,7 +809,7 @@ async def chat(
809809
810810
Example:
811811
def add_two_numbers(a: int, b: int) -> int:
812-
\"""
812+
'''
813813
Add two numbers together.
814814
815815
Args:
@@ -818,7 +818,7 @@ def add_two_numbers(a: int, b: int) -> int:
818818
819819
Returns:
820820
int: The sum of a and b
821-
\"""
821+
'''
822822
return a + b
823823
824824
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
@@ -1128,10 +1128,7 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
11281128

11291129

11301130
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
1131-
if not tools:
1132-
return []
1133-
1134-
for unprocessed_tool in tools:
1131+
for unprocessed_tool in tools or []:
11351132
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
11361133

11371134

@@ -1207,8 +1204,6 @@ def _parse_host(host: Optional[str]) -> str:
12071204
'https://[0001:002:003:0004::1]:56789/path'
12081205
>>> _parse_host('[0001:002:003:0004::1]:56789/path/')
12091206
'http://[0001:002:003:0004::1]:56789/path'
1210-
>>> _parse_host('http://host.docker.internal:11434/path')
1211-
'http://host.docker.internal:11434/path'
12121207
"""
12131208

12141209
host, port = host or '', 11434

ollama/_types.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,27 +216,27 @@ class Function(SubscriptableBaseModel):
216216

217217

218218
class Tool(SubscriptableBaseModel):
219-
type: Literal['function'] = 'function'
219+
type: Optional[Literal['function']] = 'function'
220220

221221
class Function(SubscriptableBaseModel):
222-
name: str
223-
description: str
222+
name: Optional[str] = None
223+
description: Optional[str] = None
224224

225225
class Parameters(SubscriptableBaseModel):
226-
type: Literal['object'] = 'object'
226+
type: Optional[Literal['object']] = 'object'
227227
required: Optional[Sequence[str]] = None
228228

229229
class Property(SubscriptableBaseModel):
230230
model_config = ConfigDict(arbitrary_types_allowed=True)
231231

232-
type: str
233-
description: str
232+
type: Optional[str] = None
233+
description: Optional[str] = None
234234

235235
properties: Optional[Mapping[str, Property]] = None
236236

237-
parameters: Parameters
237+
parameters: Optional[Parameters] = None
238238

239-
function: Function
239+
function: Optional[Function] = None
240240

241241

242242
class ChatRequest(BaseGenerateRequest):

ollama/_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
1212
if not doc_string:
1313
return parsed_docstring
1414

15-
lowered_doc_string = doc_string.lower()
16-
1715
key = hash(doc_string)
18-
parsed_docstring[key] = ''
19-
for line in lowered_doc_string.splitlines():
20-
if line.startswith('args:'):
16+
for line in doc_string.splitlines():
17+
lowered_line = line.lower()
18+
if lowered_line.startswith('args:'):
2119
key = 'args'
22-
elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'):
20+
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
2321
key = '_'
2422

2523
else:
@@ -29,7 +27,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
2927
last_key = None
3028
for line in parsed_docstring['args'].splitlines():
3129
line = line.strip()
32-
if ':' in line and not line.startswith('args'):
30+
if ':' in line and not line.lower().startswith('args:'):
3331
# Split on first occurrence of '(' or ':' to separate arg name from description
3432
split_char = '(' if '(' in line else ':'
3533
arg_name, rest = line.split(split_char, 1)

tests/test_client.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,25 +1025,14 @@ def func2(y: str) -> int:
10251025
},
10261026
}
10271027

1028-
tool_json = json.loads(json.dumps(tool_dict))
1029-
tools = list(_copy_tools([func1, tool_dict, tool_json]))
1030-
assert len(tools) == 3
1028+
tools = list(_copy_tools([func1, tool_dict]))
1029+
assert len(tools) == 2
10311030
assert tools[0].function.name == 'func1'
10321031
assert tools[1].function.name == 'test'
1033-
assert tools[2].function.name == 'test'
10341032

10351033

10361034
def test_tool_validation():
1037-
# Test that malformed tool dictionaries are rejected
10381035
# Raises ValidationError when used as it is a generator
10391036
with pytest.raises(ValidationError):
10401037
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
10411038
list(_copy_tools([invalid_tool]))
1042-
1043-
# Test missing required fields
1044-
incomplete_tool = {
1045-
'type': 'function',
1046-
'function': {'name': 'test'}, # missing description and parameters
1047-
}
1048-
with pytest.raises(ValidationError):
1049-
list(_copy_tools([incomplete_tool]))

tests/test_type_serialization.py

Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from base64 import b64decode, b64encode
22

3-
import pytest
43

5-
6-
from ollama._types import Image, Tool
4+
from ollama._types import Image
75

86

97
def test_image_serialization():
@@ -16,81 +14,3 @@ def test_image_serialization():
1614
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
1715
img = Image(value=b64_str)
1816
assert img.model_dump() == b64decode(b64_str).decode()
19-
20-
21-
def test_tool_serialization():
22-
# Test valid tool serialization
23-
tool = Tool(
24-
function=Tool.Function(
25-
name='add_two_numbers',
26-
description='Add two numbers together.',
27-
parameters=Tool.Function.Parameters(
28-
type='object',
29-
properties={
30-
'a': Tool.Function.Parameters.Property(
31-
type='integer',
32-
description='The first number',
33-
),
34-
'b': Tool.Function.Parameters.Property(
35-
type='integer',
36-
description='The second number',
37-
),
38-
},
39-
required=['a', 'b'],
40-
),
41-
)
42-
)
43-
assert tool.model_dump() == {
44-
'type': 'function',
45-
'function': {
46-
'name': 'add_two_numbers',
47-
'description': 'Add two numbers together.',
48-
'parameters': {
49-
'type': 'object',
50-
'properties': {
51-
'a': {
52-
'type': 'integer',
53-
'description': 'The first number',
54-
},
55-
'b': {
56-
'type': 'integer',
57-
'description': 'The second number',
58-
},
59-
},
60-
'required': ['a', 'b'],
61-
},
62-
},
63-
}
64-
65-
# Test invalid type
66-
with pytest.raises(ValueError):
67-
property = Tool.Function.Parameters.Property(
68-
type=lambda x: x, # Invalid type
69-
description='Invalid type',
70-
)
71-
Tool.model_validate(
72-
Tool(
73-
function=Tool.Function(
74-
parameters=Tool.Function.Parameters(
75-
properties={
76-
'x': property,
77-
}
78-
)
79-
)
80-
)
81-
)
82-
83-
# Test invalid parameters type
84-
with pytest.raises(ValueError):
85-
Tool.model_validate(
86-
Tool(
87-
function=Tool.Function(
88-
name='test',
89-
description='Test',
90-
parameters=Tool.Function.Parameters(
91-
type='invalid_type', # Must be 'object'
92-
properties={},
93-
),
94-
)
95-
)
96-
)

tests/test_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def test_function_to_tool_conversion():
1010
def add_numbers(x: int, y: Union[int, None] = None) -> int:
1111
"""Add two numbers together.
12-
Args:
12+
args:
1313
x (integer): The first number
1414
y (integer, optional): The second number
1515
@@ -22,10 +22,10 @@ def add_numbers(x: int, y: Union[int, None] = None) -> int:
2222

2323
assert tool['type'] == 'function'
2424
assert tool['function']['name'] == 'add_numbers'
25-
assert tool['function']['description'] == 'add two numbers together.'
25+
assert tool['function']['description'] == 'Add two numbers together.'
2626
assert tool['function']['parameters']['type'] == 'object'
2727
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
28-
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
28+
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number'
2929
assert tool['function']['parameters']['required'] == ['x']
3030

3131

@@ -42,7 +42,7 @@ def simple_func():
4242

4343
tool = convert_function_to_tool(simple_func).model_dump()
4444
assert tool['function']['name'] == 'simple_func'
45-
assert tool['function']['description'] == 'a simple function with no arguments.'
45+
assert tool['function']['description'] == 'A simple function with no arguments.'
4646
assert tool['function']['parameters']['properties'] == {}
4747

4848

@@ -137,9 +137,9 @@ def func_with_complex_docs(x: int, y: List[str]) -> Dict[str, Any]:
137137
pass
138138

139139
tool = convert_function_to_tool(func_with_complex_docs).model_dump()
140-
assert tool['function']['description'] == 'test function with complex docstring.'
141-
assert tool['function']['parameters']['properties']['x']['description'] == 'a number with multiple lines'
142-
assert tool['function']['parameters']['properties']['y']['description'] == 'a list with multiple lines'
140+
assert tool['function']['description'] == 'Test function with complex docstring.'
141+
assert tool['function']['parameters']['properties']['x']['description'] == 'A number with multiple lines'
142+
assert tool['function']['parameters']['properties']['y']['description'] == 'A list with multiple lines'
143143

144144

145145
def test_skewed_docstring_parsing():
@@ -159,8 +159,8 @@ def add_two_numbers(x: int, y: int) -> int:
159159
pass
160160

161161
tool = convert_function_to_tool(add_two_numbers).model_dump()
162-
assert tool['function']['parameters']['properties']['x']['description'] == ': the first number'
163-
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'
162+
assert tool['function']['parameters']['properties']['x']['description'] == ': The first number'
163+
assert tool['function']['parameters']['properties']['y']['description'] == 'The second number'
164164

165165

166166
def test_function_with_no_docstring():
@@ -187,7 +187,7 @@ def only_description():
187187
pass
188188

189189
tool = convert_function_to_tool(only_description).model_dump()
190-
assert tool['function']['description'] == 'a function with only a description.'
190+
assert tool['function']['description'] == 'A function with only a description.'
191191
assert tool['function']['parameters'] == {'type': 'object', 'properties': {}, 'required': None}
192192

193193
def only_description_with_args(x: int, y: int):
@@ -197,7 +197,7 @@ def only_description_with_args(x: int, y: int):
197197
pass
198198

199199
tool = convert_function_to_tool(only_description_with_args).model_dump()
200-
assert tool['function']['description'] == 'a function with only a description.'
200+
assert tool['function']['description'] == 'A function with only a description.'
201201
assert tool['function']['parameters'] == {
202202
'type': 'object',
203203
'properties': {
@@ -223,7 +223,7 @@ def function_with_yields(x: int, y: int):
223223
pass
224224

225225
tool = convert_function_to_tool(function_with_yields).model_dump()
226-
assert tool['function']['description'] == 'a function with yields section.'
226+
assert tool['function']['description'] == 'A function with yields section.'
227227
assert tool['function']['parameters']['properties']['x']['description'] == 'the first number'
228228
assert tool['function']['parameters']['properties']['y']['description'] == 'the second number'
229229

0 commit comments

Comments
 (0)