From b08b1bc08d7c161249a236a761a5926cca43120c Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Tue, 7 May 2024 15:28:12 -0700 Subject: [PATCH] chore: cache github and http lookups in loader --- Makefile | 3 ++ pkg/cache/cache.go | 70 +++++++++++++++++++++++++++++++------ pkg/cli/eval.go | 8 +++-- pkg/cli/gptscript.go | 14 +++++--- pkg/debugcmd/debug.go | 14 ++++++++ pkg/gptscript/gptscript.go | 2 ++ pkg/hash/seed.go | 10 ++++-- pkg/hash/sha256.go | 30 +++------------- pkg/loader/github/github.go | 24 ++++++++----- pkg/loader/loader.go | 68 +++++++++++++++++++++++------------ pkg/loader/url.go | 58 +++++++++++++++++++++++++----- pkg/openai/client.go | 51 ++++++++------------------- pkg/remote/remote.go | 4 ++- pkg/repos/git/cmd.go | 19 ++++++++++ pkg/server/server.go | 8 +++-- pkg/types/completion.go | 4 +++ 16 files changed, 262 insertions(+), 125 deletions(-) diff --git a/Makefile b/Makefile index ebee145d..4e5c9796 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ build-ui: touch static/ui/placeholder static/ui/_nuxt/_placeholder cp -rp ui/.output/public/* static/ui/ +build-exe: + GOOS=windows go build -o bin/gptscript.exe -tags "${GO_TAGS}" . + build: CGO_ENABLED=0 go build -o bin/gptscript -tags "${GO_TAGS}" -ldflags "-s -w" . diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index bb2941f3..581332a3 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -2,12 +2,17 @@ package cache import ( "context" + "crypto/sha256" + "encoding/gob" + "encoding/hex" "errors" "io/fs" "os" "path/filepath" "github.com/adrg/xdg" + "github.com/getkin/kin-openapi/openapi3" + openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) @@ -22,6 +27,11 @@ type Options struct { CacheDir string `usage:"Directory to store cache (default: $XDG_CACHE_HOME/gptscript)"` } +func init() { + gob.Register(openai.ChatCompletionRequest{}) + gob.Register(openapi3.Schema{}) +} + func Complete(opts ...Options) (result Options) { for _, opt := range opts { result.CacheDir = types.FirstSet(opt.CacheDir, result.CacheDir) @@ -59,22 +69,62 @@ func (c *Client) CacheDir() string { return c.dir } -func (c *Client) Store(key string, content []byte) error { - if c == nil || c.noop { +func (c *Client) cacheKey(key any) (string, error) { + hash := sha256.New() + if err := gob.NewEncoder(hash).Encode(key); err != nil { + return "", err + } + digest := hash.Sum(nil) + return hex.EncodeToString(digest), nil +} + +func (c *Client) Store(ctx context.Context, key, value any) error { + if c == nil { return nil } - return os.WriteFile(filepath.Join(c.dir, key), content, 0644) + + if c.noop || IsNoCache(ctx) { + keyValue, err := c.cacheKey(key) + if err == nil { + p := filepath.Join(c.dir, keyValue) + if _, err := os.Stat(p); err == nil { + _ = os.Remove(p) + } + } + return nil + } + + keyValue, err := c.cacheKey(key) + if err != nil { + return err + } + + f, err := os.Create(filepath.Join(c.dir, keyValue)) + if err != nil { + return err + } + defer f.Close() + + return gob.NewEncoder(f).Encode(value) } -func (c *Client) Get(key string) ([]byte, bool, error) { - if c == nil || c.noop { - return nil, false, nil +func (c *Client) Get(ctx context.Context, key, out any) (bool, error) { + if c == nil || c.noop || IsNoCache(ctx) { + return false, nil } - data, err := os.ReadFile(filepath.Join(c.dir, key)) + + keyValue, err := c.cacheKey(key) + if err != nil { + return false, err + } + + f, err := os.Open(filepath.Join(c.dir, keyValue)) if errors.Is(err, fs.ErrNotExist) { - return nil, false, nil + return false, nil } else if err != nil { - return nil, false, err + return false, err } - return data, true, nil + defer f.Close() + + return gob.NewDecoder(f).Decode(out) == nil, nil } diff --git a/pkg/cli/eval.go b/pkg/cli/eval.go index be59ece5..a51b7f3b 100644 --- a/pkg/cli/eval.go +++ b/pkg/cli/eval.go @@ -49,17 +49,19 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error { tool.Temperature = &temp32 } - prg, err := loader.ProgramFromSource(cmd.Context(), tool.String(), "") + opts, err := e.gptscript.NewGPTScriptOpts() if err != nil { return err } - opts, err := e.gptscript.NewGPTScriptOpts() + runner, err := gptscript.New(&opts) if err != nil { return err } - runner, err := gptscript.New(&opts) + prg, err := loader.ProgramFromSource(cmd.Context(), tool.String(), "", loader.Options{ + Cache: runner.Cache, + }) if err != nil { return err } diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 5d6a9e5d..10eb4179 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -259,7 +259,7 @@ func (r *GPTScript) listModels(ctx context.Context, gptScript *gptscript.GPTScri return nil } -func (r *GPTScript) readProgram(ctx context.Context, args []string) (prg types.Program, err error) { +func (r *GPTScript) readProgram(ctx context.Context, runner *gptscript.GPTScript, args []string) (prg types.Program, err error) { if len(args) == 0 { return } @@ -278,10 +278,14 @@ func (r *GPTScript) readProgram(ctx context.Context, args []string) (prg types.P } r.readData = data } - return loader.ProgramFromSource(ctx, string(data), r.SubTool) + return loader.ProgramFromSource(ctx, string(data), r.SubTool, loader.Options{ + Cache: runner.Cache, + }) } - return loader.Program(ctx, args[0], r.SubTool) + return loader.Program(ctx, args[0], r.SubTool, loader.Options{ + Cache: runner.Cache, + }) } func (r *GPTScript) PrintOutput(toolInput, toolOutput string) (err error) { @@ -337,7 +341,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { return r.listModels(ctx, gptScript, args) } - prg, err := r.readProgram(ctx, args) + prg, err := r.readProgram(ctx, gptScript, args) if err != nil { return err } @@ -392,7 +396,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { if prg.IsChat() || r.ForceChat { return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) { - return r.readProgram(ctx, args) + return r.readProgram(ctx, gptScript, args) }, os.Environ(), toolInput) } diff --git a/pkg/debugcmd/debug.go b/pkg/debugcmd/debug.go index 645910a4..e6e38d7e 100644 --- a/pkg/debugcmd/debug.go +++ b/pkg/debugcmd/debug.go @@ -16,6 +16,10 @@ type WrappedCmd struct { Dir string } +func (w *WrappedCmd) Stdout() string { + return w.r.Stdout() +} + func (w *WrappedCmd) Run() error { if len(w.Env) > 0 { w.c.Env = w.Env @@ -51,6 +55,16 @@ type recorder struct { entries []entry } +func (r *recorder) Stdout() string { + buf := strings.Builder{} + for _, e := range r.entries { + if !e.err { + buf.Write(e.data) + } + } + return buf.String() +} + func (r *recorder) dump() string { var errMessage strings.Builder for _, entry := range r.entries { diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 4f638ea9..931f6f1f 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -24,6 +24,7 @@ var log = mvl.Package() type GPTScript struct { Registry *llm.Registry Runner *runner.Runner + Cache *cache.Client WorkspacePath string DeleteWorkspaceOnClose bool } @@ -99,6 +100,7 @@ func New(opts *Options) (*GPTScript, error) { return &GPTScript{ Registry: registry, Runner: runner, + Cache: cacheClient, WorkspacePath: opts.Workspace, DeleteWorkspaceOnClose: opts.Workspace == "", }, nil diff --git a/pkg/hash/seed.go b/pkg/hash/seed.go index b70475e1..d5e7ab8c 100644 --- a/pkg/hash/seed.go +++ b/pkg/hash/seed.go @@ -1,10 +1,14 @@ package hash -import "hash/fnv" +import ( + "encoding/gob" + "hash/fnv" +) func Seed(input any) int { - s := Encode(input) h := fnv.New32a() - _, _ = h.Write([]byte(s)) + if err := gob.NewEncoder(h).Encode(input); err != nil { + panic(err) + } return int(h.Sum32()) } diff --git a/pkg/hash/sha256.go b/pkg/hash/sha256.go index fc7aa3f4..1bbb6117 100644 --- a/pkg/hash/sha256.go +++ b/pkg/hash/sha256.go @@ -2,8 +2,8 @@ package hash import ( "crypto/sha256" + "encoding/gob" "encoding/hex" - "encoding/json" ) func ID(parts ...string) string { @@ -19,31 +19,9 @@ func ID(parts ...string) string { } func Digest(obj any) string { - data, err := json.Marshal(obj) - if err != nil { + hash := sha256.New() + if err := gob.NewEncoder(hash).Encode(obj); err != nil { panic(err) } - - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) -} - -func Encode(obj any) string { - data, err := json.Marshal(obj) - if err != nil { - panic(err) - } - - asMap := map[string]any{} - if err := json.Unmarshal(data, &asMap); err != nil { - panic(err) - } - - data, err = json.Marshal(asMap) - if err != nil { - panic(err) - } - - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) + return hex.EncodeToString(hash.Sum(nil)) } diff --git a/pkg/loader/github/github.go b/pkg/loader/github/github.go index 7caaf0a9..499f0484 100644 --- a/pkg/loader/github/github.go +++ b/pkg/loader/github/github.go @@ -1,6 +1,7 @@ package github import ( + "context" "encoding/json" "fmt" "io" @@ -9,7 +10,9 @@ import ( "path/filepath" "strings" + "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/repos/git" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -29,11 +32,14 @@ func init() { loader.AddVSC(Load) } -func getCommit(account, repo, ref string) (string, error) { - url := fmt.Sprintf(githubCommitURL, account, repo, ref) - client := &http.Client{} +func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string, error) { + url := fmt.Sprintf(githubRepoURL, account, repo) + return git.LsRemote(ctx, url, ref) +} - req, err := http.NewRequest(http.MethodGet, url, nil) +func getCommit(ctx context.Context, account, repo, ref string) (string, error) { + url := fmt.Sprintf(githubCommitURL, account, repo, ref) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", fmt.Errorf("failed to create request of %s/%s at %s: %w", account, repo, url, err) } @@ -42,13 +48,15 @@ func getCommit(account, repo, ref string) (string, error) { req.Header.Add("Authorization", "Bearer "+githubAuthToken) } - resp, err := client.Do(req) - + resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } else if resp.StatusCode != http.StatusOK { c, _ := io.ReadAll(resp.Body) resp.Body.Close() + if commit, err := getCommitLsRemote(ctx, account, repo, ref); err == nil { + return commit, nil + } return "", fmt.Errorf("failed to get GitHub commit of %s/%s at %s: %s %s", account, repo, ref, resp.Status, c) } @@ -68,7 +76,7 @@ func getCommit(account, repo, ref string) (string, error) { return commit.SHA, nil } -func Load(urlName string) (string, *types.Repo, bool, error) { +func Load(ctx context.Context, _ *cache.Client, urlName string) (string, *types.Repo, bool, error) { if !strings.HasPrefix(urlName, GithubPrefix) { return "", nil, false, nil } @@ -93,7 +101,7 @@ func Load(urlName string) (string, *types.Repo, bool, error) { path += "/tool.gpt" } - ref, err := getCommit(account, repo, ref) + ref, err := getCommit(ctx, account, repo, ref) if err != nil { return "", nil, false, err } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 250276d0..2159e0b3 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -13,20 +13,24 @@ import ( "path/filepath" "slices" "strings" + "time" "unicode/utf8" "github.com/getkin/kin-openapi/openapi3" "github.com/gptscript-ai/gptscript/pkg/assemble" "github.com/gptscript-ai/gptscript/pkg/builtin" + "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" "gopkg.in/yaml.v3" ) +const CacheTimeout = time.Hour + type source struct { // Content The content of the source - Content io.ReadCloser + Content []byte // Remote indicates that this file was loaded from a remote source (not local disk) Remote bool // Path is the path of this source used to find any relative references to this source @@ -68,8 +72,15 @@ func loadLocal(base *source, name string) (*source, bool, error) { } log.Debugf("opened %s", path) + defer content.Close() + + data, err := io.ReadAll(content) + if err != nil { + return nil, false, err + } + return &source{ - Content: content, + Content: data, Remote: false, Path: filepath.Dir(path), Name: filepath.Base(path), @@ -109,12 +120,8 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types return tool, nil } -func readTool(ctx context.Context, prg *types.Program, base *source, targetToolName string) (types.Tool, error) { - data, err := io.ReadAll(base.Content) - if err != nil { - return types.Tool{}, err - } - _ = base.Content.Close() +func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) { + data := base.Content if bytes.HasPrefix(data, assemble.Header) { return loadProgram(data, prg, targetToolName) @@ -147,6 +154,7 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN // If we didn't get any tools from trying to parse it as OpenAPI, try to parse it as a GPTScript if len(tools) == 0 { + var err error tools, err = parser.ParseTools(bytes.NewReader(data), parser.Options{ AssignGlobals: true, }) @@ -193,10 +201,10 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN localTools[strings.ToLower(tool.Parameters.Name)] = tool } - return link(ctx, prg, base, mainTool, localTools) + return link(ctx, cache, prg, base, mainTool, localTools) } -func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -225,7 +233,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool linkedTool = existing } else { var err error - linkedTool, err = link(ctx, prg, base, localTool, localTools) + linkedTool, err = link(ctx, cache, prg, base, localTool, localTools) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -235,7 +243,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool toolNames[targetToolName] = struct{}{} } else { toolName, subTool := SplitToolRef(targetToolName) - resolvedTool, err := resolve(ctx, prg, base, toolName, subTool) + resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err) } @@ -254,12 +262,14 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool return tool, nil } -func ProgramFromSource(ctx context.Context, content, subToolName string) (types.Program, error) { +func ProgramFromSource(ctx context.Context, content, subToolName string, opts ...Options) (types.Program, error) { + opt := complete(opts...) + prg := types.Program{ ToolSet: types.ToolSet{}, } - tool, err := readTool(ctx, &prg, &source{ - Content: io.NopCloser(strings.NewReader(content)), + tool, err := readTool(ctx, opt.Cache, &prg, &source{ + Content: []byte(content), Location: "inline", }, subToolName) if err != nil { @@ -269,7 +279,21 @@ func ProgramFromSource(ctx context.Context, content, subToolName string) (types. return prg, nil } -func Program(ctx context.Context, name, subToolName string) (types.Program, error) { +type Options struct { + Cache *cache.Client +} + +func complete(opts ...Options) (result Options) { + for _, opt := range opts { + result.Cache = types.FirstSet(opt.Cache, result.Cache) + } + + return +} + +func Program(ctx context.Context, name, subToolName string, opts ...Options) (types.Program, error) { + opt := complete(opts...) + if subToolName == "" { name, subToolName = SplitToolRef(name) } @@ -277,7 +301,7 @@ func Program(ctx context.Context, name, subToolName string) (types.Program, erro Name: name, ToolSet: types.ToolSet{}, } - tool, err := resolve(ctx, &prg, &source{}, name, subToolName) + tool, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName) if err != nil { return types.Program{}, err } @@ -285,7 +309,7 @@ func Program(ctx context.Context, name, subToolName string) (types.Program, erro return prg, nil } -func resolve(ctx context.Context, prg *types.Program, base *source, name, subTool string) (types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) (types.Tool, error) { if subTool == "" { t, ok := builtin.Builtin(name) if ok { @@ -294,15 +318,15 @@ func resolve(ctx context.Context, prg *types.Program, base *source, name, subToo } } - s, err := input(ctx, base, name) + s, err := input(ctx, cache, base, name) if err != nil { return types.Tool{}, err } - return readTool(ctx, prg, s, subTool) + return readTool(ctx, cache, prg, s, subTool) } -func input(ctx context.Context, base *source, name string) (*source, error) { +func input(ctx context.Context, cache *cache.Client, base *source, name string) (*source, error) { if strings.HasPrefix(name, "http://") || strings.HasPrefix(name, "https://") { base.Remote = true } @@ -314,7 +338,7 @@ func input(ctx context.Context, base *source, name string) (*source, error) { } } - s, ok, err := loadURL(ctx, base, name) + s, ok, err := loadURL(ctx, cache, base, name) if err != nil || ok { return s, err } diff --git a/pkg/loader/url.go b/pkg/loader/url.go index 9a6ef19b..4a26b5de 100644 --- a/pkg/loader/url.go +++ b/pkg/loader/url.go @@ -3,15 +3,18 @@ package loader import ( "context" "fmt" + "io" "net/http" url2 "net/url" "path" "strings" + "time" + "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/types" ) -type VCSLookup func(string) (string, *types.Repo, bool, error) +type VCSLookup func(context.Context, *cache.Client, string) (string, *types.Repo, bool, error) var vcsLookups []VCSLookup @@ -19,13 +22,36 @@ func AddVSC(lookup VCSLookup) { vcsLookups = append(vcsLookups, lookup) } -func loadURL(ctx context.Context, base *source, name string) (*source, bool, error) { +type cacheKey struct { + Name string + Path string + Repo *types.Repo +} + +type cacheValue struct { + Source *source + Time time.Time +} + +func loadURL(ctx context.Context, cache *cache.Client, base *source, name string) (*source, bool, error) { var ( - repo *types.Repo - url = name - relative = strings.HasPrefix(name, ".") || !strings.Contains(name, "/") + repo *types.Repo + url = name + relative = strings.HasPrefix(name, ".") || !strings.Contains(name, "/") + cachedKey = cacheKey{ + Name: name, + Path: base.Path, + Repo: base.Repo, + } + cachedValue cacheValue ) + if ok, err := cache.Get(ctx, cachedKey, &cachedValue); err != nil { + return nil, false, err + } else if ok && time.Since(cachedValue.Time) < CacheTimeout { + return cachedValue.Source, true, nil + } + if base.Path != "" && relative { // Don't use path.Join because this is a URL and will break the :// protocol by cleaning it url = base.Path + "/" + name @@ -41,7 +67,7 @@ func loadURL(ctx context.Context, base *source, name string) (*source, bool, err if repo == nil || !relative { for _, vcs := range vcsLookups { - newURL, newRepo, ok, err := vcs(name) + newURL, newRepo, ok, err := vcs(ctx, cache, name) if err != nil { return nil, false, err } else if ok { @@ -88,12 +114,26 @@ func loadURL(ctx context.Context, base *source, name string) (*source, bool, err log.Debugf("opened %s", url) - return &source{ - Content: resp.Body, + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, false, fmt.Errorf("error loading %s: %v", url, err) + } + + result := &source{ + Content: data, Remote: true, Path: pathString, Name: name, Location: url, Repo: repo, - }, true, nil + } + + if err := cache.Store(ctx, cachedKey, cacheValue{ + Source: result, + Time: time.Now(), + }); err != nil { + return nil, false, err + } + + return result, true, nil } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 8cc7c656..1b324ab7 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -1,10 +1,7 @@ package openai import ( - "bytes" - "compress/gzip" "context" - "encoding/json" "fmt" "io" "log/slog" @@ -165,11 +162,11 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] return result, nil } -func (c *Client) cacheKey(request openai.ChatCompletionRequest) string { - return hash.Encode(map[string]any{ +func (c *Client) cacheKey(request openai.ChatCompletionRequest) any { + return map[string]any{ "base": c.cacheKeyBase, "request": request, - }) + } } func (c *Client) seed(request openai.ChatCompletionRequest) int { @@ -192,25 +189,16 @@ func (c *Client) seed(request openai.ChatCompletionRequest) int { } func (c *Client) fromCache(ctx context.Context, messageRequest types.CompletionRequest, request openai.ChatCompletionRequest) (result []openai.ChatCompletionStreamResponse, _ bool, _ error) { - if cache.IsNoCache(ctx) { - return nil, false, nil - } - if messageRequest.Cache != nil && !*messageRequest.Cache { + if !messageRequest.GetCache() { return nil, false, nil } - - cache, found, err := c.cache.Get(c.cacheKey(request)) + found, err := c.cache.Get(ctx, c.cacheKey(request), &result) if err != nil { return nil, false, err } else if !found { return nil, false, nil } - - gz, err := gzip.NewReader(bytes.NewReader(cache)) - if err != nil { - return nil, false, err - } - return result, true, json.NewDecoder(gz).Decode(&result) + return result, true, nil } func toToolCall(call types.CompletionToolCall) openai.ToolCall { @@ -249,6 +237,14 @@ func toMessages(request types.CompletionRequest) (result []openai.ChatCompletion }) } + // Never send only a system message or a system message not followed by a user message + if len(msgs) > 0 && msgs[0].Role == types.CompletionMessageRoleTypeSystem { + if len(msgs) == 1 || + (len(msgs) > 1 && msgs[1].Role != types.CompletionMessageRoleTypeUser) { + msgs[0].Role = types.CompletionMessageRoleTypeUser + } + } + for _, message := range msgs { chatMessage := openai.ChatCompletionMessage{ Role: string(message.Role), @@ -446,24 +442,7 @@ func override(left, right string) string { return left } -func (c *Client) store(ctx context.Context, key string, responses []openai.ChatCompletionStreamResponse) error { - if cache.IsNoCache(ctx) { - return nil - } - buf := &bytes.Buffer{} - gz := gzip.NewWriter(buf) - err := json.NewEncoder(gz).Encode(responses) - if err != nil { - return err - } - if err := gz.Close(); err != nil { - return err - } - return c.cache.Store(key, buf.Bytes()) -} - func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, partial chan<- types.CompletionStatus) (responses []openai.ChatCompletionStreamResponse, _ error) { - cacheKey := c.cacheKey(request) request.Stream = os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false" partial <- types.CompletionStatus{ @@ -513,7 +492,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, for { response, err := stream.Recv() if err == io.EOF { - return responses, c.store(ctx, cacheKey, responses) + return responses, c.cache.Store(ctx, c.cacheKey(request), responses) } else if err != nil { return nil, err } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 754d72e0..35131efa 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -137,7 +137,9 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return remoteClient, nil } - prg, err := loader.Program(ctx, toolName, "") + prg, err := loader.Program(ctx, toolName, "", loader.Options{ + Cache: c.cache, + }) if err != nil { return nil, err } diff --git a/pkg/repos/git/cmd.go b/pkg/repos/git/cmd.go index e3662771..a9eba5fe 100644 --- a/pkg/repos/git/cmd.go +++ b/pkg/repos/git/cmd.go @@ -2,6 +2,8 @@ package git import ( "context" + "fmt" + "strings" "github.com/gptscript-ai/gptscript/pkg/debugcmd" ) @@ -11,6 +13,23 @@ func newGitCommand(ctx context.Context, args ...string) *debugcmd.WrappedCmd { return cmd } +func LsRemote(ctx context.Context, repo, ref string) (string, error) { + cmd := newGitCommand(ctx, "ls-remote", repo, ref) + if err := cmd.Run(); err != nil { + return "", err + } + for _, line := range strings.Split(cmd.Stdout(), "\n") { + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + if fields[1] == ref { + return fields[0], nil + } + } + return "", fmt.Errorf("failed to find remote %q as %q", repo, ref) +} + func cloneBare(ctx context.Context, repo, toDir string) error { cmd := newGitCommand(ctx, "clone", "--bare", "--depth", "1", repo, toDir) return cmd.Run() diff --git a/pkg/server/server.go b/pkg/server/server.go index 9e86adbf..27ce49b1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -110,7 +110,9 @@ func (s *Server) list(rw http.ResponseWriter, req *http.Request) { _ = enc.Encode(builtin.SysProgram()) return } else if strings.HasSuffix(path, system.Suffix) { - prg, err := loader.Program(req.Context(), path, req.URL.Query().Get("tool")) + prg, err := loader.Program(req.Context(), path, req.URL.Query().Get("tool"), loader.Options{ + Cache: s.runner.Cache, + }) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return @@ -151,7 +153,9 @@ func (s *Server) run(rw http.ResponseWriter, req *http.Request) { path += system.Suffix } - prg, err := loader.Program(req.Context(), path, req.URL.Query().Get("tool")) + prg, err := loader.Program(req.Context(), path, req.URL.Query().Get("tool"), loader.Options{ + Cache: s.runner.Cache, + }) if errors.Is(err, fs.ErrNotExist) { http.NotFound(rw, req) return diff --git a/pkg/types/completion.go b/pkg/types/completion.go index e4bb92a5..b0b81774 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -20,6 +20,10 @@ type CompletionRequest struct { Cache *bool } +func (r *CompletionRequest) GetCache() bool { + return r.Cache != nil && !*r.Cache +} + type CompletionTool struct { Function CompletionFunctionDefinition `json:"function,omitempty"` }