Skip to content

Commit fbf688e

Browse files
committed
fix: various fixes for openapi support
Signed-off-by: Grant Linville <[email protected]>
1 parent d566dea commit fbf688e

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

pkg/loader/loader.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,11 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN
131131
var tools []types.Tool
132132
if isOpenAPI(data) {
133133
if t, err := openapi3.NewLoader().LoadFromData(data); err == nil {
134-
tools, err = getOpenAPITools(t)
134+
if base.Remote {
135+
tools, err = getOpenAPITools(t, base.Location)
136+
} else {
137+
tools, err = getOpenAPITools(t, "")
138+
}
135139
if err != nil {
136140
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
137141
}
@@ -256,6 +260,9 @@ func ProgramFromSource(ctx context.Context, content, subToolName string) (types.
256260
}
257261

258262
func Program(ctx context.Context, name, subToolName string) (types.Program, error) {
263+
if subToolName == "" {
264+
name, subToolName = SplitToolRef(name)
265+
}
259266
prg := types.Program{
260267
Name: name,
261268
ToolSet: types.ToolSet{},
@@ -306,7 +313,7 @@ func input(ctx context.Context, base *source, name string) (*source, error) {
306313
}
307314

308315
func SplitToolRef(targetToolName string) (toolName, subTool string) {
309-
subTool, toolName, ok := strings.Cut(targetToolName, " from ")
316+
subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ")
310317
if ok {
311318
toolName = strings.TrimSpace(toolName)
312319
subTool = strings.TrimSpace(subTool)
@@ -318,14 +325,14 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) {
318325
}
319326

320327
func isOpenAPI(data []byte) bool {
321-
var version struct {
322-
OpenAPI string `json:"openapi" yaml:"openapi"`
328+
var fragment struct {
329+
Paths map[string]any `json:"paths,omitempty"`
323330
}
324331

325-
if err := json.Unmarshal(data, &version); err != nil {
326-
if err := yaml.Unmarshal(data, &version); err != nil {
332+
if err := json.Unmarshal(data, &fragment); err != nil {
333+
if err := yaml.Unmarshal(data, &fragment); err != nil {
327334
return false
328335
}
329336
}
330-
return strings.HasPrefix(version.OpenAPI, "3.")
337+
return len(fragment.Paths) > 0
331338
}

pkg/loader/openapi.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package loader
33
import (
44
"encoding/json"
55
"fmt"
6+
"net/url"
67
"slices"
78
"strings"
89

@@ -15,15 +16,26 @@ import (
1516
// Each operation will become a tool definition.
1617
// The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'",
1718
// where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct.
18-
func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
19+
func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) {
20+
// Determine the default server.
1921
if len(t.Servers) == 0 {
20-
return nil, fmt.Errorf("no servers found in OpenAPI spec")
22+
if defaultHost != "" {
23+
u, err := url.Parse(defaultHost)
24+
if err != nil {
25+
return nil, fmt.Errorf("invalid default host URL: %w", err)
26+
}
27+
u.Path = "/"
28+
t.Servers = []*openapi3.Server{{URL: u.String()}}
29+
} else {
30+
return nil, fmt.Errorf("no servers found in OpenAPI spec")
31+
}
2132
}
2233
defaultServer, err := parseServer(t.Servers[0])
2334
if err != nil {
2435
return nil, err
2536
}
2637

38+
// Generate a tool for each operation.
2739
var (
2840
toolNames []string
2941
tools []types.Tool
@@ -80,10 +92,9 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
8092
},
8193
}
8294

83-
toolNames = append(toolNames, tool.Parameters.Name)
84-
85-
// Handle query, path, and header parameters
86-
for _, param := range operation.Parameters {
95+
// Handle query, path, and header parameters, based on the parameters for this operation
96+
// and the parameters for this path.
97+
for _, param := range append(operation.Parameters, pathObj.Parameters...) {
8798
arg := param.Value.Schema.Value
8899

89100
if arg.Description == "" {
@@ -136,10 +147,6 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
136147
tool.Parameters.Arguments.Properties["requestBodyContent"] = &openapi3.SchemaRef{Value: arg}
137148
break
138149
}
139-
140-
if bodyMIME == "" {
141-
return nil, fmt.Errorf("no supported MIME types found for request body in operation %s", operation.OperationID)
142-
}
143150
}
144151

145152
// OpenAI will get upset if we have an object schema with no properties,
@@ -154,6 +161,8 @@ func getOpenAPITools(t *openapi3.T) ([]types.Tool, error) {
154161
return nil, err
155162
}
156163

164+
// Register
165+
toolNames = append(toolNames, tool.Parameters.Name)
157166
tools = append(tools, tool)
158167
operationNum++
159168
}

0 commit comments

Comments
 (0)