Skip to content

feat(py/genkit): added the resolve_method for openai compatible plugin #3055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,30 @@


"""OpenAI OpenAI API Compatible Plugin for Genkit."""
from typing import Any

from openai import Client, OpenAI as OpenAIClient

from genkit.ai._plugin import Plugin
from genkit.ai._registry import GenkitRegistry
from genkit.core.action.types import ActionKind
from genkit.plugins.compat_oai.models import (
SUPPORTED_OPENAI_MODELS,
OpenAIModel,
OpenAIModelHandler,
)
from genkit.plugins.compat_oai.typing import OpenAIConfig
from genkit.plugins.ollama.models import ModelDefinition

OPENAI_PLUGIN_NAME = 'openai'

def default_openai_metadata(name: str) -> dict[str, Any]:
return {
'model': {
'label': f"OpenAI - {name}",
'supports': {'multiturn': True}
},
}


class OpenAI(Plugin):
Expand Down Expand Up @@ -69,6 +83,40 @@ def initialize(self, ai: GenkitRegistry) -> None:
},
)

def resolve_action( # noqa: B027
self,
ai: GenkitRegistry,
kind: ActionKind,
name: str,
) -> None:

if kind is not ActionKind.MODEL:
return None

self._define_openai_model(ai, name)
return None

def _define_openai_model(self, ai: GenkitRegistry, name: str) -> None:
"""Defines and registers an OpenAI model with Genkit.

Cleans the model name, instantiates an OpenAI, and registers it
with the provided Genkit AI registry, including metadata about its capabilities.

Args:
ai: The Genkit AI registry instance.
name: The name of the model to be registered.
"""

handler = OpenAIModelHandler(OpenAIModel(name, self._openai_client, ai)).generate
ai.define_model(
name=f'openai/{name}',
fn=handler,
config_schema=OpenAIConfig,
metadata=default_openai_metadata(name)
)




def openai_model(name: str) -> str:
"""Returns a string representing the OpenAI model name to use with Genkit.
Expand Down
52 changes: 51 additions & 1 deletion py/plugins/compat-oai/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
#
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

import pytest

from genkit.ai._aio import Genkit
from genkit.core.action.types import ActionKind
from genkit.plugins.compat_oai import OpenAIConfig
from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS
from genkit.plugins.compat_oai.openai_plugin import OpenAI, openai_model

Expand All @@ -36,6 +40,52 @@ def test_openai_plugin_initialize() -> None:
assert registry.define_model.call_count == len(SUPPORTED_OPENAI_MODELS)



@pytest.mark.parametrize(
'kind, name',
[
(ActionKind.MODEL, 'gpt-3.5-turbo')
],
)
def test_openai_plugin_resolve_action(kind, name):
"""Unit Tests for resolve action method."""
plugin = OpenAI(api_key='test-key')
registry = MagicMock(spec=Genkit)
plugin.resolve_action(registry, kind, name)

model_info = SUPPORTED_OPENAI_MODELS[name]

registry.define_model.assert_called_once_with(
name=f'openai/{name}',
fn=ANY,
config_schema=OpenAIConfig,
metadata={
'model': {
'label': model_info.label,
'supports': {'multiturn': True}
},
},
)


@pytest.mark.parametrize(
'kind, name',
[
(ActionKind.MODEL, "model_doesnt_exist")
],
)
def test_openai_plugin_resolve_action_not_found(kind, name):
"""Unit Tests for resolve action method."""

plugin = OpenAI(api_key='test-key')
registry = MagicMock(spec=Genkit)
plugin.resolve_action(registry, kind, name)

registry.define_model.assert_called_once()




def test_openai_model_function() -> None:
"""Test openai_model function."""
assert openai_model('gpt-4') == 'openai/gpt-4'
Loading