From b8071a888f1bd376da00ffc911db5ec5fce89f2d Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 9 Aug 2024 21:45:38 -0400 Subject: [PATCH] fix: use the default model provider when listing models Signed-off-by: Donnie Adams --- pkg/cli/gptscript.go | 3 +++ pkg/sdkserver/routes.go | 4 ++++ pkg/server/server.go | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 4926c6fd..2d7e90d9 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -406,6 +406,9 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { defer gptScript.Close(true) if r.ListModels { + if r.DefaultModelProvider != "" { + args = append(args, r.DefaultModelProvider) + } return r.listModels(ctx, gptScript, args) } diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index c0d7a41b..98957624 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -127,6 +127,10 @@ func (s *server) listModels(w http.ResponseWriter, r *http.Request) { providers = reqObject.Providers } + if s.gptscriptOpts.DefaultModelProvider != "" { + providers = append(providers, s.gptscriptOpts.DefaultModelProvider) + } + out, err := s.client.ListModels(r.Context(), providers...) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list models: %w", err)) diff --git a/pkg/server/server.go b/pkg/server/server.go index d1c57100..00734c38 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -24,5 +24,6 @@ func ContextWithNewRunID(ctx context.Context) context.Context { } func RunIDFromContext(ctx context.Context) string { - return ctx.Value(execKey{}).(string) + runID, _ := ctx.Value(execKey{}).(string) + return runID }