diff --git a/pkg/engine/openapi.go b/pkg/engine/openapi.go index e872b7d4..9af28772 100644 --- a/pkg/engine/openapi.go +++ b/pkg/engine/openapi.go @@ -13,6 +13,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/tidwall/gjson" + "golang.org/x/exp/maps" ) var ( @@ -35,6 +36,45 @@ type SecurityInfo struct { In string `json:"in"` // header, query, or cookie, for type==apiKey } +func (i SecurityInfo) GetCredentialToolStrings(hostname string) []string { + vars := i.getCredentialNamesAndEnvVars(hostname) + var tools []string + + for cred, v := range vars { + field := "value" + switch i.Type { + case "apiKey": + field = i.APIKeyName + case "http": + if i.Scheme == "bearer" { + field = "bearer token" + } else { + if strings.Contains(v, "PASSWORD") { + field = "password" + } else { + field = "username" + } + } + } + + tools = append(tools, fmt.Sprintf("github.com/gptscript-ai/credential as %s with %s as env and %q as message and %q as field", + cred, v, "Please provide a value for the "+v+" environment variable", field)) + } + return tools +} + +func (i SecurityInfo) getCredentialNamesAndEnvVars(hostname string) map[string]string { + if i.Type == "http" && i.Scheme == "basic" { + return map[string]string{ + hostname + i.Name + "Username": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_USERNAME", + hostname + i.Name + "Password": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_PASSWORD", + } + } + return map[string]string{ + hostname + i.Name: "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name), + } +} + type OpenAPIInstructions struct { Server string `json:"server"` Path string `json:"path"` @@ -83,8 +123,8 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { return nil, fmt.Errorf("failed to create request: %w", err) } - // Check for authentication (only if using HTTPS) - if u.Scheme == "https" { + // Check for authentication (only if using HTTPS or localhost) + if u.Scheme == "https" || u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" { if len(instructions.SecurityInfos) > 0 { if err := handleAuths(req, envMap, instructions.SecurityInfos); err != nil { return nil, fmt.Errorf("error setting up authentication: %w", err) @@ -181,15 +221,9 @@ func handleAuths(req *http.Request, envMap map[string]string, infoSets [][]Secur for _, infoSet := range infoSets { var missing []string // Keep track of any missing environment variables for _, info := range infoSet { - envNames := []string{"GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name)} - if info.Type == "http" && info.Scheme == "basic" { - envNames = []string{ - "GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name) + "_USERNAME", - "GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name) + "_PASSWORD", - } - } + vars := info.getCredentialNamesAndEnvVars(req.URL.Hostname()) - for _, envName := range envNames { + for _, envName := range vars { if _, ok := envMap[envName]; !ok { missing = append(missing, envName) } @@ -203,28 +237,28 @@ func handleAuths(req *http.Request, envMap map[string]string, infoSets [][]Secur // We're using this info set, because no environment variables were missing. // Set up the request as needed. for _, info := range infoSet { - envName := "GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name) + envNames := maps.Values(info.getCredentialNamesAndEnvVars(req.URL.Hostname())) switch info.Type { case "apiKey": switch info.In { case "header": - req.Header.Set(info.APIKeyName, envMap[envName]) + req.Header.Set(info.APIKeyName, envMap[envNames[0]]) case "query": v := url.Values{} - v.Add(info.APIKeyName, envMap[envName]) + v.Add(info.APIKeyName, envMap[envNames[0]]) req.URL.RawQuery = v.Encode() case "cookie": req.AddCookie(&http.Cookie{ Name: info.APIKeyName, - Value: envMap[envName], + Value: envMap[envNames[0]], }) } case "http": switch info.Scheme { case "bearer": - req.Header.Set("Authorization", "Bearer "+envMap[envName]) + req.Header.Set("Authorization", "Bearer "+envMap[envNames[0]]) case "basic": - req.SetBasicAuth(envMap[envName+"_USERNAME"], envMap[envName+"_PASSWORD"]) + req.SetBasicAuth(envMap[envNames[0]], envMap[envNames[1]]) } } } diff --git a/pkg/loader/openapi.go b/pkg/loader/openapi.go index 47cee122..8790fbb7 100644 --- a/pkg/loader/openapi.go +++ b/pkg/loader/openapi.go @@ -278,6 +278,17 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { return nil, err } + if len(infos) > 0 { + // Set up credential tools for the first set of infos. + for _, info := range infos[0] { + operationServerURL, err := url.Parse(operationServer) + if err != nil { + return nil, fmt.Errorf("failed to parse operation server URL: %w", err) + } + tool.Credentials = info.GetCredentialToolStrings(operationServerURL.Hostname()) + } + } + // Register toolNames = append(toolNames, tool.Parameters.Name) tools = append(tools, tool) diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 2c218e2c..9ae7c07f 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -411,7 +411,9 @@ func (t ToolDef) String() string { _, _ = fmt.Fprintf(buf, "Internal Prompt: %v\n", *t.Parameters.InternalPrompt) } if len(t.Parameters.Credentials) > 0 { - _, _ = fmt.Fprintf(buf, "Credentials: %s\n", strings.Join(t.Parameters.Credentials, ", ")) + for _, cred := range t.Parameters.Credentials { + _, _ = fmt.Fprintf(buf, "Credential: %s\n", cred) + } } if t.Parameters.Chat { _, _ = fmt.Fprintf(buf, "Chat: true\n") diff --git a/pkg/types/tool_test.go b/pkg/types/tool_test.go index 514de1c1..f1edc340 100644 --- a/pkg/types/tool_test.go +++ b/pkg/types/tool_test.go @@ -50,7 +50,8 @@ Temperature: 0.800000 Parameter: arg1: desc1 Parameter: arg2: desc2 Internal Prompt: true -Credentials: Credential1, Credential2 +Credential: Credential1 +Credential: Credential2 Chat: true This is a sample instruction