From 1124ed1c6952c2e8ff033a64a7b5dfc9fc877087 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 16 Aug 2024 16:34:48 -0400 Subject: [PATCH] feat: add ability to list models from other providers Signed-off-by: Donnie Adams --- pkg/sdkserver/routes.go | 41 ++++++++++++----------------------------- pkg/sdkserver/types.go | 4 +++- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 4309bc28..6cb1e620 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -73,39 +73,13 @@ func (s *server) version(w http.ResponseWriter, r *http.Request) { // listTools will return the output of `gptscript --list-tools` func (s *server) listTools(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) - var prg types.Program - if r.ContentLength != 0 { - reqObject := new(toolOrFileRequest) - err := json.NewDecoder(r.Body).Decode(reqObject) - if err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) - return - } - - if reqObject.Content != "" { - prg, err = loader.ProgramFromSource(r.Context(), reqObject.Content, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } else if reqObject.File != "" { - prg, err = loader.Program(r.Context(), reqObject.File, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } else { - prg, err = loader.ProgramFromSource(r.Context(), reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) - return - } - } - - tools := s.client.ListTools(r.Context(), prg) + tools := s.client.ListTools(r.Context(), types.Program{}) sort.Slice(tools, func(i, j int) bool { return tools[i].Name < tools[j].Name }) lines := make([]string, 0, len(tools)) for _, tool := range tools { - if tool.Name == "" { - tool.Name = prg.Name - } - // Don't print instructions tool.Instructions = "" @@ -118,22 +92,31 @@ func (s *server) listTools(w http.ResponseWriter, r *http.Request) { // listModels will return the output of `gptscript --list-models` func (s *server) listModels(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) + client := s.client + var providers []string if r.ContentLength != 0 { reqObject := new(modelsRequest) - if err := json.NewDecoder(r.Body).Decode(reqObject); err != nil { + err := json.NewDecoder(r.Body).Decode(reqObject) + if err != nil { writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) return } providers = reqObject.Providers + + client, err = gptscript.New(r.Context(), s.gptscriptOpts, gptscript.Options{Env: reqObject.Env, Runner: runner.Options{CredentialOverrides: reqObject.CredentialOverrides}}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to create client: %w", err)) + return + } } if s.gptscriptOpts.DefaultModelProvider != "" { providers = append(providers, s.gptscriptOpts.DefaultModelProvider) } - out, err := s.client.ListModels(r.Context(), providers...) + out, err := client.ListModels(r.Context(), providers...) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list models: %w", err)) return diff --git a/pkg/sdkserver/types.go b/pkg/sdkserver/types.go index b24ca645..e26bbba5 100644 --- a/pkg/sdkserver/types.go +++ b/pkg/sdkserver/types.go @@ -100,7 +100,9 @@ type parseRequest struct { } type modelsRequest struct { - Providers []string `json:"providers"` + Providers []string `json:"providers"` + Env []string `json:"env"` + CredentialOverrides []string `json:"credentialOverrides"` } type runInfo struct {