-
Notifications
You must be signed in to change notification settings - Fork 761
Passing Functions as Tools #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
4383603
WIP tool parsing
ParthSareen afe7db6
Managing multiple type options
ParthSareen 8fee892
Add tool parsing and processing
ParthSareen 0e5a940
Formatting and todos
ParthSareen 1ef75a7
TODOs
ParthSareen 93c7a63
wip
ParthSareen e5dc2b8
add annotations import for old tests
ParthSareen aa20015
Exhaustive type matching
ParthSareen d79538e
Ruff fix
ParthSareen 97aa167
WIP trying tests out
ParthSareen 8ec5123
Trying stuff out
ParthSareen efb775b
Multi-line docstrings and exhaustive tests
ParthSareen 2efa54a
Walrus op for cleanup
ParthSareen 1f089f7
Stringify return type arrays to not break server
ParthSareen fe8d143
WIP
ParthSareen 67321a8
Organization, cleanup, pydantic serialization, update tests
ParthSareen 2cc0b40
Typing fix
ParthSareen e68700c
Python3.8+ compatibility
ParthSareen f452fab
Add str -> str valid json mapping and add test
ParthSareen ca16670
Code cleanup and organization
ParthSareen 7dcb598
Test unhappy parse path
ParthSareen 7c5c294
Code cleanup + organize and add tests for type serialization
ParthSareen 16c868a
Update to have graceful handling and not raise - added tests as well
ParthSareen 718412a
Making good use of pydantic
ParthSareen e7bb55f
Add yields and test
ParthSareen 7396ab6
Simplified parsing and fixed required - added tests
ParthSareen 0d9eec0
Add tool.model_validate
ParthSareen ed3ba8a
Code style updates
ParthSareen a4ec34a
Add better messaging for chat
ParthSareen 6d9c156
Addressing comments + cleanup + optional tool
ParthSareen c5c61a3
Better docstring parsing and some fixes
ParthSareen b0e0409
Bugfix/image encoding (#327)
ParthSareen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, List, Mapping, Optional, Union, get_args, get_origin | ||
from ollama._types import Tool | ||
from collections.abc import Sequence, Set | ||
from typing import Dict, Set as TypeSet | ||
import sys | ||
|
||
# Type compatibility layer | ||
if sys.version_info >= (3, 10): | ||
from types import UnionType | ||
|
||
def is_union(tp: Any) -> bool: | ||
return get_origin(tp) in (Union, UnionType) | ||
else: | ||
|
||
def is_union(tp: Any) -> bool: | ||
return get_origin(tp) is Union | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# Map both the type and the type reference to the same JSON type | ||
TYPE_MAP = { | ||
# Basic types | ||
int: 'integer', | ||
'int': 'integer', | ||
str: 'string', | ||
'str': 'string', | ||
float: 'number', | ||
'float': 'number', | ||
bool: 'boolean', | ||
'bool': 'boolean', | ||
type(None): 'null', | ||
'None': 'null', | ||
# Collection types | ||
list: 'array', | ||
'list': 'array', | ||
List: 'array', | ||
'List': 'array', | ||
Sequence: 'array', | ||
'Sequence': 'array', | ||
tuple: 'array', | ||
'tuple': 'array', | ||
set: 'array', | ||
'set': 'array', | ||
Set: 'array', | ||
TypeSet: 'array', | ||
'Set': 'array', | ||
# Mapping types | ||
dict: 'object', | ||
'dict': 'object', | ||
Dict: 'object', | ||
'Dict': 'object', | ||
Mapping: 'object', | ||
'Mapping': 'object', | ||
Any: 'string', | ||
'Any': 'string', | ||
} | ||
|
||
|
||
def _get_json_type(python_type: Any) -> str | List[str]: | ||
# Handle Optional types (Union[type, None] and type | None) | ||
if is_union(python_type): | ||
args = get_args(python_type) | ||
# Filter out None/NoneType from union args | ||
non_none_args = [arg for arg in args if arg not in (None, type(None))] | ||
if non_none_args: | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(non_none_args) == 1: | ||
return _get_json_type(non_none_args[0]) | ||
# For multiple types (e.g., int | str | None), return array of types | ||
return [_get_json_type(arg) for arg in non_none_args] | ||
return 'null' | ||
|
||
# Handle generic types (List[int], Dict[str, int], etc.) | ||
if get_origin(python_type) is not None: | ||
# Get the base type (List, Dict, etc.) | ||
base_type = TYPE_MAP.get(get_origin(python_type), None) | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if base_type: | ||
return base_type | ||
# If it's a subclass of known abstract base classes, map to appropriate type | ||
if isinstance(get_origin(python_type), type): | ||
if issubclass(get_origin(python_type), (list, Sequence, tuple, set, Set)): | ||
return 'array' | ||
if issubclass(get_origin(python_type), (dict, Mapping)): | ||
return 'object' | ||
|
||
# Handle both type objects and type references | ||
type_key = python_type | ||
if isinstance(python_type, type): | ||
type_key = python_type | ||
elif isinstance(python_type, str): | ||
type_key = python_type.lower() | ||
|
||
# If type not found in map, try to get the type name | ||
if type_key not in TYPE_MAP and hasattr(python_type, '__name__'): | ||
type_key = python_type.__name__.lower() | ||
|
||
if type_key in TYPE_MAP: | ||
return TYPE_MAP[type_key] | ||
|
||
raise ValueError(f'Could not map Python type {python_type} to a valid JSON type') | ||
|
||
|
||
def _is_optional_type(python_type: Any) -> bool: | ||
if is_union(python_type): | ||
args = get_args(python_type) | ||
return any(arg in (None, type(None)) for arg in args) | ||
return False | ||
|
||
|
||
def convert_function_to_tool(func: Callable) -> Tool: | ||
doc_string = func.__doc__ | ||
if not doc_string: | ||
raise ValueError(f'Function {func.__name__} must have a docstring in Google format. Example:\n' '"""Add two numbers.\n\n' 'Args:\n' ' a: First number\n' ' b: Second number\n\n' 'Returns:\n' ' int: Sum of the numbers\n' '"""') | ||
|
||
# Extract description from docstring - get all lines before Args: | ||
description_lines = [] | ||
for line in doc_string.split('\n'): | ||
line = line.strip() | ||
if line.startswith('Args:'): | ||
break | ||
if line: | ||
description_lines.append(line) | ||
|
||
description = ' '.join(description_lines).strip() | ||
|
||
# Parse Args section | ||
if 'Args:' not in doc_string: | ||
raise ValueError(f'Function {func.__name__} docstring must have an Args section in Google format') | ||
|
||
args_section = doc_string.split('Args:')[1] | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if 'Returns:' in args_section: | ||
args_section = args_section.split('Returns:')[0] | ||
|
||
parameters = {'type': 'object', 'properties': {}, 'required': []} | ||
|
||
# Build parameters dict | ||
for param_name, param_type in func.__annotations__.items(): | ||
if param_name == 'return': | ||
continue | ||
|
||
param_desc = None | ||
for line in args_section.split('\n'): | ||
line = line.strip() | ||
# Check for parameter name with or without colon, space, or parentheses to mitigate formatting issues | ||
if line.startswith(param_name + ':') or line.startswith(param_name + ' ') or line.startswith(param_name + '('): | ||
param_desc = line.split(':', 1)[1].strip() | ||
break | ||
|
||
if not param_desc: | ||
raise ValueError(f'Parameter {param_name} must have a description in the Args section') | ||
|
||
parameters['properties'][param_name] = { | ||
'type': _get_json_type(param_type), | ||
'description': param_desc, | ||
} | ||
|
||
# Only add to required if not optional - could capture and map earlier to save this call | ||
if not _is_optional_type(param_type): | ||
parameters['required'].append(param_name) | ||
|
||
tool_dict = { | ||
'type': 'function', | ||
'function': { | ||
'name': func.__name__, | ||
'description': description, | ||
'parameters': parameters, | ||
'return_type': None, | ||
}, | ||
} | ||
|
||
if 'return' in func.__annotations__ and func.__annotations__['return'] is not None: | ||
tool_dict['function']['return_type'] = _get_json_type(func.__annotations__['return']) | ||
|
||
return Tool.model_validate(tool_dict) | ||
|
||
|
||
def process_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Sequence[Tool]: | ||
if not tools: | ||
return [] | ||
|
||
processed_tools = [] | ||
for tool in tools: | ||
if callable(tool): | ||
processed_tools.append(convert_function_to_tool(tool)) | ||
else: | ||
processed_tools.append(Tool.model_validate(tool)) | ||
|
||
return processed_tools |
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import sys | ||
|
||
import pytest | ||
|
||
if sys.version_info < (3, 10): | ||
pytest.skip('Python 3.10 or higher is required', allow_module_level=True) | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
from ollama._utils import _get_json_type, convert_function_to_tool, process_tools | ||
|
||
|
||
def test_json_type_conversion(): | ||
from typing import Optional | ||
|
||
# Test basic types | ||
assert _get_json_type(str) == 'string' | ||
assert _get_json_type(int) == 'integer' | ||
assert _get_json_type(list) == 'array' | ||
assert _get_json_type(dict) == 'object' | ||
|
||
# Test Optional | ||
assert _get_json_type(Optional[str]) == 'string' | ||
|
||
|
||
def test_function_to_tool_conversion(): | ||
from typing import Optional | ||
|
||
def add_numbers(x: int, y: Optional[int] = None) -> int: | ||
"""Add two numbers together. | ||
Args: | ||
x (integer): The first number | ||
y (integer, optional): The second number | ||
|
||
Returns: | ||
integer: The sum of x and y | ||
""" | ||
|
||
return x + y | ||
|
||
tool = convert_function_to_tool(add_numbers) | ||
|
||
assert tool['type'] == 'function' | ||
assert tool['function']['name'] == 'add_numbers' | ||
assert tool['function']['description'] == 'Add two numbers together.' | ||
assert tool['function']['parameters']['type'] == 'object' | ||
assert tool['function']['parameters']['properties']['x']['type'] == 'integer' | ||
assert tool['function']['parameters']['properties']['x']['description'] == 'The first number' | ||
assert tool['function']['parameters']['required'] == ['x'] | ||
|
||
|
||
def test_function_with_no_args(): | ||
def simple_func(): | ||
""" | ||
A simple function with no arguments. | ||
Args: | ||
None | ||
Returns: | ||
None | ||
""" | ||
pass | ||
|
||
tool = convert_function_to_tool(simple_func) | ||
assert tool.function.name == 'simple_func' | ||
assert tool.function.description == 'A simple function with no arguments.' | ||
assert tool.function.parameters.properties == {} | ||
assert tool.function.return_type is None | ||
|
||
|
||
def test_function_with_all_types(): | ||
def all_types( | ||
x: int, | ||
y: str, | ||
z: list[int], | ||
w: dict[str, int], | ||
v: int | str | None, | ||
) -> int | dict[str, int] | str | list[int] | None: | ||
""" | ||
A function with all types. | ||
Args: | ||
x (integer): The first number | ||
y (string): The second number | ||
z (array): The third number | ||
w (object): The fourth number | ||
v (integer | string | None): The fifth number | ||
""" | ||
pass | ||
|
||
tool = convert_function_to_tool(all_types) | ||
assert tool.function.parameters.properties['x']['type'] == 'integer' | ||
assert tool.function.parameters.properties['y']['type'] == 'string' | ||
assert tool.function.parameters.properties['z']['type'] == 'array' | ||
assert tool.function.parameters.properties['w']['type'] == 'object' | ||
assert set(tool.function.parameters.properties['v']['type']) == {'string', 'integer'} | ||
assert set(tool.function.return_type) == {'string', 'integer', 'array', 'object'} | ||
|
||
|
||
def test_process_tools(): | ||
def func1(x: int) -> str: | ||
"""Simple function 1. | ||
Args: | ||
x: A number | ||
""" | ||
pass | ||
|
||
def func2(y: str) -> int: | ||
"""Simple function 2. | ||
Args: | ||
y: A string | ||
""" | ||
pass | ||
|
||
# Test with list of functions | ||
tools = process_tools([func1, func2]) | ||
assert len(tools) == 2 | ||
assert tools[0].function.name == 'func1' | ||
assert tools[1].function.name == 'func2' | ||
|
||
# Test with empty input | ||
assert process_tools() == [] | ||
assert process_tools(None) == [] | ||
assert process_tools([]) == [] | ||
|
||
# Test with mix of functions and tool dicts | ||
tool_dict = {'type': 'function', 'function': {'name': 'test', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'string', 'description': 'A string'}}, 'required': ['x']}}} | ||
tools = process_tools([func1, tool_dict]) | ||
assert len(tools) == 2 | ||
assert tools[0].function.name == 'func1' | ||
assert tools[1].function.name == 'test' |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.