Skip to content

fix: various fixes for openapi support #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN
var tools []types.Tool
if isOpenAPI(data) {
if t, err := openapi3.NewLoader().LoadFromData(data); err == nil {
tools, err = getOpenAPITools(t)
if base.Remote {
tools, err = getOpenAPITools(t, base.Location)
} else {
tools, err = getOpenAPITools(t, "")
}
if err != nil {
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
}
Expand Down Expand Up @@ -256,6 +260,9 @@ func ProgramFromSource(ctx context.Context, content, subToolName string) (types.
}

func Program(ctx context.Context, name, subToolName string) (types.Program, error) {
if subToolName == "" {
name, subToolName = SplitToolRef(name)
}
prg := types.Program{
Name: name,
ToolSet: types.ToolSet{},
Expand Down Expand Up @@ -306,7 +313,7 @@ func input(ctx context.Context, base *source, name string) (*source, error) {
}

func SplitToolRef(targetToolName string) (toolName, subTool string) {
subTool, toolName, ok := strings.Cut(targetToolName, " from ")
subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ")
if ok {
toolName = strings.TrimSpace(toolName)
subTool = strings.TrimSpace(subTool)
Expand All @@ -318,14 +325,14 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) {
}

func isOpenAPI(data []byte) bool {
var version struct {
OpenAPI string `json:"openapi" yaml:"openapi"`
var fragment struct {
Paths map[string]any `json:"paths,omitempty"`
Comment on lines +328 to +329
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to check for the existence of the paths map, rather than the version string, since the version string is sometimes missing.

}

if err := json.Unmarshal(data, &version); err != nil {
if err := yaml.Unmarshal(data, &version); err != nil {
if err := json.Unmarshal(data, &fragment); err != nil {
if err := yaml.Unmarshal(data, &fragment); err != nil {
return false
}
}
return strings.HasPrefix(version.OpenAPI, "3.")
return len(fragment.Paths) > 0
}
29 changes: 22 additions & 7 deletions pkg/loader/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package loader
import (
"encoding/json"
"fmt"
"net/url"
"slices"
"strings"

Expand All @@ -15,9 +16,19 @@ import (
// Each operation will become a tool definition.
// The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'",
// where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct.
func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) {
// Determine the default server.
if len(t.Servers) == 0 {
return nil, fmt.Errorf("no servers found in OpenAPI spec")
if defaultHost != "" {
u, err := url.Parse(defaultHost)
if err != nil {
return nil, fmt.Errorf("invalid default host URL: %w", err)
}
u.Path = "/"
t.Servers = []*openapi3.Server{{URL: u.String()}}
} else {
return nil, fmt.Errorf("no servers found in OpenAPI spec")
}
Comment on lines +19 to +31
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds support for a "default host" if no server is specified at the global level in the OpenAPI document.

}
defaultServer, err := parseServer(t.Servers[0])
if err != nil {
Expand All @@ -39,6 +50,7 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
}
}

// Generate a tool for each operation.
var (
toolNames []string
tools []types.Tool
Expand All @@ -54,6 +66,7 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
}
}

operations:
for method, operation := range pathObj.Operations() {
// Handle operation-level server override, if one exists
operationServer := pathServer
Expand Down Expand Up @@ -103,10 +116,9 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
},
}

toolNames = append(toolNames, tool.Parameters.Name)

// Handle query, path, and header parameters
for _, param := range operation.Parameters {
// Handle query, path, and header parameters, based on the parameters for this operation
// and the parameters for this path.
for _, param := range append(operation.Parameters, pathObj.Parameters...) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters can also be defined at the path level (pathObj.Parameters), so we are now handling those as well.

arg := param.Value.Schema.Value

if arg.Description == "" {
Expand Down Expand Up @@ -161,7 +173,8 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
}

if bodyMIME == "" {
return nil, fmt.Errorf("no supported MIME types found for request body in operation %s", operation.OperationID)
// No supported MIME types found, so just skip this operation and move on.
continue operations
Comment on lines 175 to +177
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now skip tools with no supported MIME types in the body instead of just returning an error for the whole document.

}
}

Expand Down Expand Up @@ -226,6 +239,8 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
return nil, err
}

// Register
toolNames = append(toolNames, tool.Parameters.Name)
tools = append(tools, tool)
operationNum++
}
Expand Down