From 308ce6c10ce77835a9b4d2ca9a17d449260c6adb Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sun, 17 Aug 2025 04:34:16 +0800 Subject: Add shutdown timeouts --- forged/internal/incoming/ssh/ssh.go | 39 +++++++++++++++++++--------------- forged/internal/incoming/web/web.go | 42 ++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 36 deletions(-) (limited to 'forged/internal/incoming') diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go index 77812d1..9f9bdff 100644 --- a/forged/internal/incoming/ssh/ssh.go +++ b/forged/internal/incoming/ssh/ssh.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "time" gliderssh "github.com/gliderlabs/ssh" "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" @@ -12,27 +13,30 @@ import ( ) type Config struct { - Net string `scfg:"net"` - Addr string `scfg:"addr"` - Key string `scfg:"key"` - Root string `scfg:"root"` + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Key string `scfg:"key"` + Root string `scfg:"root"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` } type Server struct { - gliderServer *gliderssh.Server - privkey gossh.Signer - pubkeyString string - pubkeyFP string - net string - addr string - root string + gliderServer *gliderssh.Server + privkey gossh.Signer + pubkeyString string + pubkeyFP string + net string + addr string + root string + shutdownTimeout uint32 } func New(config Config) (server *Server, err error) { server = &Server{ - net: config.Net, - addr: config.Addr, - root: config.Root, + net: config.Net, + addr: config.Addr, + root: config.Root, + shutdownTimeout: config.ShutdownTimeout, } var privkeyBytes []byte @@ -63,9 +67,10 @@ func (server *Server) Run(ctx context.Context) (err error) { go func() { <-ctx.Done() - _ = server.gliderServer.Close() - _ = listener.Close() // unnecessary? - // TODO: Log the error + shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.gliderServer.Shutdown(shCtx) + _ = listener.Close() }() if err = server.gliderServer.Serve(listener); err != nil { diff --git a/forged/internal/incoming/web/web.go b/forged/internal/incoming/web/web.go index f66ad64..391f6ff 100644 --- a/forged/internal/incoming/web/web.go +++ b/forged/internal/incoming/web/web.go @@ -10,10 +10,11 @@ import ( ) type Server struct { - net string - addr string - root string - httpServer *http.Server + net string + addr string + root string + httpServer *http.Server + shutdownTimeout uint32 } type handler struct{} @@ -22,23 +23,25 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } type Config struct { - Net string `scfg:"net"` - Addr string `scfg:"addr"` - Root string `scfg:"root"` - CookieExpiry int `scfg:"cookie_expiry"` - ReadTimeout uint32 `scfg:"read_timeout"` - WriteTimeout uint32 `scfg:"write_timeout"` - IdleTimeout uint32 `scfg:"idle_timeout"` - MaxHeaderBytes int `scfg:"max_header_bytes"` - ReverseProxy bool `scfg:"reverse_proxy"` + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Root string `scfg:"root"` + CookieExpiry int `scfg:"cookie_expiry"` + ReadTimeout uint32 `scfg:"read_timeout"` + WriteTimeout uint32 `scfg:"write_timeout"` + IdleTimeout uint32 `scfg:"idle_timeout"` + MaxHeaderBytes int `scfg:"max_header_bytes"` + ReverseProxy bool `scfg:"reverse_proxy"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` } func New(config Config) (server *Server) { handler := &handler{} return &Server{ - net: config.Net, - addr: config.Addr, - root: config.Root, + net: config.Net, + addr: config.Addr, + root: config.Root, + shutdownTimeout: config.ShutdownTimeout, httpServer: &http.Server{ Handler: handler, ReadTimeout: time.Duration(config.ReadTimeout) * time.Second, @@ -57,9 +60,10 @@ func (server *Server) Run(ctx context.Context) (err error) { go func() { <-ctx.Done() - _ = server.httpServer.Close() - _ = listener.Close() // unnecessary? - // TODO: Log the error + shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.httpServer.Shutdown(shCtx) + _ = listener.Close() }() if err = server.httpServer.Serve(listener); err != nil { -- cgit v1.2.3