From c0507a2c32dd543a43e46526710fab12b81c04f8 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Tue, 6 Aug 2024 21:02:19 -0400 Subject: [PATCH] feat: allow providers to be restarted if they stop By not caching the client, gptscript is able to restart the provider daemon if it stops. If the daemon is still running, then there is little overhead because the daemon URL is cached and the tool will not be completely reprocessed. The model to provider mapping is still cached so that the client can be recreated when necessary. Signed-off-by: Donnie Adams --- pkg/remote/remote.go | 52 +++++++++++++++----------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 6d83e6cc..89863529 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -22,10 +22,9 @@ import ( ) type Client struct { - clientsLock sync.Mutex + modelsLock sync.Mutex cache *cache.Client - clients map[string]*openai.Client - models map[string]*openai.Client + modelToProvider map[string]string runner *runner.Runner envs []string credStore credentials.CredentialStore @@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent } func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { - c.clientsLock.Lock() - client, ok := c.models[messageRequest.Model] - c.clientsLock.Unlock() + c.modelsLock.Lock() + provider, ok := c.modelToProvider[messageRequest.Model] + c.modelsLock.Unlock() if !ok { return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) } + client, err := c.load(ctx, provider) + if err != nil { + return nil, err + } + toolName, modelName := types.SplitToolRef(messageRequest.Model) if modelName == "" { // modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider' @@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error) return false, nil } - client, err := c.load(ctx, providerName) + _, err := c.load(ctx, providerName) if err != nil { return false, err } - c.clientsLock.Lock() - defer c.clientsLock.Unlock() + c.modelsLock.Lock() + defer c.modelsLock.Unlock() - if c.models == nil { - c.models = map[string]*openai.Client{} + if c.modelToProvider == nil { + c.modelToProvider = map[string]string{} } - c.models[modelString] = client + c.modelToProvider[modelString] = providerName return true, nil } @@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie } func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { - c.clientsLock.Lock() - defer c.clientsLock.Unlock() - - client, ok := c.clients[toolName] - if ok { - return client, nil - } - - if c.clients == nil { - c.clients = make(map[string]*openai.Client) - } - if isHTTPURL(toolName) { remoteClient, err := c.clientFromURL(ctx, toolName) if err != nil { return nil, err } - c.clients[toolName] = remoteClient return remoteClient, nil } @@ -174,14 +165,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - if strings.HasSuffix(url, "/") { - url += "v1" - } else { - url += "/v1" - } - - client, err = openai.NewClient(ctx, c.credStore, openai.Options{ - BaseURL: url, + client, err := openai.NewClient(ctx, c.credStore, openai.Options{ + BaseURL: strings.TrimSuffix(url, "/") + "/v1", Cache: c.cache, CacheKey: prg.EntryToolID, }) @@ -189,7 +174,6 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - c.clients[toolName] = client return client, nil }