diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index adccf5ad..9bb0c634 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -3,6 +3,8 @@ package remote import ( "context" "fmt" + "net/url" + "os" "slices" "sort" "strings" @@ -87,6 +89,28 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { return true, nil } +func isHTTPURL(toolName string) bool { + return strings.HasPrefix(toolName, "http://") || + strings.HasPrefix(toolName, "https://") +} + +func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) { + parsed, err := url.Parse(apiURL) + if err != nil { + return nil, err + } + env := strings.ToUpper(strings.ReplaceAll(parsed.Hostname(), ".", "_")) + "_API_KEY" + apiKey := os.Getenv(env) + if apiKey == "" { + apiKey = "" + } + return openai.NewClient(openai.Options{ + BaseURL: apiURL, + Cache: c.cache, + APIKey: apiKey, + }) +} + func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { c.clientsLock.Lock() defer c.clientsLock.Unlock() @@ -96,6 +120,19 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return client, nil } + if c.clients == nil { + c.clients = make(map[string]*openai.Client) + } + + if isHTTPURL(toolName) { + remoteClient, err := c.clientFromURL(toolName) + if err != nil { + return nil, err + } + c.clients[toolName] = remoteClient + return remoteClient, nil + } + prg, err := loader.Program(ctx, toolName, "") if err != nil { return nil, err @@ -120,10 +157,6 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - if c.clients == nil { - c.clients = make(map[string]*openai.Client) - } - c.clients[toolName] = client return client, nil }