diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index c00308e7..c4178801 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -79,7 +79,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, req.Env, req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -147,7 +147,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, req.Env, req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -207,7 +207,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, req.Env, req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -270,7 +270,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, req.Env, req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index c01eeb21..dfad4a18 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -28,6 +28,7 @@ type server struct { gptscriptOpts gptscript.Options address, token string datasetTool, workspaceTool string + serverToolsEnv []string client *gptscript.GPTScript events *broadcaster.Broadcaster[event] diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index 04bff085..79d6daf7 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -28,6 +28,7 @@ type Options struct { ListenAddress string DatasetTool, WorkspaceTool string + ServerToolsEnv []string Debug bool DisableServerErrorLogging bool } @@ -105,11 +106,13 @@ func run(ctx context.Context, listener net.Listener, opts Options) error { } s := &server{ - gptscriptOpts: opts.Options, - address: listener.Addr().String(), - token: token, - datasetTool: opts.DatasetTool, - workspaceTool: opts.WorkspaceTool, + gptscriptOpts: opts.Options, + address: listener.Addr().String(), + token: token, + datasetTool: opts.DatasetTool, + workspaceTool: opts.WorkspaceTool, + serverToolsEnv: opts.ServerToolsEnv, + client: g, events: events, runtimeManager: runtimes.Default(opts.Options.Cache.CacheDir, opts.SystemToolsDir), @@ -176,6 +179,9 @@ func complete(opts ...Options) Options { if result.DatasetTool == "" { result.DatasetTool = "github.com/gptscript-ai/datasets" } + if len(result.ServerToolsEnv) == 0 { + result.ServerToolsEnv = os.Environ() + } return result } diff --git a/pkg/sdkserver/workspaces.go b/pkg/sdkserver/workspaces.go index ed6602ea..f0d7ef00 100644 --- a/pkg/sdkserver/workspaces.go +++ b/pkg/sdkserver/workspaces.go @@ -30,6 +30,10 @@ type createWorkspaceRequest struct { FromWorkspaceIDs []string `json:"fromWorkspaceIDs"` } +func (s *server) getServerToolsEnv(env []string) []string { + return append(s.serverToolsEnv, env...) +} + func (s *server) createWorkspace(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) var reqObject createWorkspaceRequest @@ -51,7 +55,7 @@ func (s *server) createWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"provider": "%s", "workspace_ids": "%s"}`, reqObject.ProviderType, strings.Join(reqObject.FromWorkspaceIDs, ","), @@ -86,7 +90,7 @@ func (s *server) deleteWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s"}`, reqObject.ID, @@ -123,7 +127,7 @@ func (s *server) listWorkspaceContents(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "ls_prefix": "%s"}`, reqObject.ID, reqObject.Prefix, @@ -159,7 +163,7 @@ func (s *server) removeAllWithPrefixInWorkspace(w http.ResponseWriter, r *http.R out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "prefix": "%s"}`, reqObject.ID, reqObject.Prefix, @@ -196,7 +200,7 @@ func (s *server) writeFileInWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "file_path": "%s", "body": "%s"}`, reqObject.ID, reqObject.FilePath, reqObject.Contents, @@ -232,7 +236,7 @@ func (s *server) removeFileInWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath, @@ -268,7 +272,7 @@ func (s *server) readFileInWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath, @@ -304,7 +308,7 @@ func (s *server) statFileInWorkspace(w http.ResponseWriter, r *http.Request) { out, err := s.client.Run( r.Context(), prg, - reqObject.Env, + s.getServerToolsEnv(reqObject.Env), fmt.Sprintf( `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath,