From 3c53e2aa3f4cae850557ef231c4949fd43750f21 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 19 Jul 2024 21:54:01 -0400 Subject: [PATCH] feat: improve SDK server start up Additionally, this change includes a way to run the server embeddedly in another process that may use stdin. Signed-off-by: Donnie Adams --- pkg/cli/sdk_server.go | 4 +-- pkg/sdkserver/server.go | 74 +++++++++++++++++++++++++++++------------ 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/pkg/cli/sdk_server.go b/pkg/cli/sdk_server.go index a2ac8488..c9cf480f 100644 --- a/pkg/cli/sdk_server.go +++ b/pkg/cli/sdk_server.go @@ -29,11 +29,11 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error { // Don't use cmd.Context() as we don't want to die on ctrl+c ctx := context.Background() if term.IsTerminal(int(os.Stdin.Fd())) { - // Only support CTRL+C if stdin is the terminal. When ran as a SDK it will be a pipe + // Only support CTRL+C if stdin is the terminal. When ran as an SDK it will be a pipe ctx = cmd.Context() } - return sdkserver.Start(ctx, sdkserver.Options{ + return sdkserver.Run(ctx, sdkserver.Options{ Options: opts, ListenAddress: c.ListenAddress, Debug: c.Debug, diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index 4556f69e..d0ca5a5f 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "log/slog" "net" "net/http" "os" @@ -29,7 +28,18 @@ type Options struct { Debug bool } -func Start(ctx context.Context, opts Options) error { +// Run will start the server and block until the server is shut down. +func Run(ctx context.Context, opts Options) error { + listener, err := newListener(opts) + if err != nil { + return err + } + + _, err = io.WriteString(os.Stderr, listener.Addr().String()+"\n") + if err != nil { + return fmt.Errorf("failed to write to address to stderr: %w", err) + } + sigCtx, cancel := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL) defer cancel() go func() { @@ -40,6 +50,34 @@ func Start(ctx context.Context, opts Options) error { cancel() }() + return run(sigCtx, listener, opts) +} + +// EmbeddedStart allows running the server as an embedded process that may use Stdin for input. +// It returns the address the server is listening on. +func EmbeddedStart(ctx context.Context, opts Options) (string, error) { + listener, err := newListener(opts) + if err != nil { + return "", err + } + + go func() { + _ = run(ctx, listener, opts) + }() + + return listener.Addr().String(), nil +} + +func (s *server) close() { + s.client.Close(true) + s.events.Close() +} + +func newListener(opts Options) (net.Listener, error) { + return net.Listen("tcp", opts.ListenAddress) +} + +func run(ctx context.Context, listener net.Listener, opts Options) error { if opts.Debug { mvl.SetDebug() } @@ -58,11 +96,6 @@ func Start(ctx context.Context, opts Options) error { return err } - listener, err := net.Listen("tcp", opts.ListenAddress) - if err != nil { - return fmt.Errorf("failed to listen on %s: %w", opts.ListenAddress, err) - } - s := &server{ gptscriptOpts: opts.Options, address: listener.Addr().String(), @@ -72,11 +105,11 @@ func Start(ctx context.Context, opts Options) error { waitingToConfirm: make(map[string]chan runner.AuthorizerResponse), waitingToPrompt: make(map[string]chan map[string]string), } - defer s.Close() + defer s.close() s.addRoutes(http.DefaultServeMux) - server := http.Server{ + httpServer := &http.Server{ Handler: apply(http.DefaultServeMux, contentType("application/json"), addRequestID, @@ -86,25 +119,22 @@ func Start(ctx context.Context, opts Options) error { ), } - slog.Info("Starting server", "addr", s.address) - - context.AfterFunc(sigCtx, func() { - ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + logger := mvl.Package() + done := make(chan struct{}) + context.AfterFunc(ctx, func() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - slog.Info("Shutting down server") - _ = server.Shutdown(ctx) - slog.Info("Server stopped") + logger.Infof("Shutting down server") + _ = httpServer.Shutdown(ctx) + logger.Infof("Server stopped") + close(done) }) - if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err = httpServer.Serve(listener); !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("server error: %w", err) } + <-done return nil } - -func (s *server) Close() { - s.client.Close(true) - s.events.Close() -}