diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 40eda367..c9c11396 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -400,7 +400,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { if prg.IsChat() || r.ForceChat { return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) { - return prg, nil + return r.readProgram(ctx, gptScript, args) }, os.Environ(), toolInput) } diff --git a/pkg/counter/counter.go b/pkg/counter/counter.go new file mode 100644 index 00000000..8ba329b7 --- /dev/null +++ b/pkg/counter/counter.go @@ -0,0 +1,17 @@ +package counter + +import ( + "fmt" + "sync/atomic" + "time" +) + +var counter = int32(time.Now().Unix()) + +func Reset(i int32) { + atomic.StoreInt32(&counter, i) +} + +func Next() string { + return fmt.Sprint(atomic.AddInt32(&counter, 1)) +} diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 65c249b4..e1bcd620 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -12,17 +12,17 @@ import ( "runtime" "sort" "strings" - "sync/atomic" "github.com/google/shlex" context2 "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) { - id := fmt.Sprint(atomic.AddInt64(&completionID, 1)) + id := counter.Next() defer func() { e.Progress <- types.CompletionStatus{ diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 72d72827..d081de41 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -6,15 +6,13 @@ import ( "fmt" "strings" "sync" - "sync/atomic" + "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) -var completionID int64 - type Model interface { Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) } @@ -123,12 +121,10 @@ func (c *Context) MarshalJSON() ([]byte, error) { return json.Marshal(c.GetCallContext()) } -var execID int32 - func NewContext(ctx context.Context, prg *types.Program) Context { callCtx := Context{ commonContext: commonContext{ - ID: fmt.Sprint(atomic.AddInt32(&execID, 1)), + ID: counter.Next(), Tool: prg.ToolSet[prg.EntryToolID], }, Ctx: ctx, @@ -144,7 +140,7 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCatego } if callID == "" { - callID = fmt.Sprint(atomic.AddInt32(&execID, 1)) + callID = counter.Next() } return Context{ diff --git a/pkg/engine/print.go b/pkg/engine/print.go index 8d8e3c77..06d76fce 100644 --- a/pkg/engine/print.go +++ b/pkg/engine/print.go @@ -1,15 +1,14 @@ package engine import ( - "fmt" "strings" - "sync/atomic" + "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/types" ) func (e *Engine) runEcho(tool types.Tool) (cmdOut *Return, cmdErr error) { - id := fmt.Sprint(atomic.AddInt64(&completionID, 1)) + id := counter.Next() out := strings.TrimPrefix(tool.Instructions, types.EchoPrefix+"\n") e.Progress <- types.CompletionStatus{ diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index d8cc3563..ce59d3d1 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -221,6 +221,14 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name")) } + if i != 0 && tool.Parameters.GlobalModelName != "" { + return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name")) + } + + if i != 0 && len(tool.Parameters.GlobalTools) > 0 { + return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools")) + } + if targetToolName != "" && strings.EqualFold(tool.Parameters.Name, targetToolName) { mainTool = tool } diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 041c366f..f10ef770 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -13,6 +13,7 @@ import ( "time" "github.com/fatih/color" + "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" @@ -40,12 +41,11 @@ type Console struct { } var ( - runID int64 prettyIDCounter int64 ) func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) { - id := atomic.AddInt64(&runID, 1) + id := counter.Next() mon := newDisplay(c.dumpState, c.displayProgress, c.printMessages) mon.dump.ID = fmt.Sprint(id) mon.dump.Program = prg diff --git a/pkg/openai/client.go b/pkg/openai/client.go index d7a70d5b..6779d73b 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -9,11 +9,11 @@ import ( "slices" "sort" "strings" - "sync/atomic" "github.com/getkin/kin-openapi/openapi3" openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/hash" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" @@ -24,10 +24,9 @@ const ( ) var ( - key = os.Getenv("OPENAI_API_KEY") - url = os.Getenv("OPENAI_URL") - azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT") - completionID int64 + key = os.Getenv("OPENAI_API_KEY") + url = os.Getenv("OPENAI_URL") + azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT") ) type Client struct { @@ -332,7 +331,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques }) } - id := fmt.Sprint(atomic.AddInt64(&completionID, 1)) + id := counter.Next() status <- types.CompletionStatus{ CompletionID: id, Request: request, diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 07ccb626..ccc3dcb8 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -100,13 +100,13 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { return false, err } tool.Parameters.Chat = v - case "export": + case "export", "exporttool", "exports", "exporttools": tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...) case "tool", "tools": tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...) case "globaltool", "globaltools": tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(value)...) - case "exportcontext": + case "exportcontext", "exportcontexts": tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(value)...) case "context": tool.Parameters.Context = append(tool.Parameters.Context, csv(value)...) diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 5f05d6ec..ed87b683 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -669,6 +669,15 @@ func TestCase2(t *testing.T) { assert.Equal(t, "TEST RESULT CALL: 1", x) } +func TestGlobalErr(t *testing.T) { + runner := tester.NewRunner(t) + _, err := runner.Run("", "") + autogold.Expect("line testdata/TestGlobalErr/test.gpt:4: only the first tool in a file can have global model name").Equal(t, err.Error()) + + _, err = runner.Run("test2.gpt", "") + autogold.Expect("line testdata/TestGlobalErr/test2.gpt:4: only the first tool in a file can have global tools").Equal(t, err.Error()) +} + func TestContextArg(t *testing.T) { runner := tester.NewRunner(t) x, err := runner.Run("", `{ diff --git a/pkg/tests/testdata/TestExportContext/call1.golden b/pkg/tests/testdata/TestExportContext/call1.golden index 026f573a..1012a0ec 100644 --- a/pkg/tests/testdata/TestExportContext/call1.golden +++ b/pkg/tests/testdata/TestExportContext/call1.golden @@ -4,7 +4,7 @@ "Tools": [ { "function": { - "toolID": "testdata/TestExportContext/test.gpt:21", + "toolID": "testdata/TestExportContext/test.gpt:22", "name": "subtool", "parameters": { "properties": { @@ -22,7 +22,7 @@ }, { "function": { - "toolID": "testdata/TestExportContext/test.gpt:14", + "toolID": "testdata/TestExportContext/test.gpt:15", "name": "sampletool", "description": "sample", "parameters": { diff --git a/pkg/tests/testdata/TestExportContext/test.gpt b/pkg/tests/testdata/TestExportContext/test.gpt index f015b223..02daec54 100644 --- a/pkg/tests/testdata/TestExportContext/test.gpt +++ b/pkg/tests/testdata/TestExportContext/test.gpt @@ -6,6 +6,7 @@ This is from tool --- name: fromcontext export: sampletool +export context: fromexportcontext #!/bin/bash echo this is from context @@ -19,7 +20,6 @@ Dummy body --- name: subtool -export context: fromexportcontext Dummy body diff --git a/pkg/tests/testdata/TestGlobalErr/test.gpt b/pkg/tests/testdata/TestGlobalErr/test.gpt new file mode 100644 index 00000000..81607071 --- /dev/null +++ b/pkg/tests/testdata/TestGlobalErr/test.gpt @@ -0,0 +1,7 @@ +first + +--- +name: second +global model name: foo + +second \ No newline at end of file diff --git a/pkg/tests/testdata/TestGlobalErr/test2.gpt b/pkg/tests/testdata/TestGlobalErr/test2.gpt new file mode 100644 index 00000000..7b09599e --- /dev/null +++ b/pkg/tests/testdata/TestGlobalErr/test2.gpt @@ -0,0 +1,7 @@ +first + +--- +name: second +global tools: asdf + +second \ No newline at end of file diff --git a/pkg/types/set.go b/pkg/types/set.go new file mode 100644 index 00000000..e8467b3a --- /dev/null +++ b/pkg/types/set.go @@ -0,0 +1,45 @@ +package types + +type toolRefKey struct { + name string + toolID string + arg string +} + +type toolRefSet struct { + set map[toolRefKey]ToolReference + order []toolRefKey + err error +} + +func (t *toolRefSet) List() (result []ToolReference, err error) { + for _, k := range t.order { + result = append(result, t.set[k]) + } + return result, t.err +} + +func (t *toolRefSet) AddAll(values []ToolReference, err error) { + if t.err != nil { + t.err = err + } + for _, v := range values { + t.Add(v) + } +} + +func (t *toolRefSet) Add(value ToolReference) { + key := toolRefKey{ + name: value.Named, + toolID: value.ToolID, + arg: value.Arg, + } + + if _, ok := t.set[key]; !ok { + if t.set == nil { + t.set = map[toolRefKey]ToolReference{} + } + t.set[key] = value + t.order = append(t.order, key) + } +} diff --git a/pkg/types/tool.go b/pkg/types/tool.go index e8814b35..feb3ab09 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -63,59 +63,8 @@ type ToolReference struct { ToolID string } -func (p Program) GetContextToolRefs(toolID string) (result []ToolReference, _ error) { - seen := map[struct { - toolID string - arg string - }]struct{}{} - tool := p.ToolSet[toolID] - - subToolRefs, err := tool.GetToolRefsFromNames(tool.Tools) - if err != nil { - return nil, err - } - - for _, subToolRef := range subToolRefs { - subTool := p.ToolSet[subToolRef.ToolID] - exportContextToolRefs, err := subTool.GetToolRefsFromNames(subTool.ExportContext) - if err != nil { - return nil, err - } - for _, exportContextToolRef := range exportContextToolRefs { - key := struct { - toolID string - arg string - }{ - toolID: exportContextToolRef.ToolID, - arg: exportContextToolRef.Arg, - } - if _, ok := seen[key]; !ok { - seen[key] = struct{}{} - result = append(result, exportContextToolRef) - } - } - } - - contextToolRefs, err := p.ToolSet[toolID].GetToolRefsFromNames(p.ToolSet[toolID].Context) - if err != nil { - return nil, err - } - - for _, contextToolRef := range contextToolRefs { - key := struct { - toolID string - arg string - }{ - toolID: contextToolRef.ToolID, - arg: contextToolRef.Arg, - } - if _, ok := seen[key]; !ok { - seen[key] = struct{}{} - result = append(result, contextToolRef) - } - } - - return +func (p Program) GetContextToolRefs(toolID string) ([]ToolReference, error) { + return p.ToolSet[toolID].GetContextTools(p) } func (p Program) GetCompletionTools() (result []CompletionTool, err error) { @@ -295,105 +244,139 @@ func (t Tool) String() string { return buf.String() } -func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) { - toolNames := map[string]struct{}{} +func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) { + result := &toolRefSet{} - subToolRefs, err := t.GetToolRefsFromNames(t.Parameters.Tools) + exportRefs, err := t.GetToolRefsFromNames(t.ExportContext) if err != nil { return nil, err } - for _, subToolRef := range subToolRefs { - result, err = appendTool(result, prg, t, subToolRef.Reference, toolNames, subToolRef.Named) - if err != nil { - return nil, err - } + for _, exportRef := range exportRefs { + result.Add(exportRef) + + tool := prg.ToolSet[exportRef.ToolID] + result.AddAll(tool.GetExportedContext(prg)) } - for _, subToolName := range t.Parameters.Context { - result, err = appendExports(result, prg, t, subToolName, toolNames) - if err != nil { - return nil, err - } + return result.List() +} + +func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) { + result := &toolRefSet{} + + exportRefs, err := t.GetToolRefsFromNames(t.Export) + if err != nil { + return nil, err + } + + for _, exportRef := range exportRefs { + result.Add(exportRef) + result.AddAll(prg.ToolSet[exportRef.ToolID].GetExportedTools(prg)) } - return result, nil + return result.List() } -func getTool(prg Program, parent Tool, name string) (Tool, error) { - toolID, ok := parent.ToolMapping[name] - if !ok { - return Tool{}, &ErrToolNotFound{ - ToolName: name, - } +func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { + result := &toolRefSet{} + + contextRefs, err := t.GetToolRefsFromNames(t.Context) + if err != nil { + return nil, err } - tool, ok := prg.ToolSet[toolID] - if !ok { - return Tool{}, &ErrToolNotFound{ - ToolName: name, - } + + for _, contextRef := range contextRefs { + result.AddAll(prg.ToolSet[contextRef.ToolID].GetExportedContext(prg)) + result.Add(contextRef) } - return tool, nil + + return result.List() } -func appendExports(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) { - subTool, err := getTool(prg, parentTool, subToolName) +func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) { + refs, err := t.getCompletionToolRefs(prg) if err != nil { return nil, err } + return toolRefsToCompletionTools(refs, prg), nil +} + +func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error { + subToolRefs, err := t.GetToolRefsFromNames(t.Parameters.Tools) + if err != nil { + return err + } - for _, export := range subTool.Export { - completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "") - if err != nil { - return nil, err - } + for _, subToolRef := range subToolRefs { + // Add the tool + result.Add(subToolRef) + + // Get all tools exports + result.AddAll(prg.ToolSet[subToolRef.ToolID].GetExportedTools(prg)) } - return completionTools, nil + return nil } -func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}, asName string) ([]CompletionTool, error) { - subTool, err := getTool(prg, parentTool, subToolName) +func (t Tool) addContextExportedTools(prg Program, result *toolRefSet) error { + contextTools, err := t.GetContextTools(prg) if err != nil { + return err + } + + for _, contextTool := range contextTools { + result.AddAll(prg.ToolSet[contextTool.ToolID].GetExportedTools(prg)) + } + + return nil +} + +func (t Tool) getCompletionToolRefs(prg Program) ([]ToolReference, error) { + result := toolRefSet{} + + if err := t.addReferencedTools(prg, &result); err != nil { return nil, err } - args := subTool.Parameters.Arguments - if args == nil && !subTool.IsCommand() && !subTool.Chat { - args = &system.DefaultToolSchema + if err := t.addContextExportedTools(prg, &result); err != nil { + return nil, err } - for _, existingTool := range completionTools { - if existingTool.Function.ToolID == subTool.ID { - return completionTools, nil + return result.List() +} + +func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) { + toolNames := map[string]struct{}{} + + for _, subToolRef := range completionTools { + subTool := prg.ToolSet[subToolRef.ToolID] + + subToolName := subToolRef.Reference + if subToolRef.Named != "" { + subToolName = subToolRef.Named } - } - if subTool.Instructions == "" { - log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) - } else { - name := subToolName - if asName != "" { - name = asName + args := subTool.Parameters.Arguments + if args == nil && !subTool.IsCommand() && !subTool.Chat { + args = &system.DefaultToolSchema } - completionTools = append(completionTools, CompletionTool{ - Function: CompletionFunctionDefinition{ - ToolID: subTool.ID, - Name: PickToolName(name, toolNames), - Description: subTool.Parameters.Description, - Parameters: args, - }, - }) - } - for _, export := range subTool.Export { - completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "") - if err != nil { - return nil, err + if subTool.Instructions == "" { + log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) + } else { + result = append(result, CompletionTool{ + Function: CompletionFunctionDefinition{ + ToolID: subTool.ID, + Name: PickToolName(subToolName, toolNames), + Description: subTool.Parameters.Description, + Parameters: args, + }, + }) } } - return completionTools, nil + return } type Repo struct {