diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index c12f976f..f60f09a1 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -373,12 +373,20 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. } opt := complete(opts...) + var locationPath, locationName string + if opt.Location != "" { + locationPath = path.Dir(opt.Location) + locationName = path.Base(opt.Location) + } + prg := types.Program{ ToolSet: types.ToolSet{}, } tools, err := readTool(ctx, opt.Cache, &prg, &source{ Content: []byte(content), - Location: "inline", + Path: locationPath, + Name: locationName, + Location: opt.Location, }, subToolName) if err != nil { return types.Program{}, err @@ -388,12 +396,18 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. } type Options struct { - Cache *cache.Client + Cache *cache.Client + Location string } func complete(opts ...Options) (result Options) { for _, opt := range opts { result.Cache = types.FirstSet(opt.Cache, result.Cache) + result.Location = types.FirstSet(opt.Location, result.Location) + } + + if result.Location == "" { + result.Location = "inline" } return diff --git a/pkg/loader/url.go b/pkg/loader/url.go index bc4d5c9f..2035469e 100644 --- a/pkg/loader/url.go +++ b/pkg/loader/url.go @@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string req.Header.Set("Authorization", "Bearer "+bearerToken) } - data, err := getWithDefaults(req) + data, defaulted, err := getWithDefaults(req) if err != nil { return nil, false, fmt.Errorf("error loading %s: %v", url, err) } + if defaulted != "" { + pathString = url + name = defaulted + if repo != nil { + repo.Path = path.Join(repo.Path, repo.Name) + repo.Name = defaulted + } + } + log.Debugf("opened %s", url) result := &source{ @@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string return result, true, nil } -func getWithDefaults(req *http.Request) ([]byte, error) { +func getWithDefaults(req *http.Request) ([]byte, string, error) { originalPath := req.URL.Path // First, try to get the original path as is. It might be an OpenAPI definition. resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, "", err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 { - return toolBytes, nil - } + toolBytes, err := io.ReadAll(resp.Body) + return toolBytes, "", err + } + + base := path.Base(originalPath) + if strings.Contains(base, ".") { + return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) } for i, def := range types.DefaultFiles { - base := path.Base(originalPath) - if !strings.Contains(base, ".") { - req.URL.Path = path.Join(originalPath, def) - } + req.URL.Path = path.Join(originalPath, def) resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return nil, "", err } defer resp.Body.Close() @@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) { } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) + return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status) } - return io.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) + return data, def, err } + panic("unreachable") } diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index c16b4429..e17a2d1a 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -183,7 +183,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { logger.Debugf("executing tool: %+v", reqObject) var ( def fmt.Stringer = &reqObject.ToolDefs - programLoader loaderFunc = loader.ProgramFromSource + programLoader = loaderWithLocation(loader.ProgramFromSource, reqObject.Location) ) if reqObject.Content != "" { def = &reqObject.content diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index dc155557..0d055614 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -16,6 +16,14 @@ import ( type loaderFunc func(context.Context, string, string, ...loader.Options) (types.Program, error) +func loaderWithLocation(f loaderFunc, loc string) loaderFunc { + return func(ctx context.Context, s string, s2 string, options ...loader.Options) (types.Program, error) { + return f(ctx, s, s2, append(options, loader.Options{ + Location: loc, + })...) + } +} + func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) { g, err := gptscript.New(ctx, s.gptscriptOpts, opts) if err != nil { diff --git a/pkg/sdkserver/types.go b/pkg/sdkserver/types.go index 6f940c8b..478c6565 100644 --- a/pkg/sdkserver/types.go +++ b/pkg/sdkserver/types.go @@ -61,6 +61,7 @@ type toolOrFileRequest struct { CredentialContext string `json:"credentialContext"` CredentialOverrides []string `json:"credentialOverrides"` Confirm bool `json:"confirm"` + Location string `json:"location,omitempty"` } type content struct {