diff --git a/pkg/repos/get.go b/pkg/repos/get.go index b43bc63b..416f4c61 100644 --- a/pkg/repos/get.go +++ b/pkg/repos/get.go @@ -27,6 +27,7 @@ const credentialHelpersRepo = "github.com/gptscript-ai/gptscript-credential-help type Runtime interface { ID() string Supports(tool types.Tool, cmd []string) bool + Binary(ctx context.Context, tool types.Tool, dataRoot, toolSource string, env []string) (bool, []string, error) Setup(ctx context.Context, tool types.Tool, dataRoot, toolSource string, env []string) ([]string, error) GetHash(tool types.Tool) (string, error) } @@ -46,6 +47,10 @@ func (n noopRuntime) Supports(_ types.Tool, _ []string) bool { return false } +func (n noopRuntime) Binary(_ context.Context, _ types.Tool, _, _ string, _ []string) (bool, []string, error) { + return false, nil, nil +} + func (n noopRuntime) Setup(_ context.Context, _ types.Tool, _, _ string, _ []string) ([]string, error) { return nil, nil } @@ -211,21 +216,30 @@ func (m *Manager) setup(ctx context.Context, runtime Runtime, tool types.Tool, e _ = os.RemoveAll(doneFile) _ = os.RemoveAll(target) - if tool.Source.Repo.VCS == "git" { - if err := git.Checkout(ctx, m.gitDir, tool.Source.Repo.Root, tool.Source.Repo.Revision, target); err != nil { - return "", nil, err + var ( + newEnv []string + isBinary bool + ) + + if isBinary, newEnv, err = runtime.Binary(ctx, tool, m.runtimeDir, targetFinal, env); err != nil { + return "", nil, err + } else if !isBinary { + if tool.Source.Repo.VCS == "git" { + if err := git.Checkout(ctx, m.gitDir, tool.Source.Repo.Root, tool.Source.Repo.Revision, target); err != nil { + return "", nil, err + } + } else { + if err := os.MkdirAll(target, 0755); err != nil { + return "", nil, err + } } - } else { - if err := os.MkdirAll(target, 0755); err != nil { + + newEnv, err = runtime.Setup(ctx, tool, m.runtimeDir, targetFinal, env) + if err != nil { return "", nil, err } } - newEnv, err := runtime.Setup(ctx, tool, m.runtimeDir, targetFinal, env) - if err != nil { - return "", nil, err - } - out, err := os.Create(doneFile + ".tmp") if err != nil { return "", nil, err diff --git a/pkg/repos/runtimes/busybox/busybox.go b/pkg/repos/runtimes/busybox/busybox.go index 481ed1fe..e4604b06 100644 --- a/pkg/repos/runtimes/busybox/busybox.go +++ b/pkg/repos/runtimes/busybox/busybox.go @@ -49,6 +49,10 @@ func (r *Runtime) Supports(_ types.Tool, cmd []string) bool { return false } +func (r *Runtime) Binary(_ context.Context, _ types.Tool, _, _ string, _ []string) (bool, []string, error) { + return false, nil, nil +} + func (r *Runtime) Setup(ctx context.Context, _ types.Tool, dataRoot, _ string, env []string) ([]string, error) { binPath, err := r.getRuntime(ctx, dataRoot) if err != nil { diff --git a/pkg/repos/runtimes/golang/golang.go b/pkg/repos/runtimes/golang/golang.go index 882e8a0b..9e472e90 100644 --- a/pkg/repos/runtimes/golang/golang.go +++ b/pkg/repos/runtimes/golang/golang.go @@ -4,10 +4,14 @@ import ( "bufio" "bytes" "context" + "crypto/sha256" _ "embed" + "encoding/hex" "errors" "fmt" + "io" "io/fs" + "net/http" "os" "path/filepath" "runtime" @@ -44,6 +48,183 @@ func (r *Runtime) Supports(tool types.Tool, cmd []string) bool { len(cmd) > 0 && cmd[0] == "${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool" } +type release struct { + account, repo, label string +} + +func (r release) checksumTxt() string { + return fmt.Sprintf( + "https://github.com/%s/%s/releases/download/%s/checksums.txt", + r.account, + r.repo, + r.label) +} + +func (r release) binURL() string { + return fmt.Sprintf( + "https://github.com/%s/%s/releases/download/%s/%s", + r.account, + r.repo, + r.label, + r.srcBinName()) +} + +func (r release) targetBinName() string { + suffix := "" + if runtime.GOOS == "windows" { + suffix = ".exe" + } + + return "gptscript-go-tool" + suffix +} + +func (r release) srcBinName() string { + suffix := "" + if runtime.GOOS == "windows" { + suffix = ".exe" + } + + return r.repo + "-" + + runtime.GOOS + "-" + + runtime.GOARCH + suffix +} + +func getLatestRelease(tool types.Tool) (*release, bool) { + if tool.Source.Repo == nil || !strings.HasPrefix(tool.Source.Repo.Root, "https://github.com/") { + return nil, false + } + + parts := strings.Split(strings.TrimPrefix(strings.TrimSuffix(tool.Source.Repo.Root, ".git"), "https://"), "/") + if len(parts) != 3 { + return nil, false + } + + client := http.Client{ + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(fmt.Sprintf("https://github.com/%s/%s/releases/latest", parts[1], parts[2])) + if err != nil || resp.StatusCode != http.StatusFound { + // ignore error + return nil, false + } + defer resp.Body.Close() + + target := resp.Header.Get("Location") + if target == "" { + return nil, false + } + + account, repo := parts[1], parts[2] + parts = strings.Split(target, "/") + label := parts[len(parts)-1] + + return &release{ + account: account, + repo: repo, + label: label, + }, true +} + +func get(ctx context.Context, url string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } else if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("bad HTTP status code: %d", resp.StatusCode) + } + + return resp, nil +} + +func downloadBin(ctx context.Context, checksum, src, url, bin string) error { + resp, err := get(ctx, url) + if err != nil { + return err + } + defer resp.Body.Close() + + if err := os.MkdirAll(filepath.Join(src, "bin"), 0755); err != nil { + return err + } + + targetFile, err := os.Create(filepath.Join(src, "bin", bin)) + if err != nil { + return err + } + + digest := sha256.New() + + if _, err := io.Copy(io.MultiWriter(targetFile, digest), resp.Body); err != nil { + return err + } + + if err := targetFile.Close(); err != nil { + return nil + } + + if got := hex.EncodeToString(digest.Sum(nil)); got != checksum { + return fmt.Errorf("checksum mismatch %s != %s", got, checksum) + } + + if err := os.Chmod(targetFile.Name(), 0755); err != nil { + return err + } + + return nil +} + +func getChecksum(ctx context.Context, rel *release) string { + resp, err := get(ctx, rel.checksumTxt()) + if err != nil { + // ignore error + return "" + } + defer resp.Body.Close() + + scan := bufio.NewScanner(resp.Body) + for scan.Scan() { + fields := strings.Fields(scan.Text()) + if len(fields) != 2 || fields[1] != rel.srcBinName() { + continue + } + return fields[0] + } + + return "" +} + +func (r *Runtime) Binary(ctx context.Context, tool types.Tool, _, toolSource string, env []string) (bool, []string, error) { + if !tool.Source.IsGit() { + return false, nil, nil + } + + rel, ok := getLatestRelease(tool) + if !ok { + return false, nil, nil + } + + checksum := getChecksum(ctx, rel) + if checksum == "" { + return false, nil, nil + } + + if err := downloadBin(ctx, checksum, toolSource, rel.binURL(), rel.targetBinName()); err != nil { + // ignore error + return false, nil, nil + } + + return true, env, nil +} + func (r *Runtime) Setup(ctx context.Context, _ types.Tool, dataRoot, toolSource string, env []string) ([]string, error) { binPath, err := r.getRuntime(ctx, dataRoot) if err != nil { diff --git a/pkg/repos/runtimes/node/node.go b/pkg/repos/runtimes/node/node.go index d0a9d8cb..01a752e6 100644 --- a/pkg/repos/runtimes/node/node.go +++ b/pkg/repos/runtimes/node/node.go @@ -39,6 +39,10 @@ func (r *Runtime) ID() string { return "node" + r.Version } +func (r *Runtime) Binary(_ context.Context, _ types.Tool, _, _ string, _ []string) (bool, []string, error) { + return false, nil, nil +} + func (r *Runtime) Supports(_ types.Tool, cmd []string) bool { for _, testCmd := range []string{"node", "npx", "npm"} { if r.supports(testCmd, cmd) { diff --git a/pkg/repos/runtimes/python/python.go b/pkg/repos/runtimes/python/python.go index 87b072e5..ee4bf571 100644 --- a/pkg/repos/runtimes/python/python.go +++ b/pkg/repos/runtimes/python/python.go @@ -175,6 +175,10 @@ func (r *Runtime) getReleaseAndDigest() (string, string, error) { return "", "", fmt.Errorf("failed to find an python runtime for %s", r.Version) } +func (r *Runtime) Binary(_ context.Context, _ types.Tool, _, _ string, _ []string) (bool, []string, error) { + return false, nil, nil +} + func (r *Runtime) GetHash(tool types.Tool) (string, error) { if !tool.Source.IsGit() && tool.WorkingDir != "" { if _, ok := tool.MetaData[requirementsTxt]; ok {