diff options
author | Runxi Yu <me@runxiyu.org> | 2025-08-17 04:48:47 +0800 |
---|---|---|
committer | Runxi Yu <me@runxiyu.org> | 2025-08-17 05:22:09 +0800 |
commit | b6dd3ffce416fa86c223fad5d2f6c3db5d5727e4 (patch) | |
tree | f3cc533b1ef5184ee881d9d0103a47ec7f09bd34 /forged/internal | |
parent | Add shutdown timeouts (diff) | |
download | forge-b6dd3ffce416fa86c223fad5d2f6c3db5d5727e4.tar.gz forge-b6dd3ffce416fa86c223fad5d2f6c3db5d5727e4.tar.zst forge-b6dd3ffce416fa86c223fad5d2f6c3db5d5727e4.zip |
A few other context fixes
Diffstat (limited to 'forged/internal')
-rw-r--r-- | forged/internal/incoming/hooks/hooks.go | 21 | ||||
-rw-r--r-- | forged/internal/incoming/lmtp/config.go | 21 | ||||
-rw-r--r-- | forged/internal/incoming/ssh/ssh.go | 13 | ||||
-rw-r--r-- | forged/internal/incoming/web/web.go | 16 | ||||
-rw-r--r-- | forged/internal/server/server.go | 54 |
5 files changed, 55 insertions, 70 deletions
diff --git a/forged/internal/incoming/hooks/hooks.go b/forged/internal/incoming/hooks/hooks.go index 3be0811..52ccb0f 100644 --- a/forged/internal/incoming/hooks/hooks.go +++ b/forged/internal/incoming/hooks/hooks.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/gliderlabs/ssh" "go.lindenii.runxiyu.org/forge/forged/internal/common/cmap" @@ -51,25 +52,29 @@ func (server *Server) Run(ctx context.Context) error { _ = listener.Close() }() - go func() { - <-ctx.Done() + stop := context.AfterFunc(ctx, func() { _ = listener.Close() - // TODO: Log the error - }() + }) + defer stop() for { conn, err := listener.Accept() if err != nil { - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { return nil } return fmt.Errorf("accept conn: %w", err) } - go server.handleConn(conn) + go server.handleConn(ctx, conn) } } -func (server *Server) handleConn(conn net.Conn) { - panic("TODO: handle hook connection") +func (server *Server) handleConn(ctx context.Context, conn net.Conn) { + defer conn.Close() + unblock := context.AfterFunc(ctx, func() { + _ = conn.SetDeadline(time.Now()) + _ = conn.Close() + }) + defer unblock() } diff --git a/forged/internal/incoming/lmtp/config.go b/forged/internal/incoming/lmtp/config.go index ce32f3d..def3ce9 100644 --- a/forged/internal/incoming/lmtp/config.go +++ b/forged/internal/incoming/lmtp/config.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" ) @@ -44,25 +45,29 @@ func (server *Server) Run(ctx context.Context) error { _ = listener.Close() }() - go func() { - <-ctx.Done() + stop := context.AfterFunc(ctx, func() { _ = listener.Close() - // TODO: Log the error - }() + }) + defer stop() for { conn, err := listener.Accept() if err != nil { - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { return nil } return fmt.Errorf("accept conn: %w", err) } - go server.handleConn(conn) + go server.handleConn(ctx, conn) } } -func (server *Server) handleConn(conn net.Conn) { - panic("TODO: handle LMTP connection") +func (server *Server) handleConn(ctx context.Context, conn net.Conn) { + defer conn.Close() + unblock := context.AfterFunc(ctx, func() { + _ = conn.SetDeadline(time.Now()) + _ = conn.Close() + }) + defer unblock() } diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go index 9f9bdff..0c722c0 100644 --- a/forged/internal/incoming/ssh/ssh.go +++ b/forged/internal/incoming/ssh/ssh.go @@ -61,20 +61,23 @@ func New(config Config) (server *Server, err error) { func (server *Server) Run(ctx context.Context) (err error) { listener, err := misc.Listen(server.net, server.addr) + if err != nil { + return fmt.Errorf("listen for SSH: %w", err) + } defer func() { _ = listener.Close() }() - go func() { - <-ctx.Done() - shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + stop := context.AfterFunc(ctx, func() { + shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second) defer cancel() _ = server.gliderServer.Shutdown(shCtx) _ = listener.Close() - }() + }) + defer stop() if err = server.gliderServer.Serve(listener); err != nil { - if errors.Is(err, gliderssh.ErrServerClosed) { + if errors.Is(err, gliderssh.ErrServerClosed) || ctx.Err() != nil { return nil } return fmt.Errorf("serve SSH: %w", err) diff --git a/forged/internal/incoming/web/web.go b/forged/internal/incoming/web/web.go index 391f6ff..dc2d9b4 100644 --- a/forged/internal/incoming/web/web.go +++ b/forged/internal/incoming/web/web.go @@ -3,6 +3,7 @@ package web import ( "context" "fmt" + "net" "net/http" "time" @@ -53,21 +54,26 @@ func New(config Config) (server *Server) { } func (server *Server) Run(ctx context.Context) (err error) { + server.httpServer.BaseContext = func(_ net.Listener) context.Context { return ctx } + listener, err := misc.Listen(server.net, server.addr) + if err != nil { + return fmt.Errorf("listen for web: %w", err) + } defer func() { _ = listener.Close() }() - go func() { - <-ctx.Done() - shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + stop := context.AfterFunc(ctx, func() { + shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second) defer cancel() _ = server.httpServer.Shutdown(shCtx) _ = listener.Close() - }() + }) + defer stop() if err = server.httpServer.Serve(listener); err != nil { - if err == http.ErrServerClosed { + if err == http.ErrServerClosed || ctx.Err() != nil { return nil } return fmt.Errorf("serve web: %w", err) diff --git a/forged/internal/server/server.go b/forged/internal/server/server.go index 472df7a..ab677e0 100644 --- a/forged/internal/server/server.go +++ b/forged/internal/server/server.go @@ -10,6 +10,7 @@ import ( "go.lindenii.runxiyu.org/forge/forged/internal/incoming/lmtp" "go.lindenii.runxiyu.org/forge/forged/internal/incoming/ssh" "go.lindenii.runxiyu.org/forge/forged/internal/incoming/web" + "golang.org/x/sync/errgroup" ) type Server struct { @@ -51,57 +52,22 @@ func (server *Server) Run(ctx context.Context) (err error) { // TODO: Not running git2d because it should be run separately. // This needs to be documented somewhere, hence a TODO here for now. - subCtx, cancel := context.WithCancel(ctx) - defer cancel() + g, gctx := errgroup.WithContext(ctx) - server.database, err = database.Open(subCtx, server.config.DB) + server.database, err = database.Open(gctx, server.config.DB) if err != nil { return fmt.Errorf("open database: %w", err) } + defer server.database.Close() - errCh := make(chan error) + g.Go(func() error { return server.hookServer.Run(gctx) }) + g.Go(func() error { return server.lmtpServer.Run(gctx) }) + g.Go(func() error { return server.webServer.Run(gctx) }) + g.Go(func() error { return server.sshServer.Run(gctx) }) - go func() { - if err := server.hookServer.Run(subCtx); err != nil { - select { - case errCh <- err: - default: - } - } - }() - - go func() { - if err := server.lmtpServer.Run(subCtx); err != nil { - select { - case errCh <- err: - default: - } - } - }() - - go func() { - if err := server.webServer.Run(subCtx); err != nil { - select { - case errCh <- err: - default: - } - } - }() - - go func() { - if err := server.sshServer.Run(subCtx); err != nil { - select { - case errCh <- err: - default: - } - } - }() - - select { - case err := <-errCh: + if err := g.Wait(); err != nil { return fmt.Errorf("server error: %w", err) - case <-ctx.Done(): } - return nil + return ctx.Err() } |