From 6aeb06b88cf9e5de49111b2c6d5c0e3b1a8f76ce Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 1 Apr 2024 16:14:36 -0400 Subject: [PATCH 1/2] fix: various fixes for openapi support Signed-off-by: Grant Linville --- pkg/loader/loader.go | 21 ++++++++++++++------- pkg/loader/openapi.go | 29 +++++++++++++++++++---------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 85ec985f..41248957 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -131,7 +131,11 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN var tools []types.Tool if isOpenAPI(data) { if t, err := openapi3.NewLoader().LoadFromData(data); err == nil { - tools, err = getOpenAPITools(t) + if base.Remote { + tools, err = getOpenAPITools(t, base.Location) + } else { + tools, err = getOpenAPITools(t, "") + } if err != nil { return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err) } @@ -256,6 +260,9 @@ func ProgramFromSource(ctx context.Context, content, subToolName string) (types. } func Program(ctx context.Context, name, subToolName string) (types.Program, error) { + if subToolName == "" { + name, subToolName = SplitToolRef(name) + } prg := types.Program{ Name: name, ToolSet: types.ToolSet{}, @@ -306,7 +313,7 @@ func input(ctx context.Context, base *source, name string) (*source, error) { } func SplitToolRef(targetToolName string) (toolName, subTool string) { - subTool, toolName, ok := strings.Cut(targetToolName, " from ") + subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ") if ok { toolName = strings.TrimSpace(toolName) subTool = strings.TrimSpace(subTool) @@ -318,14 +325,14 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) { } func isOpenAPI(data []byte) bool { - var version struct { - OpenAPI string `json:"openapi" yaml:"openapi"` + var fragment struct { + Paths map[string]any `json:"paths,omitempty"` } - if err := json.Unmarshal(data, &version); err != nil { - if err := yaml.Unmarshal(data, &version); err != nil { + if err := json.Unmarshal(data, &fragment); err != nil { + if err := yaml.Unmarshal(data, &fragment); err != nil { return false } } - return strings.HasPrefix(version.OpenAPI, "3.") + return len(fragment.Paths) > 0 } diff --git a/pkg/loader/openapi.go b/pkg/loader/openapi.go index 502bdd94..eefd010f 100644 --- a/pkg/loader/openapi.go +++ b/pkg/loader/openapi.go @@ -3,6 +3,7 @@ package loader import ( "encoding/json" "fmt" + "net/url" "slices" "strings" @@ -15,9 +16,19 @@ import ( // Each operation will become a tool definition. // The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'", // where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct. -func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) { +func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { + // Determine the default server. if len(t.Servers) == 0 { - return nil, fmt.Errorf("no servers found in OpenAPI spec") + if defaultHost != "" { + u, err := url.Parse(defaultHost) + if err != nil { + return nil, fmt.Errorf("invalid default host URL: %w", err) + } + u.Path = "/" + t.Servers = []*openapi3.Server{{URL: u.String()}} + } else { + return nil, fmt.Errorf("no servers found in OpenAPI spec") + } } defaultServer, err := parseServer(t.Servers[0]) if err != nil { @@ -39,6 +50,7 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) { } } + // Generate a tool for each operation. var ( toolNames []string tools []types.Tool @@ -103,10 +115,9 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) { }, } - toolNames = append(toolNames, tool.Parameters.Name) - - // Handle query, path, and header parameters - for _, param := range operation.Parameters { + // Handle query, path, and header parameters, based on the parameters for this operation + // and the parameters for this path. + for _, param := range append(operation.Parameters, pathObj.Parameters...) { arg := param.Value.Schema.Value if arg.Description == "" { @@ -159,10 +170,6 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) { tool.Parameters.Arguments.Properties["requestBodyContent"] = &openapi3.SchemaRef{Value: arg} break } - - if bodyMIME == "" { - return nil, fmt.Errorf("no supported MIME types found for request body in operation %s", operation.OperationID) - } } // See if there is any auth defined for this operation @@ -226,6 +233,8 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) { return nil, err } + // Register + toolNames = append(toolNames, tool.Parameters.Name) tools = append(tools, tool) operationNum++ } From 1a152397783427b168ec03c6bf15aa42ad14dad1 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 1 Apr 2024 16:18:23 -0400 Subject: [PATCH 2/2] skip operations with a request body with no supported MIME type Signed-off-by: Grant Linville --- pkg/loader/openapi.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/loader/openapi.go b/pkg/loader/openapi.go index eefd010f..40411a04 100644 --- a/pkg/loader/openapi.go +++ b/pkg/loader/openapi.go @@ -66,6 +66,7 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { } } + operations: for method, operation := range pathObj.Operations() { // Handle operation-level server override, if one exists operationServer := pathServer @@ -170,6 +171,11 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) { tool.Parameters.Arguments.Properties["requestBodyContent"] = &openapi3.SchemaRef{Value: arg} break } + + if bodyMIME == "" { + // No supported MIME types found, so just skip this operation and move on. + continue operations + } } // See if there is any auth defined for this operation