diff --git a/pkg/sdkserver/prompt.go b/pkg/sdkserver/prompt.go index a5b15162..8d34fc53 100644 --- a/pkg/sdkserver/prompt.go +++ b/pkg/sdkserver/prompt.go @@ -43,6 +43,11 @@ func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) { func (s *server) prompt(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) + if r.Header.Get("Authorization") != "Bearer "+s.token { + writeError(logger, w, http.StatusUnauthorized, fmt.Errorf("invalid token")) + return + } + id := r.PathValue("id") s.lock.RLock() diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 262b43ed..c049a791 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -27,9 +27,9 @@ import ( const toolRunTimeout = 15 * time.Minute type server struct { - address string - client *gptscript.GPTScript - events *broadcaster.Broadcaster[event] + address, token string + client *gptscript.GPTScript + events *broadcaster.Broadcaster[event] lock sync.RWMutex waitingToConfirm map[string]chan runner.AuthorizerResponse @@ -165,9 +165,9 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { reqObject.Env = append(os.Environ(), reqObject.Env...) // Don't overwrite the PromptURLEnvVar if it is already set in the environment. - if !slices.ContainsFunc(reqObject.Env, func(s string) bool { return strings.HasPrefix(s, types.PromptURLEnvVar+"=") }) { + if !slices.ContainsFunc(reqObject.Env, func(s string) bool { return strings.HasPrefix(s, types.PromptTokenEnvVar+"=") }) { // Append a prompt URL for this run. - reqObject.Env = append(reqObject.Env, fmt.Sprintf("%s=http://%s/prompt/%s", types.PromptURLEnvVar, s.address, runID)) + reqObject.Env = append(reqObject.Env, fmt.Sprintf("%s=http://%s/prompt/%s", types.PromptURLEnvVar, s.address, runID), fmt.Sprintf("%s=%s", types.PromptTokenEnvVar, s.token)) } logger.Debugf("executing tool: %+v", reqObject) diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index 1e1cacc1..1433606a 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -13,6 +13,7 @@ import ( "time" "github.com/acorn-io/broadcaster" + "github.com/google/uuid" "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/runner" @@ -52,6 +53,7 @@ func Start(ctx context.Context, opts Options) error { s := &server{ address: opts.ListenAddress, + token: uuid.NewString(), client: g, events: events, waitingToConfirm: make(map[string]chan runner.AuthorizerResponse),