Skip to content

feat: add workspace functions #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 100 additions & 13 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ import (
)

var tools = map[string]types.Tool{
"sys.workspace.ls": {
Parameters: types.Parameters{
Description: "Lists the contents of a directory relative to the current workspace",
Arguments: types.ObjectSchema(
"dir", "The directory to list"),
},
BuiltinFunc: SysWorkspaceLs,
},
"sys.workspace.write": {
Parameters: types.Parameters{
Description: "Write the contents to a file relative to the current workspace",
Arguments: types.ObjectSchema(
"filename", "The name of the file to write to",
"content", "The content to write"),
},
BuiltinFunc: SysWorkspaceWrite,
},
"sys.workspace.read": {
Parameters: types.Parameters{
Description: "Reads the contents of a file relative to the current workspace",
Arguments: types.ObjectSchema(
"filename", "The name of the file to read"),
},
BuiltinFunc: SysWorkspaceRead,
},
"sys.ls": {
Parameters: types.Parameters{
Description: "Lists the contents of a directory",
Expand Down Expand Up @@ -297,19 +322,46 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
return string(out), err
}

func getWorkspaceDir(envs []string) (string, error) {
for _, env := range envs {
dir, ok := strings.CutPrefix(env, "GPTSCRIPT_WORKSPACE_DIR=")
if ok && dir != "" {
return dir, nil
}
}
return "", fmt.Errorf("no workspace directory found in env")
}

func SysWorkspaceLs(_ context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}
return sysLs(dir, input)
}

func SysLs(_ context.Context, _ []string, input string) (string, error) {
return sysLs("", input)
}

func sysLs(base, input string) (string, error) {
var params struct {
Dir string `json:"dir,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}

if params.Dir == "" {
params.Dir = "."
dir := params.Dir
if dir == "" {
dir = "."
}

if base != "" {
dir = filepath.Join(base, dir)
}

entries, err := os.ReadDir(params.Dir)
entries, err := os.ReadDir(dir)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Sprintf("directory does not exist: %s", params.Dir), nil
} else if err != nil {
Expand All @@ -328,20 +380,38 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysWorkspaceRead(ctx context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}

return sysRead(ctx, dir, env, input)
}

func SysRead(ctx context.Context, env []string, input string) (string, error) {
return sysRead(ctx, "", env, input)
}

func sysRead(ctx context.Context, base string, env []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}

file := params.Filename
if base != "" {
file = filepath.Join(base, file)
}

// Lock the file to prevent concurrent writes from other tool calls.
locker.RLock(params.Filename)
defer locker.RUnlock(params.Filename)
locker.RLock(file)
defer locker.RUnlock(file)

log.Debugf("Reading file %s", params.Filename)
data, err := os.ReadFile(params.Filename)
log.Debugf("Reading file %s", file)
data, err := os.ReadFile(file)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Sprintf("The file %s does not exist", params.Filename), nil
} else if err != nil {
Expand All @@ -354,7 +424,19 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) {
return string(data), nil
}

func SysWorkspaceWrite(ctx context.Context, env []string, input string) (string, error) {
dir, err := getWorkspaceDir(env)
if err != nil {
return "", err
}
return sysWrite(ctx, dir, env, input)
}

func SysWrite(ctx context.Context, env []string, input string) (string, error) {
return sysWrite(ctx, "", env, input)
}

func sysWrite(ctx context.Context, base string, env []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand All @@ -363,28 +445,33 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) {
return "", err
}

file := params.Filename
if base != "" {
file = filepath.Join(base, file)
}

// Lock the file to prevent concurrent writes from other tool calls.
locker.Lock(params.Filename)
defer locker.Unlock(params.Filename)
locker.Lock(file)
defer locker.Unlock(file)

dir := filepath.Dir(params.Filename)
dir := filepath.Dir(file)
if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) {
log.Debugf("Creating dir %s", dir)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("creating dir %s: %w", dir, err)
}
}

if _, err := os.Stat(params.Filename); err == nil {
if _, err := os.Stat(file); err == nil {
if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil {
return "", err
}
}

data := []byte(params.Content)
log.Debugf("Wrote %d bytes to file %s", len(data), params.Filename)
log.Debugf("Wrote %d bytes to file %s", len(data), file)

return "", os.WriteFile(params.Filename, data, 0644)
return "", os.WriteFile(file, data, 0644)
}

func SysAppend(ctx context.Context, env []string, input string) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg
prompter Prompter
)

prompter, err := newReadlinePrompter()
prompter, err := newReadlinePrompter(prg)
if err != nil {
return err
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/chat/readline.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/adrg/xdg"
"github.com/chzyer/readline"
"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/mvl"
)

Expand All @@ -18,8 +19,13 @@ type readlinePrompter struct {
readliner *readline.Instance
}

func newReadlinePrompter() (*readlinePrompter, error) {
historyFile, err := xdg.CacheFile("gptscript/chat.history")
func newReadlinePrompter(prg GetProgram) (*readlinePrompter, error) {
targetProgram, err := prg()
if err != nil {
return nil, err
}

historyFile, err := xdg.CacheFile(fmt.Sprintf("gptscript/chat-%s.history", hash.ID(targetProgram.EntryToolID)))
if err != nil {
historyFile = ""
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type GPTScript struct {
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`

readData []byte
}
Expand Down Expand Up @@ -123,6 +124,7 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
Quiet: r.Quiet,
Env: os.Environ(),
CredentialContext: r.CredentialContext,
Workspace: r.Workspace,
}

if r.Ports != "" {
Expand Down
60 changes: 51 additions & 9 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,41 @@ package gptscript

import (
"context"
"fmt"
"os"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/llm"
"github.com/gptscript-ai/gptscript/pkg/monitor"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/remote"
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
"github.com/gptscript-ai/gptscript/pkg/runner"
"github.com/gptscript-ai/gptscript/pkg/types"
)

var log = mvl.Package()

type GPTScript struct {
Registry *llm.Registry
Runner *runner.Runner
Registry *llm.Registry
Runner *runner.Runner
WorkspacePath string
DeleteWorkspaceOnClose bool
}

type Options struct {
Cache cache.Options
OpenAI openai.Options
Monitor monitor.Options
Runner runner.Options
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"`
Env []string `usage:"-"`
CredentialContext string
Quiet *bool
Workspace string
Env []string
}

func complete(opts *Options) (result *Options) {
Expand Down Expand Up @@ -89,21 +97,55 @@ func New(opts *Options) (*GPTScript, error) {
}

return &GPTScript{
Registry: registry,
Runner: runner,
Registry: registry,
Runner: runner,
WorkspacePath: opts.Workspace,
DeleteWorkspaceOnClose: opts.Workspace == "",
}, nil
}

func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (runner.ChatResponse, error) {
return g.Runner.Chat(ctx, prevState, prg, env, input)
func (g *GPTScript) getEnv(env []string) ([]string, error) {
if g.WorkspacePath == "" {
var err error
g.WorkspacePath, err = os.MkdirTemp("", "gptscript-workspace-*")
if err != nil {
return nil, err
}
}
if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil {
return nil, err
}
return append([]string{
fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath),
fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)),
}, env...), nil
}

func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) {
envs, err := g.getEnv(envs)
if err != nil {
return runner.ChatResponse{}, err
}

return g.Runner.Chat(ctx, prevState, prg, envs, input)
}

func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) {
envs, err := g.getEnv(envs)
if err != nil {
return "", err
}

return g.Runner.Run(ctx, prg, envs, input)
}

func (g *GPTScript) Close() {
g.Runner.Close()
if g.DeleteWorkspaceOnClose && g.WorkspacePath != "" {
if err := os.RemoveAll(g.WorkspacePath); err != nil {
log.Errorf("failed to delete workspace %s: %s", g.WorkspacePath, err)
}
}
}

func (g *GPTScript) GetModel() engine.Model {
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
value = strings.TrimSpace(value)
switch normalize(key) {
case "name":
tool.Parameters.Name = strings.ToLower(value)
tool.Parameters.Name = value
case "modelprovider":
tool.Parameters.ModelProvider = true
case "model", "modelname":
Expand Down
Loading