aboutsummaryrefslogtreecommitdiff
path: root/forged
diff options
context:
space:
mode:
Diffstat (limited to 'forged')
-rw-r--r--forged/internal/incoming/hooks/hooks.go21
-rw-r--r--forged/internal/incoming/lmtp/config.go21
-rw-r--r--forged/internal/incoming/ssh/ssh.go13
-rw-r--r--forged/internal/incoming/web/web.go16
-rw-r--r--forged/internal/server/server.go54
5 files changed, 55 insertions, 70 deletions
diff --git a/forged/internal/incoming/hooks/hooks.go b/forged/internal/incoming/hooks/hooks.go
index 3be0811..52ccb0f 100644
--- a/forged/internal/incoming/hooks/hooks.go
+++ b/forged/internal/incoming/hooks/hooks.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
+ "time"
"github.com/gliderlabs/ssh"
"go.lindenii.runxiyu.org/forge/forged/internal/common/cmap"
@@ -51,25 +52,29 @@ func (server *Server) Run(ctx context.Context) error {
_ = listener.Close()
}()
- go func() {
- <-ctx.Done()
+ stop := context.AfterFunc(ctx, func() {
_ = listener.Close()
- // TODO: Log the error
- }()
+ })
+ defer stop()
for {
conn, err := listener.Accept()
if err != nil {
- if errors.Is(err, net.ErrClosed) {
+ if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
return nil
}
return fmt.Errorf("accept conn: %w", err)
}
- go server.handleConn(conn)
+ go server.handleConn(ctx, conn)
}
}
-func (server *Server) handleConn(conn net.Conn) {
- panic("TODO: handle hook connection")
+func (server *Server) handleConn(ctx context.Context, conn net.Conn) {
+ defer conn.Close()
+ unblock := context.AfterFunc(ctx, func() {
+ _ = conn.SetDeadline(time.Now())
+ _ = conn.Close()
+ })
+ defer unblock()
}
diff --git a/forged/internal/incoming/lmtp/config.go b/forged/internal/incoming/lmtp/config.go
index ce32f3d..def3ce9 100644
--- a/forged/internal/incoming/lmtp/config.go
+++ b/forged/internal/incoming/lmtp/config.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
+ "time"
"go.lindenii.runxiyu.org/forge/forged/internal/common/misc"
)
@@ -44,25 +45,29 @@ func (server *Server) Run(ctx context.Context) error {
_ = listener.Close()
}()
- go func() {
- <-ctx.Done()
+ stop := context.AfterFunc(ctx, func() {
_ = listener.Close()
- // TODO: Log the error
- }()
+ })
+ defer stop()
for {
conn, err := listener.Accept()
if err != nil {
- if errors.Is(err, net.ErrClosed) {
+ if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
return nil
}
return fmt.Errorf("accept conn: %w", err)
}
- go server.handleConn(conn)
+ go server.handleConn(ctx, conn)
}
}
-func (server *Server) handleConn(conn net.Conn) {
- panic("TODO: handle LMTP connection")
+func (server *Server) handleConn(ctx context.Context, conn net.Conn) {
+ defer conn.Close()
+ unblock := context.AfterFunc(ctx, func() {
+ _ = conn.SetDeadline(time.Now())
+ _ = conn.Close()
+ })
+ defer unblock()
}
diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go
index 9f9bdff..0c722c0 100644
--- a/forged/internal/incoming/ssh/ssh.go
+++ b/forged/internal/incoming/ssh/ssh.go
@@ -61,20 +61,23 @@ func New(config Config) (server *Server, err error) {
func (server *Server) Run(ctx context.Context) (err error) {
listener, err := misc.Listen(server.net, server.addr)
+ if err != nil {
+ return fmt.Errorf("listen for SSH: %w", err)
+ }
defer func() {
_ = listener.Close()
}()
- go func() {
- <-ctx.Done()
- shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second)
+ stop := context.AfterFunc(ctx, func() {
+ shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second)
defer cancel()
_ = server.gliderServer.Shutdown(shCtx)
_ = listener.Close()
- }()
+ })
+ defer stop()
if err = server.gliderServer.Serve(listener); err != nil {
- if errors.Is(err, gliderssh.ErrServerClosed) {
+ if errors.Is(err, gliderssh.ErrServerClosed) || ctx.Err() != nil {
return nil
}
return fmt.Errorf("serve SSH: %w", err)
diff --git a/forged/internal/incoming/web/web.go b/forged/internal/incoming/web/web.go
index 391f6ff..dc2d9b4 100644
--- a/forged/internal/incoming/web/web.go
+++ b/forged/internal/incoming/web/web.go
@@ -3,6 +3,7 @@ package web
import (
"context"
"fmt"
+ "net"
"net/http"
"time"
@@ -53,21 +54,26 @@ func New(config Config) (server *Server) {
}
func (server *Server) Run(ctx context.Context) (err error) {
+ server.httpServer.BaseContext = func(_ net.Listener) context.Context { return ctx }
+
listener, err := misc.Listen(server.net, server.addr)
+ if err != nil {
+ return fmt.Errorf("listen for web: %w", err)
+ }
defer func() {
_ = listener.Close()
}()
- go func() {
- <-ctx.Done()
- shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second)
+ stop := context.AfterFunc(ctx, func() {
+ shCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Duration(server.shutdownTimeout)*time.Second)
defer cancel()
_ = server.httpServer.Shutdown(shCtx)
_ = listener.Close()
- }()
+ })
+ defer stop()
if err = server.httpServer.Serve(listener); err != nil {
- if err == http.ErrServerClosed {
+ if err == http.ErrServerClosed || ctx.Err() != nil {
return nil
}
return fmt.Errorf("serve web: %w", err)
diff --git a/forged/internal/server/server.go b/forged/internal/server/server.go
index 472df7a..ab677e0 100644
--- a/forged/internal/server/server.go
+++ b/forged/internal/server/server.go
@@ -10,6 +10,7 @@ import (
"go.lindenii.runxiyu.org/forge/forged/internal/incoming/lmtp"
"go.lindenii.runxiyu.org/forge/forged/internal/incoming/ssh"
"go.lindenii.runxiyu.org/forge/forged/internal/incoming/web"
+ "golang.org/x/sync/errgroup"
)
type Server struct {
@@ -51,57 +52,22 @@ func (server *Server) Run(ctx context.Context) (err error) {
// TODO: Not running git2d because it should be run separately.
// This needs to be documented somewhere, hence a TODO here for now.
- subCtx, cancel := context.WithCancel(ctx)
- defer cancel()
+ g, gctx := errgroup.WithContext(ctx)
- server.database, err = database.Open(subCtx, server.config.DB)
+ server.database, err = database.Open(gctx, server.config.DB)
if err != nil {
return fmt.Errorf("open database: %w", err)
}
+ defer server.database.Close()
- errCh := make(chan error)
+ g.Go(func() error { return server.hookServer.Run(gctx) })
+ g.Go(func() error { return server.lmtpServer.Run(gctx) })
+ g.Go(func() error { return server.webServer.Run(gctx) })
+ g.Go(func() error { return server.sshServer.Run(gctx) })
- go func() {
- if err := server.hookServer.Run(subCtx); err != nil {
- select {
- case errCh <- err:
- default:
- }
- }
- }()
-
- go func() {
- if err := server.lmtpServer.Run(subCtx); err != nil {
- select {
- case errCh <- err:
- default:
- }
- }
- }()
-
- go func() {
- if err := server.webServer.Run(subCtx); err != nil {
- select {
- case errCh <- err:
- default:
- }
- }
- }()
-
- go func() {
- if err := server.sshServer.Run(subCtx); err != nil {
- select {
- case errCh <- err:
- default:
- }
- }
- }()
-
- select {
- case err := <-errCh:
+ if err := g.Wait(); err != nil {
return fmt.Errorf("server error: %w", err)
- case <-ctx.Done():
}
- return nil
+ return ctx.Err()
}