diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 6d83e6cc..89863529 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -22,10 +22,9 @@ import ( ) type Client struct { - clientsLock sync.Mutex + modelsLock sync.Mutex cache *cache.Client - clients map[string]*openai.Client - models map[string]*openai.Client + modelToProvider map[string]string runner *runner.Runner envs []string credStore credentials.CredentialStore @@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent } func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { - c.clientsLock.Lock() - client, ok := c.models[messageRequest.Model] - c.clientsLock.Unlock() + c.modelsLock.Lock() + provider, ok := c.modelToProvider[messageRequest.Model] + c.modelsLock.Unlock() if !ok { return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) } + client, err := c.load(ctx, provider) + if err != nil { + return nil, err + } + toolName, modelName := types.SplitToolRef(messageRequest.Model) if modelName == "" { // modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider' @@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error) return false, nil } - client, err := c.load(ctx, providerName) + _, err := c.load(ctx, providerName) if err != nil { return false, err } - c.clientsLock.Lock() - defer c.clientsLock.Unlock() + c.modelsLock.Lock() + defer c.modelsLock.Unlock() - if c.models == nil { - c.models = map[string]*openai.Client{} + if c.modelToProvider == nil { + c.modelToProvider = map[string]string{} } - c.models[modelString] = client + c.modelToProvider[modelString] = providerName return true, nil } @@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie } func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { - c.clientsLock.Lock() - defer c.clientsLock.Unlock() - - client, ok := c.clients[toolName] - if ok { - return client, nil - } - - if c.clients == nil { - c.clients = make(map[string]*openai.Client) - } - if isHTTPURL(toolName) { remoteClient, err := c.clientFromURL(ctx, toolName) if err != nil { return nil, err } - c.clients[toolName] = remoteClient return remoteClient, nil } @@ -174,14 +165,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - if strings.HasSuffix(url, "/") { - url += "v1" - } else { - url += "/v1" - } - - client, err = openai.NewClient(ctx, c.credStore, openai.Options{ - BaseURL: url, + client, err := openai.NewClient(ctx, c.credStore, openai.Options{ + BaseURL: strings.TrimSuffix(url, "/") + "/v1", Cache: c.cache, CacheKey: prg.EntryToolID, }) @@ -189,7 +174,6 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - c.clients[toolName] = client return client, nil }