summaryrefslogtreecommitdiffstats
path: root/modules/graceful/net_unix.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--modules/graceful/net_unix.go321
1 files changed, 321 insertions, 0 deletions
diff --git a/modules/graceful/net_unix.go b/modules/graceful/net_unix.go
new file mode 100644
index 0000000..796e005
--- /dev/null
+++ b/modules/graceful/net_unix.go
@@ -0,0 +1,321 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
+
+//go:build !windows
+
+package graceful
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/util"
+)
+
+const (
+ listenFDsEnv = "LISTEN_FDS"
+ startFD = 3
+ unlinkFDsEnv = "GITEA_UNLINK_FDS"
+
+ notifySocketEnv = "NOTIFY_SOCKET"
+ watchdogTimeoutEnv = "WATCHDOG_USEC"
+)
+
+// In order to keep the working directory the same as when we started we record
+// it at startup.
+var originalWD, _ = os.Getwd()
+
+var (
+ once = sync.Once{}
+ mutex = sync.Mutex{}
+
+ providedListenersToUnlink = []bool{}
+ activeListenersToUnlink = []bool{}
+ providedListeners = []net.Listener{}
+ activeListeners = []net.Listener{}
+
+ notifySocketAddr string
+ watchdogTimeout time.Duration
+)
+
+func getProvidedFDs() (savedErr error) {
+ // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
+ once.Do(func() {
+ mutex.Lock()
+ defer mutex.Unlock()
+ // now handle some additional systemd provided things
+ notifySocketAddr = os.Getenv(notifySocketEnv)
+ if notifySocketAddr != "" {
+ log.Debug("Systemd Notify Socket provided: %s", notifySocketAddr)
+ savedErr = os.Unsetenv(notifySocketEnv)
+ if savedErr != nil {
+ log.Warn("Unable to Unset the NOTIFY_SOCKET environment variable: %v", savedErr)
+ return
+ }
+ // FIXME: We don't handle WATCHDOG_PID
+ timeoutStr := os.Getenv(watchdogTimeoutEnv)
+ if timeoutStr != "" {
+ savedErr = os.Unsetenv(watchdogTimeoutEnv)
+ if savedErr != nil {
+ log.Warn("Unable to Unset the WATCHDOG_USEC environment variable: %v", savedErr)
+ return
+ }
+
+ s, err := strconv.ParseInt(timeoutStr, 10, 64)
+ if err != nil {
+ log.Error("Unable to parse the provided WATCHDOG_USEC: %v", err)
+ savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %w", err)
+ return
+ }
+ if s <= 0 {
+ log.Error("Unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
+ savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
+ return
+ }
+ watchdogTimeout = time.Duration(s) * time.Microsecond
+ }
+ } else {
+ log.Trace("No Systemd Notify Socket provided")
+ }
+
+ numFDs := os.Getenv(listenFDsEnv)
+ if numFDs == "" {
+ return
+ }
+ n, err := strconv.Atoi(numFDs)
+ if err != nil {
+ savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDsEnv, numFDs, err)
+ return
+ }
+
+ fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDsEnv), ",")
+ providedListenersToUnlink = make([]bool, n)
+ for _, fdStr := range fdsToUnlinkStr {
+ i, err := strconv.Atoi(fdStr)
+ if err != nil || i < 0 || i >= n {
+ continue
+ }
+ providedListenersToUnlink[i] = true
+ }
+
+ for i := startFD; i < n+startFD; i++ {
+ file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
+
+ l, err := net.FileListener(file)
+ if err == nil {
+ // Close the inherited file if it's a listener
+ if err = file.Close(); err != nil {
+ savedErr = fmt.Errorf("error closing provided socket fd %d: %w", i, err)
+ return
+ }
+ providedListeners = append(providedListeners, l)
+ continue
+ }
+
+ // If needed we can handle packetconns here.
+ savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err)
+ return
+ }
+ })
+ return savedErr
+}
+
+// closeProvidedListeners closes all unused provided listeners.
+func closeProvidedListeners() {
+ mutex.Lock()
+ defer mutex.Unlock()
+ for _, l := range providedListeners {
+ err := l.Close()
+ if err != nil {
+ log.Error("Error in closing unused provided listener: %v", err)
+ }
+ }
+ providedListeners = []net.Listener{}
+}
+
+// DefaultGetListener obtains a listener for the stream-oriented local network address:
+// "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
+func DefaultGetListener(network, address string) (net.Listener, error) {
+ // Add a deferral to say that we've tried to grab a listener
+ defer GetManager().InformCleanup()
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ tcpAddr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return GetListenerTCP(network, tcpAddr)
+ case "unix", "unixpacket":
+ unixAddr, err := net.ResolveUnixAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return GetListenerUnix(network, unixAddr)
+ default:
+ return nil, net.UnknownNetworkError(network)
+ }
+}
+
+// GetListenerTCP announces on the local network address. The network must be:
+// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
+// matching network and address, or creates a new one using net.ListenTCP.
+func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
+ if err := getProvidedFDs(); err != nil {
+ return nil, err
+ }
+
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ // look for a provided listener
+ for i, l := range providedListeners {
+ if isSameAddr(l.Addr(), address) {
+ providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
+ needsUnlink := providedListenersToUnlink[i]
+ providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
+
+ activeListeners = append(activeListeners, l)
+ activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
+ return l.(*net.TCPListener), nil
+ }
+ }
+
+ // no provided listener for this address -> make a fresh listener
+ l, err := net.ListenTCP(network, address)
+ if err != nil {
+ return nil, err
+ }
+ activeListeners = append(activeListeners, l)
+ activeListenersToUnlink = append(activeListenersToUnlink, false)
+ return l, nil
+}
+
+// GetListenerUnix announces on the local network address. The network must be:
+// "unix" or "unixpacket". It returns a provided net.Listener for the
+// matching network and address, or creates a new one using net.ListenUnix.
+func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
+ if err := getProvidedFDs(); err != nil {
+ return nil, err
+ }
+
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ // look for a provided listener
+ for i, l := range providedListeners {
+ if isSameAddr(l.Addr(), address) {
+ providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
+ needsUnlink := providedListenersToUnlink[i]
+ providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
+
+ activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
+ activeListeners = append(activeListeners, l)
+ unixListener := l.(*net.UnixListener)
+ if needsUnlink {
+ unixListener.SetUnlinkOnClose(true)
+ }
+ return unixListener, nil
+ }
+ }
+
+ // make a fresh listener
+ if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
+ return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err)
+ }
+
+ l, err := net.ListenUnix(network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ fileMode := os.FileMode(setting.UnixSocketPermission)
+ if err = os.Chmod(address.Name, fileMode); err != nil {
+ return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err)
+ }
+
+ activeListeners = append(activeListeners, l)
+ activeListenersToUnlink = append(activeListenersToUnlink, true)
+ return l, nil
+}
+
+func isSameAddr(a1, a2 net.Addr) bool {
+ // If the addresses are not on the same network fail.
+ if a1.Network() != a2.Network() {
+ return false
+ }
+
+ // If the two addresses have the same string representation they're equal
+ a1s := a1.String()
+ a2s := a2.String()
+ if a1s == a2s {
+ return true
+ }
+
+ // This allows for ipv6 vs ipv4 local addresses to compare as equal. This
+ // scenario is common when listening on localhost.
+ const ipv6prefix = "[::]"
+ a1s = strings.TrimPrefix(a1s, ipv6prefix)
+ a2s = strings.TrimPrefix(a2s, ipv6prefix)
+ const ipv4prefix = "0.0.0.0"
+ a1s = strings.TrimPrefix(a1s, ipv4prefix)
+ a2s = strings.TrimPrefix(a2s, ipv4prefix)
+ return a1s == a2s
+}
+
+func getActiveListeners() []net.Listener {
+ mutex.Lock()
+ defer mutex.Unlock()
+ listeners := make([]net.Listener, len(activeListeners))
+ copy(listeners, activeListeners)
+ return listeners
+}
+
+func getActiveListenersToUnlink() []bool {
+ mutex.Lock()
+ defer mutex.Unlock()
+ listenersToUnlink := make([]bool, len(activeListenersToUnlink))
+ copy(listenersToUnlink, activeListenersToUnlink)
+ return listenersToUnlink
+}
+
+func getNotifySocket() (*net.UnixConn, error) {
+ if err := getProvidedFDs(); err != nil {
+ // This error will be logged elsewhere
+ return nil, nil
+ }
+
+ if notifySocketAddr == "" {
+ return nil, nil
+ }
+
+ socketAddr := &net.UnixAddr{
+ Name: notifySocketAddr,
+ Net: "unixgram",
+ }
+
+ notifySocket, err := net.DialUnix(socketAddr.Net, nil, socketAddr)
+ if err != nil {
+ log.Warn("failed to dial NOTIFY_SOCKET %s: %v", socketAddr, err)
+ return nil, err
+ }
+
+ return notifySocket, nil
+}
+
+func getWatchdogTimeout() time.Duration {
+ if err := getProvidedFDs(); err != nil {
+ // This error will be logged elsewhere
+ return 0
+ }
+
+ return watchdogTimeout
+}