From dd136858f1ea40ad3c94191d647487fa4f31926c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 18 Oct 2024 20:33:49 +0200 Subject: Adding upstream version 9.0.0. Signed-off-by: Daniel Baumann --- modules/graceful/context.go | 36 +++ modules/graceful/manager.go | 260 +++++++++++++++++ modules/graceful/manager_common.go | 108 +++++++ modules/graceful/manager_unix.go | 201 +++++++++++++ modules/graceful/manager_windows.go | 190 ++++++++++++ modules/graceful/net_unix.go | 321 +++++++++++++++++++++ modules/graceful/net_windows.go | 19 ++ modules/graceful/releasereopen/releasereopen.go | 61 ++++ .../graceful/releasereopen/releasereopen_test.go | 44 +++ modules/graceful/restart_unix.go | 115 ++++++++ modules/graceful/server.go | 284 ++++++++++++++++++ modules/graceful/server_hooks.go | 73 +++++ modules/graceful/server_http.go | 37 +++ 13 files changed, 1749 insertions(+) create mode 100644 modules/graceful/context.go create mode 100644 modules/graceful/manager.go create mode 100644 modules/graceful/manager_common.go create mode 100644 modules/graceful/manager_unix.go create mode 100644 modules/graceful/manager_windows.go create mode 100644 modules/graceful/net_unix.go create mode 100644 modules/graceful/net_windows.go create mode 100644 modules/graceful/releasereopen/releasereopen.go create mode 100644 modules/graceful/releasereopen/releasereopen_test.go create mode 100644 modules/graceful/restart_unix.go create mode 100644 modules/graceful/server.go create mode 100644 modules/graceful/server_hooks.go create mode 100644 modules/graceful/server_http.go (limited to 'modules/graceful') diff --git a/modules/graceful/context.go b/modules/graceful/context.go new file mode 100644 index 0000000..c9c4ca4 --- /dev/null +++ b/modules/graceful/context.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package graceful + +import ( + "context" +) + +// Shutdown procedure: +// * cancel ShutdownContext: the registered context consumers have time to do their cleanup (they could use the hammer context) +// * cancel HammerContext: the all context consumers have limited time to do their cleanup (wait for a few seconds) +// * cancel TerminateContext: the registered context consumers have time to do their cleanup (but they shouldn't use shutdown/hammer context anymore) +// * cancel manager context +// If the shutdown is triggered again during the shutdown procedure, the hammer context will be canceled immediately to force to shut down. + +// ShutdownContext returns a context.Context that is Done at shutdown +// Callers using this context should ensure that they are registered as a running server +// in order that they are waited for. +func (g *Manager) ShutdownContext() context.Context { + return g.shutdownCtx +} + +// HammerContext returns a context.Context that is Done at hammer +// Callers using this context should ensure that they are registered as a running server +// in order that they are waited for. +func (g *Manager) HammerContext() context.Context { + return g.hammerCtx +} + +// TerminateContext returns a context.Context that is Done at terminate +// Callers using this context should ensure that they are registered as a terminating server +// in order that they are waited for. +func (g *Manager) TerminateContext() context.Context { + return g.terminateCtx +} diff --git a/modules/graceful/manager.go b/modules/graceful/manager.go new file mode 100644 index 0000000..077eac6 --- /dev/null +++ b/modules/graceful/manager.go @@ -0,0 +1,260 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package graceful + +import ( + "context" + "runtime/pprof" + "sync" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/process" + "code.gitea.io/gitea/modules/setting" +) + +type state uint8 + +const ( + stateInit state = iota + stateRunning + stateShuttingDown + stateTerminate +) + +type RunCanceler interface { + Run() + Cancel() +} + +// There are some places that could inherit sockets: +// +// * HTTP or HTTPS main listener +// * HTTP or HTTPS install listener +// * HTTP redirection fallback +// * Builtin SSH listener +// +// If you add a new place you must increment this number +// and add a function to call manager.InformCleanup if it's not going to be used +const numberOfServersToCreate = 4 + +var ( + manager *Manager + initOnce sync.Once +) + +// GetManager returns the Manager +func GetManager() *Manager { + InitManager(context.Background()) + return manager +} + +// InitManager creates the graceful manager in the provided context +func InitManager(ctx context.Context) { + initOnce.Do(func() { + manager = newGracefulManager(ctx) + + // Set the process default context to the HammerContext + process.DefaultContext = manager.HammerContext() + }) +} + +// RunWithCancel helps to run a function with a custom context, the Cancel function will be called at shutdown +// The Cancel function should stop the Run function in predictable time. +func (g *Manager) RunWithCancel(rc RunCanceler) { + g.RunAtShutdown(context.Background(), rc.Cancel) + g.runningServerWaitGroup.Add(1) + defer g.runningServerWaitGroup.Done() + defer func() { + if err := recover(); err != nil { + log.Critical("PANIC during RunWithCancel: %v\nStacktrace: %s", err, log.Stack(2)) + g.doShutdown() + } + }() + rc.Run() +} + +// RunWithShutdownContext takes a function that has a context to watch for shutdown. +// After the provided context is Done(), the main function must return once shutdown is complete. +// (Optionally the HammerContext may be obtained and waited for however, this should be avoided if possible.) +func (g *Manager) RunWithShutdownContext(run func(context.Context)) { + g.runningServerWaitGroup.Add(1) + defer g.runningServerWaitGroup.Done() + defer func() { + if err := recover(); err != nil { + log.Critical("PANIC during RunWithShutdownContext: %v\nStacktrace: %s", err, log.Stack(2)) + g.doShutdown() + } + }() + ctx := g.ShutdownContext() + pprof.SetGoroutineLabels(ctx) // We don't have a label to restore back to but I think this is fine + run(ctx) +} + +// RunAtTerminate adds to the terminate wait group and creates a go-routine to run the provided function at termination +func (g *Manager) RunAtTerminate(terminate func()) { + g.terminateWaitGroup.Add(1) + g.lock.Lock() + defer g.lock.Unlock() + g.toRunAtTerminate = append(g.toRunAtTerminate, + func() { + defer g.terminateWaitGroup.Done() + defer func() { + if err := recover(); err != nil { + log.Critical("PANIC during RunAtTerminate: %v\nStacktrace: %s", err, log.Stack(2)) + } + }() + terminate() + }) +} + +// RunAtShutdown creates a go-routine to run the provided function at shutdown +func (g *Manager) RunAtShutdown(ctx context.Context, shutdown func()) { + g.lock.Lock() + defer g.lock.Unlock() + g.toRunAtShutdown = append(g.toRunAtShutdown, + func() { + defer func() { + if err := recover(); err != nil { + log.Critical("PANIC during RunAtShutdown: %v\nStacktrace: %s", err, log.Stack(2)) + } + }() + select { + case <-ctx.Done(): + return + default: + shutdown() + } + }) +} + +func (g *Manager) doShutdown() { + if !g.setStateTransition(stateRunning, stateShuttingDown) { + g.DoImmediateHammer() + return + } + g.lock.Lock() + g.shutdownCtxCancel() + atShutdownCtx := pprof.WithLabels(g.hammerCtx, pprof.Labels("gracefulLifecycle", "post-shutdown")) + pprof.SetGoroutineLabels(atShutdownCtx) + for _, fn := range g.toRunAtShutdown { + go fn() + } + g.lock.Unlock() + + if setting.GracefulHammerTime >= 0 { + go g.doHammerTime(setting.GracefulHammerTime) + } + go func() { + g.runningServerWaitGroup.Wait() + // Mop up any remaining unclosed events. + g.doHammerTime(0) + <-time.After(1 * time.Second) + g.doTerminate() + g.terminateWaitGroup.Wait() + g.lock.Lock() + g.managerCtxCancel() + g.lock.Unlock() + }() +} + +func (g *Manager) doHammerTime(d time.Duration) { + time.Sleep(d) + g.lock.Lock() + select { + case <-g.hammerCtx.Done(): + default: + log.Warn("Setting Hammer condition") + g.hammerCtxCancel() + atHammerCtx := pprof.WithLabels(g.terminateCtx, pprof.Labels("gracefulLifecycle", "post-hammer")) + pprof.SetGoroutineLabels(atHammerCtx) + } + g.lock.Unlock() +} + +func (g *Manager) doTerminate() { + if !g.setStateTransition(stateShuttingDown, stateTerminate) { + return + } + g.lock.Lock() + select { + case <-g.terminateCtx.Done(): + default: + log.Warn("Terminating") + g.terminateCtxCancel() + atTerminateCtx := pprof.WithLabels(g.managerCtx, pprof.Labels("gracefulLifecycle", "post-terminate")) + pprof.SetGoroutineLabels(atTerminateCtx) + + for _, fn := range g.toRunAtTerminate { + go fn() + } + } + g.lock.Unlock() +} + +// IsChild returns if the current process is a child of previous Gitea process +func (g *Manager) IsChild() bool { + return g.isChild +} + +// IsShutdown returns a channel which will be closed at shutdown. +// The order of closure is shutdown, hammer (potentially), terminate +func (g *Manager) IsShutdown() <-chan struct{} { + return g.shutdownCtx.Done() +} + +// IsHammer returns a channel which will be closed at hammer. +// Servers running within the running server wait group should respond to IsHammer +// if not shutdown already +func (g *Manager) IsHammer() <-chan struct{} { + return g.hammerCtx.Done() +} + +// ServerDone declares a running server done and subtracts one from the +// running server wait group. Users probably do not want to call this +// and should use one of the RunWithShutdown* functions +func (g *Manager) ServerDone() { + g.runningServerWaitGroup.Done() +} + +func (g *Manager) setStateTransition(old, new state) bool { + g.lock.Lock() + if g.state != old { + g.lock.Unlock() + return false + } + g.state = new + g.lock.Unlock() + return true +} + +// InformCleanup tells the cleanup wait group that we have either taken a listener or will not be taking a listener. +// At the moment the total number of servers (numberOfServersToCreate) are pre-defined as a const before global init, +// so this function MUST be called if a server is not used. +func (g *Manager) InformCleanup() { + g.createServerCond.L.Lock() + defer g.createServerCond.L.Unlock() + g.createdServer++ + g.createServerCond.Signal() +} + +// Done allows the manager to be viewed as a context.Context, it returns a channel that is closed when the server is finished terminating +func (g *Manager) Done() <-chan struct{} { + return g.managerCtx.Done() +} + +// Err allows the manager to be viewed as a context.Context done at Terminate +func (g *Manager) Err() error { + return g.managerCtx.Err() +} + +// Value allows the manager to be viewed as a context.Context done at Terminate +func (g *Manager) Value(key any) any { + return g.managerCtx.Value(key) +} + +// Deadline returns nil as there is no fixed Deadline for the manager, it allows the manager to be viewed as a context.Context +func (g *Manager) Deadline() (deadline time.Time, ok bool) { + return g.managerCtx.Deadline() +} diff --git a/modules/graceful/manager_common.go b/modules/graceful/manager_common.go new file mode 100644 index 0000000..892957e --- /dev/null +++ b/modules/graceful/manager_common.go @@ -0,0 +1,108 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package graceful + +import ( + "context" + "runtime/pprof" + "sync" + "time" +) + +// FIXME: it seems that there is a bug when using systemd Type=notify: the "Install Page" (INSTALL_LOCK=false) doesn't notify properly. +// At the moment, no idea whether it also affects Windows Service, or whether it's a regression bug. It needs to be investigated later. + +type systemdNotifyMsg string + +const ( + readyMsg systemdNotifyMsg = "READY=1" + stoppingMsg systemdNotifyMsg = "STOPPING=1" + reloadingMsg systemdNotifyMsg = "RELOADING=1" + watchdogMsg systemdNotifyMsg = "WATCHDOG=1" +) + +func statusMsg(msg string) systemdNotifyMsg { + return systemdNotifyMsg("STATUS=" + msg) +} + +// Manager manages the graceful shutdown process +type Manager struct { + ctx context.Context + isChild bool + forked bool + lock sync.RWMutex + state state + shutdownCtx context.Context + hammerCtx context.Context + terminateCtx context.Context + managerCtx context.Context + shutdownCtxCancel context.CancelFunc + hammerCtxCancel context.CancelFunc + terminateCtxCancel context.CancelFunc + managerCtxCancel context.CancelFunc + runningServerWaitGroup sync.WaitGroup + terminateWaitGroup sync.WaitGroup + createServerCond sync.Cond + createdServer int + shutdownRequested chan struct{} + + toRunAtShutdown []func() + toRunAtTerminate []func() +} + +func newGracefulManager(ctx context.Context) *Manager { + manager := &Manager{ctx: ctx, shutdownRequested: make(chan struct{})} + manager.createServerCond.L = &sync.Mutex{} + manager.prepare(ctx) + manager.start() + return manager +} + +func (g *Manager) prepare(ctx context.Context) { + g.terminateCtx, g.terminateCtxCancel = context.WithCancel(ctx) + g.shutdownCtx, g.shutdownCtxCancel = context.WithCancel(ctx) + g.hammerCtx, g.hammerCtxCancel = context.WithCancel(ctx) + g.managerCtx, g.managerCtxCancel = context.WithCancel(ctx) + + g.terminateCtx = pprof.WithLabels(g.terminateCtx, pprof.Labels("gracefulLifecycle", "with-terminate")) + g.shutdownCtx = pprof.WithLabels(g.shutdownCtx, pprof.Labels("gracefulLifecycle", "with-shutdown")) + g.hammerCtx = pprof.WithLabels(g.hammerCtx, pprof.Labels("gracefulLifecycle", "with-hammer")) + g.managerCtx = pprof.WithLabels(g.managerCtx, pprof.Labels("gracefulLifecycle", "with-manager")) + + if !g.setStateTransition(stateInit, stateRunning) { + panic("invalid graceful manager state: transition from init to running failed") + } +} + +// DoImmediateHammer causes an immediate hammer +func (g *Manager) DoImmediateHammer() { + g.notify(statusMsg("Sending immediate hammer")) + g.doHammerTime(0 * time.Second) +} + +// DoGracefulShutdown causes a graceful shutdown +func (g *Manager) DoGracefulShutdown() { + g.lock.Lock() + select { + case <-g.shutdownRequested: + default: + close(g.shutdownRequested) + } + forked := g.forked + g.lock.Unlock() + + if !forked { + g.notify(stoppingMsg) + } else { + g.notify(statusMsg("Shutting down after fork")) + } + g.doShutdown() +} + +// RegisterServer registers the running of a listening server, in the case of unix this means that the parent process can now die. +// Any call to RegisterServer must be matched by a call to ServerDone +func (g *Manager) RegisterServer() { + KillParent() + g.runningServerWaitGroup.Add(1) +} diff --git a/modules/graceful/manager_unix.go b/modules/graceful/manager_unix.go new file mode 100644 index 0000000..931b0f1 --- /dev/null +++ b/modules/graceful/manager_unix.go @@ -0,0 +1,201 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build !windows + +package graceful + +import ( + "context" + "errors" + "os" + "os/signal" + "runtime/pprof" + "strconv" + "syscall" + "time" + + "code.gitea.io/gitea/modules/graceful/releasereopen" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/process" + "code.gitea.io/gitea/modules/setting" +) + +func pidMsg() systemdNotifyMsg { + return systemdNotifyMsg("MAINPID=" + strconv.Itoa(os.Getpid())) +} + +// Notify systemd of status via the notify protocol +func (g *Manager) notify(msg systemdNotifyMsg) { + conn, err := getNotifySocket() + if err != nil { + // the err is logged in getNotifySocket + return + } + if conn == nil { + return + } + defer conn.Close() + + if _, err = conn.Write([]byte(msg)); err != nil { + log.Warn("Failed to notify NOTIFY_SOCKET: %v", err) + return + } +} + +func (g *Manager) start() { + // Now label this and all goroutines created by this goroutine with the gracefulLifecycle manager + pprof.SetGoroutineLabels(g.managerCtx) + defer pprof.SetGoroutineLabels(g.ctx) + + g.isChild = len(os.Getenv(listenFDsEnv)) > 0 && os.Getppid() > 1 + + g.notify(statusMsg("Starting Gitea")) + g.notify(pidMsg()) + go g.handleSignals(g.managerCtx) + + // Handle clean up of unused provided listeners and delayed start-up + startupDone := make(chan struct{}) + go func() { + defer func() { + close(startupDone) + // Close the unused listeners + closeProvidedListeners() + }() + // Wait for all servers to be created + g.createServerCond.L.Lock() + for { + if g.createdServer >= numberOfServersToCreate { + g.createServerCond.L.Unlock() + g.notify(readyMsg) + return + } + select { + case <-g.IsShutdown(): + g.createServerCond.L.Unlock() + return + default: + } + g.createServerCond.Wait() + } + }() + if setting.StartupTimeout > 0 { + go func() { + select { + case <-startupDone: + return + case <-g.IsShutdown(): + g.createServerCond.Signal() + return + case <-time.After(setting.StartupTimeout): + log.Error("Startup took too long! Shutting down") + g.notify(statusMsg("Startup took too long! Shutting down")) + g.notify(stoppingMsg) + g.doShutdown() + } + }() + } +} + +func (g *Manager) handleSignals(ctx context.Context) { + ctx, _, finished := process.GetManager().AddTypedContext(ctx, "Graceful: HandleSignals", process.SystemProcessType, true) + defer finished() + + signalChannel := make(chan os.Signal, 1) + + signal.Notify( + signalChannel, + syscall.SIGHUP, + syscall.SIGUSR1, + syscall.SIGUSR2, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGTSTP, + ) + + watchdogTimeout := getWatchdogTimeout() + t := &time.Ticker{} + if watchdogTimeout != 0 { + g.notify(watchdogMsg) + t = time.NewTicker(watchdogTimeout / 2) + } + + pid := syscall.Getpid() + for { + select { + case sig := <-signalChannel: + switch sig { + case syscall.SIGHUP: + log.Info("PID: %d. Received SIGHUP. Attempting GracefulRestart...", pid) + g.DoGracefulRestart() + case syscall.SIGUSR1: + log.Warn("PID %d. Received SIGUSR1. Releasing and reopening logs", pid) + g.notify(statusMsg("Releasing and reopening logs")) + if err := releasereopen.GetManager().ReleaseReopen(); err != nil { + log.Error("Error whilst releasing and reopening logs: %v", err) + } + case syscall.SIGUSR2: + log.Warn("PID %d. Received SIGUSR2. Hammering...", pid) + g.DoImmediateHammer() + case syscall.SIGINT: + log.Warn("PID %d. Received SIGINT. Shutting down...", pid) + g.DoGracefulShutdown() + case syscall.SIGTERM: + log.Warn("PID %d. Received SIGTERM. Shutting down...", pid) + g.DoGracefulShutdown() + case syscall.SIGTSTP: + log.Info("PID %d. Received SIGTSTP.", pid) + default: + log.Info("PID %d. Received %v.", pid, sig) + } + case <-t.C: + g.notify(watchdogMsg) + case <-ctx.Done(): + log.Warn("PID: %d. Background context for manager closed - %v - Shutting down...", pid, ctx.Err()) + g.DoGracefulShutdown() + return + } + } +} + +func (g *Manager) doFork() error { + g.lock.Lock() + if g.forked { + g.lock.Unlock() + return errors.New("another process already forked. Ignoring this one") + } + g.forked = true + g.lock.Unlock() + + g.notify(reloadingMsg) + + // We need to move the file logs to append pids + setting.RestartLogsWithPIDSuffix() + + _, err := RestartProcess() + + return err +} + +// DoGracefulRestart causes a graceful restart +func (g *Manager) DoGracefulRestart() { + if setting.GracefulRestartable { + log.Info("PID: %d. Forking...", os.Getpid()) + err := g.doFork() + if err != nil { + if err.Error() == "another process already forked. Ignoring this one" { + g.DoImmediateHammer() + } else { + log.Error("Error whilst forking from PID: %d : %v", os.Getpid(), err) + } + } + // doFork calls RestartProcess which starts a new Gitea process, so this parent process needs to exit + // Otherwise some resources (eg: leveldb lock) will be held by this parent process and the new process will fail to start + log.Info("PID: %d. Shutting down after forking ...", os.Getpid()) + g.doShutdown() + } else { + log.Info("PID: %d. Not set restartable. Shutting down...", os.Getpid()) + g.notify(stoppingMsg) + g.doShutdown() + } +} diff --git a/modules/graceful/manager_windows.go b/modules/graceful/manager_windows.go new file mode 100644 index 0000000..bee4438 --- /dev/null +++ b/modules/graceful/manager_windows.go @@ -0,0 +1,190 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT +// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler + +//go:build windows + +package graceful + +import ( + "os" + "runtime/pprof" + "strconv" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" +) + +// WindowsServiceName is the name of the Windows service +var WindowsServiceName = "gitea" + +const ( + hammerCode = 128 + hammerCmd = svc.Cmd(hammerCode) + acceptHammerCode = svc.Accepted(hammerCode) +) + +func (g *Manager) start() { + // Now label this and all goroutines created by this goroutine with the gracefulLifecycle manager + pprof.SetGoroutineLabels(g.managerCtx) + defer pprof.SetGoroutineLabels(g.ctx) + + if skip, _ := strconv.ParseBool(os.Getenv("SKIP_MINWINSVC")); skip { + log.Trace("Skipping SVC check as SKIP_MINWINSVC is set") + return + } + + // Make SVC process + run := svc.Run + + //lint:ignore SA1019 We use IsAnInteractiveSession because IsWindowsService has a different permissions profile + isAnInteractiveSession, err := svc.IsAnInteractiveSession() //nolint:staticcheck + if err != nil { + log.Error("Unable to ascertain if running as an Windows Service: %v", err) + return + } + if isAnInteractiveSession { + log.Trace("Not running a service ... using the debug SVC manager") + run = debug.Run + } + go func() { + _ = run(WindowsServiceName, g) + }() +} + +// Execute makes Manager implement svc.Handler +func (g *Manager) Execute(args []string, changes <-chan svc.ChangeRequest, status chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { + if setting.StartupTimeout > 0 { + status <- svc.Status{State: svc.StartPending, WaitHint: uint32(setting.StartupTimeout / time.Millisecond)} + } else { + status <- svc.Status{State: svc.StartPending} + } + + log.Trace("Awaiting server start-up") + // Now need to wait for everything to start... + if !g.awaitServer(setting.StartupTimeout) { + log.Trace("... start-up failed ... Stopped") + return false, 1 + } + + log.Trace("Sending Running state to SVC") + + // We need to implement some way of svc.AcceptParamChange/svc.ParamChange + status <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown | acceptHammerCode, + } + + log.Trace("Started") + + waitTime := 30 * time.Second + +loop: + for { + select { + case <-g.ctx.Done(): + log.Trace("Shutting down") + g.DoGracefulShutdown() + waitTime += setting.GracefulHammerTime + break loop + case <-g.shutdownRequested: + log.Trace("Shutting down") + waitTime += setting.GracefulHammerTime + break loop + case change := <-changes: + switch change.Cmd { + case svc.Interrogate: + log.Trace("SVC sent interrogate") + status <- change.CurrentStatus + case svc.Stop, svc.Shutdown: + log.Trace("SVC requested shutdown - shutting down") + g.DoGracefulShutdown() + waitTime += setting.GracefulHammerTime + break loop + case hammerCode: + log.Trace("SVC requested hammer - shutting down and hammering immediately") + g.DoGracefulShutdown() + g.DoImmediateHammer() + break loop + default: + log.Debug("Unexpected control request: %v", change.Cmd) + } + } + } + + log.Trace("Sending StopPending state to SVC") + status <- svc.Status{ + State: svc.StopPending, + WaitHint: uint32(waitTime / time.Millisecond), + } + +hammerLoop: + for { + select { + case change := <-changes: + switch change.Cmd { + case svc.Interrogate: + log.Trace("SVC sent interrogate") + status <- change.CurrentStatus + case svc.Stop, svc.Shutdown, hammerCmd: + log.Trace("SVC requested hammer - hammering immediately") + g.DoImmediateHammer() + break hammerLoop + default: + log.Debug("Unexpected control request: %v", change.Cmd) + } + case <-g.hammerCtx.Done(): + break hammerLoop + } + } + + log.Trace("Stopped") + return false, 0 +} + +func (g *Manager) awaitServer(limit time.Duration) bool { + c := make(chan struct{}) + go func() { + g.createServerCond.L.Lock() + for { + if g.createdServer >= numberOfServersToCreate { + g.createServerCond.L.Unlock() + close(c) + return + } + select { + case <-g.IsShutdown(): + g.createServerCond.L.Unlock() + return + default: + } + g.createServerCond.Wait() + } + }() + + var tc <-chan time.Time + if limit > 0 { + tc = time.After(limit) + } + select { + case <-c: + return true // completed normally + case <-tc: + return false // timed out + case <-g.IsShutdown(): + g.createServerCond.Signal() + return false + } +} + +func (g *Manager) notify(msg systemdNotifyMsg) { + // Windows doesn't use systemd to notify +} + +func KillParent() { + // Windows doesn't need to "kill parent" because there is no graceful restart +} diff --git a/modules/graceful/net_unix.go b/modules/graceful/net_unix.go new file mode 100644 index 0000000..796e005 --- /dev/null +++ b/modules/graceful/net_unix.go @@ -0,0 +1,321 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler + +//go:build !windows + +package graceful + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" +) + +const ( + listenFDsEnv = "LISTEN_FDS" + startFD = 3 + unlinkFDsEnv = "GITEA_UNLINK_FDS" + + notifySocketEnv = "NOTIFY_SOCKET" + watchdogTimeoutEnv = "WATCHDOG_USEC" +) + +// In order to keep the working directory the same as when we started we record +// it at startup. +var originalWD, _ = os.Getwd() + +var ( + once = sync.Once{} + mutex = sync.Mutex{} + + providedListenersToUnlink = []bool{} + activeListenersToUnlink = []bool{} + providedListeners = []net.Listener{} + activeListeners = []net.Listener{} + + notifySocketAddr string + watchdogTimeout time.Duration +) + +func getProvidedFDs() (savedErr error) { + // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error + once.Do(func() { + mutex.Lock() + defer mutex.Unlock() + // now handle some additional systemd provided things + notifySocketAddr = os.Getenv(notifySocketEnv) + if notifySocketAddr != "" { + log.Debug("Systemd Notify Socket provided: %s", notifySocketAddr) + savedErr = os.Unsetenv(notifySocketEnv) + if savedErr != nil { + log.Warn("Unable to Unset the NOTIFY_SOCKET environment variable: %v", savedErr) + return + } + // FIXME: We don't handle WATCHDOG_PID + timeoutStr := os.Getenv(watchdogTimeoutEnv) + if timeoutStr != "" { + savedErr = os.Unsetenv(watchdogTimeoutEnv) + if savedErr != nil { + log.Warn("Unable to Unset the WATCHDOG_USEC environment variable: %v", savedErr) + return + } + + s, err := strconv.ParseInt(timeoutStr, 10, 64) + if err != nil { + log.Error("Unable to parse the provided WATCHDOG_USEC: %v", err) + savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %w", err) + return + } + if s <= 0 { + log.Error("Unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr) + savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr) + return + } + watchdogTimeout = time.Duration(s) * time.Microsecond + } + } else { + log.Trace("No Systemd Notify Socket provided") + } + + numFDs := os.Getenv(listenFDsEnv) + if numFDs == "" { + return + } + n, err := strconv.Atoi(numFDs) + if err != nil { + savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDsEnv, numFDs, err) + return + } + + fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDsEnv), ",") + providedListenersToUnlink = make([]bool, n) + for _, fdStr := range fdsToUnlinkStr { + i, err := strconv.Atoi(fdStr) + if err != nil || i < 0 || i >= n { + continue + } + providedListenersToUnlink[i] = true + } + + for i := startFD; i < n+startFD; i++ { + file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i)) + + l, err := net.FileListener(file) + if err == nil { + // Close the inherited file if it's a listener + if err = file.Close(); err != nil { + savedErr = fmt.Errorf("error closing provided socket fd %d: %w", i, err) + return + } + providedListeners = append(providedListeners, l) + continue + } + + // If needed we can handle packetconns here. + savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err) + return + } + }) + return savedErr +} + +// closeProvidedListeners closes all unused provided listeners. +func closeProvidedListeners() { + mutex.Lock() + defer mutex.Unlock() + for _, l := range providedListeners { + err := l.Close() + if err != nil { + log.Error("Error in closing unused provided listener: %v", err) + } + } + providedListeners = []net.Listener{} +} + +// DefaultGetListener obtains a listener for the stream-oriented local network address: +// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". +func DefaultGetListener(network, address string) (net.Listener, error) { + // Add a deferral to say that we've tried to grab a listener + defer GetManager().InformCleanup() + switch network { + case "tcp", "tcp4", "tcp6": + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return GetListenerTCP(network, tcpAddr) + case "unix", "unixpacket": + unixAddr, err := net.ResolveUnixAddr(network, address) + if err != nil { + return nil, err + } + return GetListenerUnix(network, unixAddr) + default: + return nil, net.UnknownNetworkError(network) + } +} + +// GetListenerTCP announces on the local network address. The network must be: +// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the +// matching network and address, or creates a new one using net.ListenTCP. +func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) { + if err := getProvidedFDs(); err != nil { + return nil, err + } + + mutex.Lock() + defer mutex.Unlock() + + // look for a provided listener + for i, l := range providedListeners { + if isSameAddr(l.Addr(), address) { + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) + needsUnlink := providedListenersToUnlink[i] + providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...) + + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink) + return l.(*net.TCPListener), nil + } + } + + // no provided listener for this address -> make a fresh listener + l, err := net.ListenTCP(network, address) + if err != nil { + return nil, err + } + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, false) + return l, nil +} + +// GetListenerUnix announces on the local network address. The network must be: +// "unix" or "unixpacket". It returns a provided net.Listener for the +// matching network and address, or creates a new one using net.ListenUnix. +func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) { + if err := getProvidedFDs(); err != nil { + return nil, err + } + + mutex.Lock() + defer mutex.Unlock() + + // look for a provided listener + for i, l := range providedListeners { + if isSameAddr(l.Addr(), address) { + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) + needsUnlink := providedListenersToUnlink[i] + providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...) + + activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink) + activeListeners = append(activeListeners, l) + unixListener := l.(*net.UnixListener) + if needsUnlink { + unixListener.SetUnlinkOnClose(true) + } + return unixListener, nil + } + } + + // make a fresh listener + if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err) + } + + l, err := net.ListenUnix(network, address) + if err != nil { + return nil, err + } + + fileMode := os.FileMode(setting.UnixSocketPermission) + if err = os.Chmod(address.Name, fileMode); err != nil { + return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err) + } + + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, true) + return l, nil +} + +func isSameAddr(a1, a2 net.Addr) bool { + // If the addresses are not on the same network fail. + if a1.Network() != a2.Network() { + return false + } + + // If the two addresses have the same string representation they're equal + a1s := a1.String() + a2s := a2.String() + if a1s == a2s { + return true + } + + // This allows for ipv6 vs ipv4 local addresses to compare as equal. This + // scenario is common when listening on localhost. + const ipv6prefix = "[::]" + a1s = strings.TrimPrefix(a1s, ipv6prefix) + a2s = strings.TrimPrefix(a2s, ipv6prefix) + const ipv4prefix = "0.0.0.0" + a1s = strings.TrimPrefix(a1s, ipv4prefix) + a2s = strings.TrimPrefix(a2s, ipv4prefix) + return a1s == a2s +} + +func getActiveListeners() []net.Listener { + mutex.Lock() + defer mutex.Unlock() + listeners := make([]net.Listener, len(activeListeners)) + copy(listeners, activeListeners) + return listeners +} + +func getActiveListenersToUnlink() []bool { + mutex.Lock() + defer mutex.Unlock() + listenersToUnlink := make([]bool, len(activeListenersToUnlink)) + copy(listenersToUnlink, activeListenersToUnlink) + return listenersToUnlink +} + +func getNotifySocket() (*net.UnixConn, error) { + if err := getProvidedFDs(); err != nil { + // This error will be logged elsewhere + return nil, nil + } + + if notifySocketAddr == "" { + return nil, nil + } + + socketAddr := &net.UnixAddr{ + Name: notifySocketAddr, + Net: "unixgram", + } + + notifySocket, err := net.DialUnix(socketAddr.Net, nil, socketAddr) + if err != nil { + log.Warn("failed to dial NOTIFY_SOCKET %s: %v", socketAddr, err) + return nil, err + } + + return notifySocket, nil +} + +func getWatchdogTimeout() time.Duration { + if err := getProvidedFDs(); err != nil { + // This error will be logged elsewhere + return 0 + } + + return watchdogTimeout +} diff --git a/modules/graceful/net_windows.go b/modules/graceful/net_windows.go new file mode 100644 index 0000000..9667bd4 --- /dev/null +++ b/modules/graceful/net_windows.go @@ -0,0 +1,19 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler + +//go:build windows + +package graceful + +import "net" + +// DefaultGetListener obtains a listener for the local network address. +// On windows this is basically just a shim around net.Listen. +func DefaultGetListener(network, address string) (net.Listener, error) { + // Add a deferral to say that we've tried to grab a listener + defer GetManager().InformCleanup() + + return net.Listen(network, address) +} diff --git a/modules/graceful/releasereopen/releasereopen.go b/modules/graceful/releasereopen/releasereopen.go new file mode 100644 index 0000000..de5b07c --- /dev/null +++ b/modules/graceful/releasereopen/releasereopen.go @@ -0,0 +1,61 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package releasereopen + +import ( + "errors" + "sync" +) + +type ReleaseReopener interface { + ReleaseReopen() error +} + +type Manager struct { + mu sync.Mutex + counter int64 + + releaseReopeners map[int64]ReleaseReopener +} + +func (r *Manager) Register(rr ReleaseReopener) (cancel func()) { + r.mu.Lock() + defer r.mu.Unlock() + + r.counter++ + currentCounter := r.counter + r.releaseReopeners[r.counter] = rr + + return func() { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.releaseReopeners, currentCounter) + } +} + +func (r *Manager) ReleaseReopen() error { + r.mu.Lock() + defer r.mu.Unlock() + + var errs []error + for _, rr := range r.releaseReopeners { + if err := rr.ReleaseReopen(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func GetManager() *Manager { + return manager +} + +func NewManager() *Manager { + return &Manager{ + releaseReopeners: make(map[int64]ReleaseReopener), + } +} + +var manager = NewManager() diff --git a/modules/graceful/releasereopen/releasereopen_test.go b/modules/graceful/releasereopen/releasereopen_test.go new file mode 100644 index 0000000..6ab9f95 --- /dev/null +++ b/modules/graceful/releasereopen/releasereopen_test.go @@ -0,0 +1,44 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package releasereopen + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testReleaseReopener struct { + count int +} + +func (t *testReleaseReopener) ReleaseReopen() error { + t.count++ + return nil +} + +func TestManager(t *testing.T) { + m := NewManager() + + t1 := &testReleaseReopener{} + t2 := &testReleaseReopener{} + t3 := &testReleaseReopener{} + + _ = m.Register(t1) + c2 := m.Register(t2) + _ = m.Register(t3) + + require.NoError(t, m.ReleaseReopen()) + assert.EqualValues(t, 1, t1.count) + assert.EqualValues(t, 1, t2.count) + assert.EqualValues(t, 1, t3.count) + + c2() + + require.NoError(t, m.ReleaseReopen()) + assert.EqualValues(t, 2, t1.count) + assert.EqualValues(t, 1, t2.count) + assert.EqualValues(t, 2, t3.count) +} diff --git a/modules/graceful/restart_unix.go b/modules/graceful/restart_unix.go new file mode 100644 index 0000000..98d5c5c --- /dev/null +++ b/modules/graceful/restart_unix.go @@ -0,0 +1,115 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler + +//go:build !windows + +package graceful + +import ( + "fmt" + "net" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "syscall" + "time" +) + +var killParent sync.Once + +// KillParent sends the kill signal to the parent process if we are a child +func KillParent() { + killParent.Do(func() { + if GetManager().IsChild() { + ppid := syscall.Getppid() + if ppid > 1 { + _ = syscall.Kill(ppid, syscall.SIGTERM) + } + } + }) +} + +// RestartProcess starts a new process passing it the active listeners. It +// doesn't fork, but starts a new process using the same environment and +// arguments as when it was originally started. This allows for a newly +// deployed binary to be started. It returns the pid of the newly started +// process when successful. +func RestartProcess() (int, error) { + listeners := getActiveListeners() + + // Extract the fds from the listeners. + files := make([]*os.File, len(listeners)) + for i, l := range listeners { + var err error + // Now, all our listeners actually have File() functions so instead of + // individually casting we just use a hacky interface + files[i], err = l.(filer).File() + if err != nil { + return 0, err + } + + if unixListener, ok := l.(*net.UnixListener); ok { + unixListener.SetUnlinkOnClose(false) + } + // Remember to close these at the end. + defer func(i int) { + _ = files[i].Close() + }(i) + } + + // Use the original binary location. This works with symlinks such that if + // the file it points to has been changed we will use the updated symlink. + argv0, err := exec.LookPath(os.Args[0]) + if err != nil { + return 0, err + } + + // Pass on the environment and replace the old count key with the new one. + var env []string + for _, v := range os.Environ() { + if !strings.HasPrefix(v, listenFDsEnv+"=") { + env = append(env, v) + } + } + env = append(env, fmt.Sprintf("%s=%d", listenFDsEnv, len(listeners))) + + if notifySocketAddr != "" { + env = append(env, fmt.Sprintf("%s=%s", notifySocketEnv, notifySocketAddr)) + } + + if watchdogTimeout != 0 { + watchdogStr := strconv.FormatInt(int64(watchdogTimeout/time.Millisecond), 10) + env = append(env, fmt.Sprintf("%s=%s", watchdogTimeoutEnv, watchdogStr)) + } + + sb := &strings.Builder{} + for i, unlink := range getActiveListenersToUnlink() { + if !unlink { + continue + } + _, _ = sb.WriteString(strconv.Itoa(i)) + _, _ = sb.WriteString(",") + } + unlinkStr := sb.String() + if len(unlinkStr) > 0 { + unlinkStr = unlinkStr[:len(unlinkStr)-1] + env = append(env, fmt.Sprintf("%s=%s", unlinkFDsEnv, unlinkStr)) + } + + allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) + process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ + Dir: originalWD, + Env: env, + Files: allFiles, + }) + if err != nil { + return 0, err + } + processPid := process.Pid + _ = process.Release() // no wait, so release + return processPid, nil +} diff --git a/modules/graceful/server.go b/modules/graceful/server.go new file mode 100644 index 0000000..2525a83 --- /dev/null +++ b/modules/graceful/server.go @@ -0,0 +1,284 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// This code is highly inspired by endless go + +package graceful + +import ( + "crypto/tls" + "net" + "os" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/proxyprotocol" + "code.gitea.io/gitea/modules/setting" +) + +// GetListener returns a net listener +// This determines the implementation of net.Listener which the server will use, +// so that downstreams could provide their own Listener, such as with a hidden service or a p2p network +var GetListener = DefaultGetListener + +// ServeFunction represents a listen.Accept loop +type ServeFunction = func(net.Listener) error + +// Server represents our graceful server +type Server struct { + network string + address string + listener net.Listener + wg sync.WaitGroup + state state + lock *sync.RWMutex + BeforeBegin func(network, address string) + OnShutdown func() + PerWriteTimeout time.Duration + PerWritePerKbTimeout time.Duration +} + +// NewServer creates a server on network at provided address +func NewServer(network, address, name string) *Server { + if GetManager().IsChild() { + log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) + } else { + log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid()) + } + srv := &Server{ + wg: sync.WaitGroup{}, + state: stateInit, + lock: &sync.RWMutex{}, + network: network, + address: address, + PerWriteTimeout: setting.PerWriteTimeout, + PerWritePerKbTimeout: setting.PerWritePerKbTimeout, + } + + srv.BeforeBegin = func(network, addr string) { + log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid()) + } + + return srv +} + +// ListenAndServe listens on the provided network address and then calls Serve +// to handle requests on incoming connections. +func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error { + go srv.awaitShutdown() + + listener, err := GetListener(srv.network, srv.address) + if err != nil { + log.Error("Unable to GetListener: %v", err) + return err + } + + // we need to wrap the listener to take account of our lifecycle + listener = newWrappedListener(listener, srv) + + // Now we need to take account of ProxyProtocol settings... + if useProxyProtocol { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + srv.listener = listener + + srv.BeforeBegin(srv.network, srv.address) + + return srv.Serve(serve) +} + +// ListenAndServeTLSConfig listens on the provided network address and then calls +// Serve to handle requests on incoming TLS connections. +func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error { + go srv.awaitShutdown() + + if tlsConfig.MinVersion == 0 { + tlsConfig.MinVersion = tls.VersionTLS12 + } + + listener, err := GetListener(srv.network, srv.address) + if err != nil { + log.Error("Unable to get Listener: %v", err) + return err + } + + // we need to wrap the listener to take account of our lifecycle + listener = newWrappedListener(listener, srv) + + // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us + if useProxyProtocol && !proxyProtocolTLSBridging { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + + // Now handle the tls protocol + listener = tls.NewListener(listener, tlsConfig) + + // Now if we're bridging then we need the proxy to tell us who we're bridging for... + if useProxyProtocol && proxyProtocolTLSBridging { + listener = &proxyprotocol.Listener{ + Listener: listener, + ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, + AcceptUnknown: setting.ProxyProtocolAcceptUnknown, + } + } + + srv.listener = listener + srv.BeforeBegin(srv.network, srv.address) + + return srv.Serve(serve) +} + +// Serve accepts incoming HTTP connections on the wrapped listener l, creating a new +// service goroutine for each. The service goroutines read requests and then call +// handler to reply to them. Handler is typically nil, in which case the +// DefaultServeMux is used. +// +// In addition to the standard Serve behaviour each connection is added to a +// sync.Waitgroup so that all outstanding connections can be served before shutting +// down the server. +func (srv *Server) Serve(serve ServeFunction) error { + defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid()) + srv.setState(stateRunning) + GetManager().RegisterServer() + err := serve(srv.listener) + log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid()) + srv.wg.Wait() + srv.setState(stateTerminate) + GetManager().ServerDone() + // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil + if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") { + return nil + } + return err +} + +func (srv *Server) getState() state { + srv.lock.RLock() + defer srv.lock.RUnlock() + + return srv.state +} + +func (srv *Server) setState(st state) { + srv.lock.Lock() + defer srv.lock.Unlock() + + srv.state = st +} + +type filer interface { + File() (*os.File, error) +} + +type wrappedListener struct { + net.Listener + stopped bool + server *Server +} + +func newWrappedListener(l net.Listener, srv *Server) *wrappedListener { + return &wrappedListener{ + Listener: l, + server: srv, + } +} + +func (wl *wrappedListener) Accept() (net.Conn, error) { + var c net.Conn + // Set keepalive on TCPListeners connections. + if tcl, ok := wl.Listener.(*net.TCPListener); ok { + tc, err := tcl.AcceptTCP() + if err != nil { + return nil, err + } + _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener + _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener + c = tc + } else { + var err error + c, err = wl.Listener.Accept() + if err != nil { + return nil, err + } + } + + closed := int32(0) + + c = &wrappedConn{ + Conn: c, + server: wl.server, + closed: &closed, + perWriteTimeout: wl.server.PerWriteTimeout, + perWritePerKbTimeout: wl.server.PerWritePerKbTimeout, + } + + wl.server.wg.Add(1) + return c, nil +} + +func (wl *wrappedListener) Close() error { + if wl.stopped { + return syscall.EINVAL + } + + wl.stopped = true + return wl.Listener.Close() +} + +func (wl *wrappedListener) File() (*os.File, error) { + // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes + return wl.Listener.(filer).File() +} + +type wrappedConn struct { + net.Conn + server *Server + closed *int32 + deadline time.Time + perWriteTimeout time.Duration + perWritePerKbTimeout time.Duration +} + +func (w *wrappedConn) Write(p []byte) (n int, err error) { + if w.perWriteTimeout > 0 { + minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout + minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout) + + w.deadline = w.deadline.Add(minTimeout) + if minDeadline.After(w.deadline) { + w.deadline = minDeadline + } + _ = w.Conn.SetWriteDeadline(w.deadline) + } + return w.Conn.Write(p) +} + +func (w *wrappedConn) Close() error { + if atomic.CompareAndSwapInt32(w.closed, 0, 1) { + defer func() { + if err := recover(); err != nil { + select { + case <-GetManager().IsHammer(): + // Likely deadlocked request released at hammertime + log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err) + default: + log.Error("Panic during connection close! %v", err) + } + } + }() + w.server.wg.Done() + } + return w.Conn.Close() +} diff --git a/modules/graceful/server_hooks.go b/modules/graceful/server_hooks.go new file mode 100644 index 0000000..9b67589 --- /dev/null +++ b/modules/graceful/server_hooks.go @@ -0,0 +1,73 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package graceful + +import ( + "os" + "runtime" + + "code.gitea.io/gitea/modules/log" +) + +// awaitShutdown waits for the shutdown signal from the Manager +func (srv *Server) awaitShutdown() { + select { + case <-GetManager().IsShutdown(): + // Shutdown + srv.doShutdown() + case <-GetManager().IsHammer(): + // Hammer + srv.doShutdown() + srv.doHammer() + } + <-GetManager().IsHammer() + srv.doHammer() +} + +// shutdown closes the listener so that no new connections are accepted +// and starts a goroutine that will hammer (stop all running requests) the server +// after setting.GracefulHammerTime. +func (srv *Server) doShutdown() { + // only shutdown if we're running. + if srv.getState() != stateRunning { + return + } + + srv.setState(stateShuttingDown) + + if srv.OnShutdown != nil { + srv.OnShutdown() + } + err := srv.listener.Close() + if err != nil { + log.Error("PID: %d Listener.Close() error: %v", os.Getpid(), err) + } else { + log.Info("PID: %d Listener (%s) closed.", os.Getpid(), srv.listener.Addr()) + } +} + +func (srv *Server) doHammer() { + defer func() { + // We call srv.wg.Done() until it panics. + // This happens if we call Done() when the WaitGroup counter is already at 0 + // So if it panics -> we're done, Serve() will return and the + // parent will goroutine will exit. + if r := recover(); r != nil { + log.Error("WaitGroup at 0: Error: %v", r) + } + }() + if srv.getState() != stateShuttingDown { + return + } + log.Warn("Forcefully shutting down parent") + for { + if srv.getState() == stateTerminate { + break + } + srv.wg.Done() + + // Give other goroutines a chance to finish before we forcibly stop them. + runtime.Gosched() + } +} diff --git a/modules/graceful/server_http.go b/modules/graceful/server_http.go new file mode 100644 index 0000000..7c855ac --- /dev/null +++ b/modules/graceful/server_http.go @@ -0,0 +1,37 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package graceful + +import ( + "context" + "crypto/tls" + "net" + "net/http" +) + +func newHTTPServer(network, address, name string, handler http.Handler) (*Server, ServeFunction) { + server := NewServer(network, address, name) + httpServer := http.Server{ + Handler: handler, + BaseContext: func(net.Listener) context.Context { return GetManager().HammerContext() }, + } + server.OnShutdown = func() { + httpServer.SetKeepAlivesEnabled(false) + } + return server, httpServer.Serve +} + +// HTTPListenAndServe listens on the provided network address and then calls Serve +// to handle requests on incoming connections. +func HTTPListenAndServe(network, address, name string, handler http.Handler, useProxyProtocol bool) error { + server, lHandler := newHTTPServer(network, address, name, handler) + return server.ListenAndServe(lHandler, useProxyProtocol) +} + +// HTTPListenAndServeTLSConfig listens on the provided network address and then calls Serve +// to handle requests on incoming connections. +func HTTPListenAndServeTLSConfig(network, address, name string, tlsConfig *tls.Config, handler http.Handler, useProxyProtocol, proxyProtocolTLSBridging bool) error { + server, lHandler := newHTTPServer(network, address, name, handler) + return server.ListenAndServeTLSConfig(tlsConfig, lHandler, useProxyProtocol, proxyProtocolTLSBridging) +} -- cgit v1.2.3