Skip to content

Passing Functions as Tools #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4383603
WIP tool parsing
ParthSareen Nov 8, 2024
afe7db6
Managing multiple type options
ParthSareen Nov 9, 2024
8fee892
Add tool parsing and processing
ParthSareen Nov 11, 2024
0e5a940
Formatting and todos
ParthSareen Nov 11, 2024
1ef75a7
TODOs
ParthSareen Nov 11, 2024
93c7a63
wip
ParthSareen Nov 11, 2024
e5dc2b8
add annotations import for old tests
ParthSareen Nov 11, 2024
aa20015
Exhaustive type matching
ParthSareen Nov 11, 2024
d79538e
Ruff fix
ParthSareen Nov 11, 2024
97aa167
WIP trying tests out
ParthSareen Nov 11, 2024
8ec5123
Trying stuff out
ParthSareen Nov 11, 2024
efb775b
Multi-line docstrings and exhaustive tests
ParthSareen Nov 12, 2024
2efa54a
Walrus op for cleanup
ParthSareen Nov 12, 2024
1f089f7
Stringify return type arrays to not break server
ParthSareen Nov 13, 2024
fe8d143
WIP
ParthSareen Nov 14, 2024
67321a8
Organization, cleanup, pydantic serialization, update tests
ParthSareen Nov 14, 2024
2cc0b40
Typing fix
ParthSareen Nov 14, 2024
e68700c
Python3.8+ compatibility
ParthSareen Nov 14, 2024
f452fab
Add str -> str valid json mapping and add test
ParthSareen Nov 14, 2024
ca16670
Code cleanup and organization
ParthSareen Nov 14, 2024
7dcb598
Test unhappy parse path
ParthSareen Nov 14, 2024
7c5c294
Code cleanup + organize and add tests for type serialization
ParthSareen Nov 14, 2024
16c868a
Update to have graceful handling and not raise - added tests as well
ParthSareen Nov 15, 2024
718412a
Making good use of pydantic
ParthSareen Nov 18, 2024
e7bb55f
Add yields and test
ParthSareen Nov 18, 2024
7396ab6
Simplified parsing and fixed required - added tests
ParthSareen Nov 18, 2024
0d9eec0
Add tool.model_validate
ParthSareen Nov 18, 2024
ed3ba8a
Code style updates
ParthSareen Nov 19, 2024
a4ec34a
Add better messaging for chat
ParthSareen Nov 19, 2024
6d9c156
Addressing comments + cleanup + optional tool
ParthSareen Nov 19, 2024
c5c61a3
Better docstring parsing and some fixes
ParthSareen Nov 20, 2024
b0e0409
Bugfix/image encoding (#327)
ParthSareen Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import (
Any,
Callable,
Literal,
Mapping,
Optional,
Expand All @@ -22,6 +23,9 @@

import sys


from ollama._utils import process_tools

if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
Expand Down Expand Up @@ -284,7 +288,7 @@ def chat(
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
Expand All @@ -300,6 +304,7 @@ def chat(
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""

tools = process_tools(tools)
return self._request(
ChatResponse,
'POST',
Expand Down
6 changes: 5 additions & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import (
Any,
List,
Literal,
Mapping,
Optional,
Expand Down Expand Up @@ -229,12 +230,14 @@ class Function(SubscriptableBaseModel):
description: str

class Parameters(SubscriptableBaseModel):
type: str
type: Union[str, List[str]]
required: Optional[Sequence[str]] = None
properties: Optional[JsonSchemaValue] = None

parameters: Parameters

return_type: Optional[Union[str, List[str]]] = None

function: Function


Expand Down Expand Up @@ -335,6 +338,7 @@ class ModelDetails(SubscriptableBaseModel):

class ListResponse(SubscriptableBaseModel):
class Model(SubscriptableBaseModel):
name: Optional[str] = None
modified_at: Optional[datetime] = None
digest: Optional[str] = None
size: Optional[ByteSize] = None
Expand Down
188 changes: 188 additions & 0 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

from typing import Any, Callable, List, Mapping, Optional, Union, get_args, get_origin
from ollama._types import Tool
from collections.abc import Sequence, Set
from typing import Dict, Set as TypeSet
import sys

# Type compatibility layer
if sys.version_info >= (3, 10):
from types import UnionType

def is_union(tp: Any) -> bool:
return get_origin(tp) in (Union, UnionType)
else:

def is_union(tp: Any) -> bool:
return get_origin(tp) is Union


# Map both the type and the type reference to the same JSON type
TYPE_MAP = {
# Basic types
int: 'integer',
'int': 'integer',
str: 'string',
'str': 'string',
float: 'number',
'float': 'number',
bool: 'boolean',
'bool': 'boolean',
type(None): 'null',
'None': 'null',
# Collection types
list: 'array',
'list': 'array',
List: 'array',
'List': 'array',
Sequence: 'array',
'Sequence': 'array',
tuple: 'array',
'tuple': 'array',
set: 'array',
'set': 'array',
Set: 'array',
TypeSet: 'array',
'Set': 'array',
# Mapping types
dict: 'object',
'dict': 'object',
Dict: 'object',
'Dict': 'object',
Mapping: 'object',
'Mapping': 'object',
Any: 'string',
'Any': 'string',
}


def _get_json_type(python_type: Any) -> str | List[str]:
# Handle Optional types (Union[type, None] and type | None)
if is_union(python_type):
args = get_args(python_type)
# Filter out None/NoneType from union args
non_none_args = [arg for arg in args if arg not in (None, type(None))]
if non_none_args:
if len(non_none_args) == 1:
return _get_json_type(non_none_args[0])
# For multiple types (e.g., int | str | None), return array of types
return [_get_json_type(arg) for arg in non_none_args]
return 'null'

# Handle generic types (List[int], Dict[str, int], etc.)
if get_origin(python_type) is not None:
# Get the base type (List, Dict, etc.)
base_type = TYPE_MAP.get(get_origin(python_type), None)
if base_type:
return base_type
# If it's a subclass of known abstract base classes, map to appropriate type
if isinstance(get_origin(python_type), type):
if issubclass(get_origin(python_type), (list, Sequence, tuple, set, Set)):
return 'array'
if issubclass(get_origin(python_type), (dict, Mapping)):
return 'object'

# Handle both type objects and type references
type_key = python_type
if isinstance(python_type, type):
type_key = python_type
elif isinstance(python_type, str):
type_key = python_type.lower()

# If type not found in map, try to get the type name
if type_key not in TYPE_MAP and hasattr(python_type, '__name__'):
type_key = python_type.__name__.lower()

if type_key in TYPE_MAP:
return TYPE_MAP[type_key]

raise ValueError(f'Could not map Python type {python_type} to a valid JSON type')


def _is_optional_type(python_type: Any) -> bool:
if is_union(python_type):
args = get_args(python_type)
return any(arg in (None, type(None)) for arg in args)
return False


def convert_function_to_tool(func: Callable) -> Tool:
doc_string = func.__doc__
if not doc_string:
raise ValueError(f'Function {func.__name__} must have a docstring in Google format. Example:\n' '"""Add two numbers.\n\n' 'Args:\n' ' a: First number\n' ' b: Second number\n\n' 'Returns:\n' ' int: Sum of the numbers\n' '"""')

# Extract description from docstring - get all lines before Args:
description_lines = []
for line in doc_string.split('\n'):
line = line.strip()
if line.startswith('Args:'):
break
if line:
description_lines.append(line)

description = ' '.join(description_lines).strip()

# Parse Args section
if 'Args:' not in doc_string:
raise ValueError(f'Function {func.__name__} docstring must have an Args section in Google format')

args_section = doc_string.split('Args:')[1]
if 'Returns:' in args_section:
args_section = args_section.split('Returns:')[0]

parameters = {'type': 'object', 'properties': {}, 'required': []}

# Build parameters dict
for param_name, param_type in func.__annotations__.items():
if param_name == 'return':
continue

param_desc = None
for line in args_section.split('\n'):
line = line.strip()
# Check for parameter name with or without colon, space, or parentheses to mitigate formatting issues
if line.startswith(param_name + ':') or line.startswith(param_name + ' ') or line.startswith(param_name + '('):
param_desc = line.split(':', 1)[1].strip()
break

if not param_desc:
raise ValueError(f'Parameter {param_name} must have a description in the Args section')

parameters['properties'][param_name] = {
'type': _get_json_type(param_type),
'description': param_desc,
}

# Only add to required if not optional - could capture and map earlier to save this call
if not _is_optional_type(param_type):
parameters['required'].append(param_name)

tool_dict = {
'type': 'function',
'function': {
'name': func.__name__,
'description': description,
'parameters': parameters,
'return_type': None,
},
}

if 'return' in func.__annotations__ and func.__annotations__['return'] is not None:
tool_dict['function']['return_type'] = _get_json_type(func.__annotations__['return'])

return Tool.model_validate(tool_dict)


def process_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Sequence[Tool]:
if not tools:
return []

processed_tools = []
for tool in tools:
if callable(tool):
processed_tools.append(convert_function_to_tool(tool))
else:
processed_tools.append(Tool.model_validate(tool))

return processed_tools
128 changes: 128 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import sys

import pytest

if sys.version_info < (3, 10):
pytest.skip('Python 3.10 or higher is required', allow_module_level=True)


from ollama._utils import _get_json_type, convert_function_to_tool, process_tools


def test_json_type_conversion():
from typing import Optional

# Test basic types
assert _get_json_type(str) == 'string'
assert _get_json_type(int) == 'integer'
assert _get_json_type(list) == 'array'
assert _get_json_type(dict) == 'object'

# Test Optional
assert _get_json_type(Optional[str]) == 'string'


def test_function_to_tool_conversion():
from typing import Optional

def add_numbers(x: int, y: Optional[int] = None) -> int:
"""Add two numbers together.
Args:
x (integer): The first number
y (integer, optional): The second number

Returns:
integer: The sum of x and y
"""

return x + y

tool = convert_function_to_tool(add_numbers)

assert tool['type'] == 'function'
assert tool['function']['name'] == 'add_numbers'
assert tool['function']['description'] == 'Add two numbers together.'
assert tool['function']['parameters']['type'] == 'object'
assert tool['function']['parameters']['properties']['x']['type'] == 'integer'
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number'
assert tool['function']['parameters']['required'] == ['x']


def test_function_with_no_args():
def simple_func():
"""
A simple function with no arguments.
Args:
None
Returns:
None
"""
pass

tool = convert_function_to_tool(simple_func)
assert tool.function.name == 'simple_func'
assert tool.function.description == 'A simple function with no arguments.'
assert tool.function.parameters.properties == {}
assert tool.function.return_type is None


def test_function_with_all_types():
def all_types(
x: int,
y: str,
z: list[int],
w: dict[str, int],
v: int | str | None,
) -> int | dict[str, int] | str | list[int] | None:
"""
A function with all types.
Args:
x (integer): The first number
y (string): The second number
z (array): The third number
w (object): The fourth number
v (integer | string | None): The fifth number
"""
pass

tool = convert_function_to_tool(all_types)
assert tool.function.parameters.properties['x']['type'] == 'integer'
assert tool.function.parameters.properties['y']['type'] == 'string'
assert tool.function.parameters.properties['z']['type'] == 'array'
assert tool.function.parameters.properties['w']['type'] == 'object'
assert set(tool.function.parameters.properties['v']['type']) == {'string', 'integer'}
assert set(tool.function.return_type) == {'string', 'integer', 'array', 'object'}


def test_process_tools():
def func1(x: int) -> str:
"""Simple function 1.
Args:
x: A number
"""
pass

def func2(y: str) -> int:
"""Simple function 2.
Args:
y: A string
"""
pass

# Test with list of functions
tools = process_tools([func1, func2])
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'func2'

# Test with empty input
assert process_tools() == []
assert process_tools(None) == []
assert process_tools([]) == []

# Test with mix of functions and tool dicts
tool_dict = {'type': 'function', 'function': {'name': 'test', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'string', 'description': 'A string'}}, 'required': ['x']}}}
tools = process_tools([func1, tool_dict])
assert len(tools) == 2
assert tools[0].function.name == 'func1'
assert tools[1].function.name == 'test'
Loading