diff options
Diffstat (limited to '')
-rw-r--r-- | modules/graceful/net_unix.go | 321 |
1 files changed, 321 insertions, 0 deletions
diff --git a/modules/graceful/net_unix.go b/modules/graceful/net_unix.go new file mode 100644 index 0000000..796e005 --- /dev/null +++ b/modules/graceful/net_unix.go @@ -0,0 +1,321 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler + +//go:build !windows + +package graceful + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" +) + +const ( + listenFDsEnv = "LISTEN_FDS" + startFD = 3 + unlinkFDsEnv = "GITEA_UNLINK_FDS" + + notifySocketEnv = "NOTIFY_SOCKET" + watchdogTimeoutEnv = "WATCHDOG_USEC" +) + +// In order to keep the working directory the same as when we started we record +// it at startup. +var originalWD, _ = os.Getwd() + +var ( + once = sync.Once{} + mutex = sync.Mutex{} + + providedListenersToUnlink = []bool{} + activeListenersToUnlink = []bool{} + providedListeners = []net.Listener{} + activeListeners = []net.Listener{} + + notifySocketAddr string + watchdogTimeout time.Duration +) + +func getProvidedFDs() (savedErr error) { + // Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error + once.Do(func() { + mutex.Lock() + defer mutex.Unlock() + // now handle some additional systemd provided things + notifySocketAddr = os.Getenv(notifySocketEnv) + if notifySocketAddr != "" { + log.Debug("Systemd Notify Socket provided: %s", notifySocketAddr) + savedErr = os.Unsetenv(notifySocketEnv) + if savedErr != nil { + log.Warn("Unable to Unset the NOTIFY_SOCKET environment variable: %v", savedErr) + return + } + // FIXME: We don't handle WATCHDOG_PID + timeoutStr := os.Getenv(watchdogTimeoutEnv) + if timeoutStr != "" { + savedErr = os.Unsetenv(watchdogTimeoutEnv) + if savedErr != nil { + log.Warn("Unable to Unset the WATCHDOG_USEC environment variable: %v", savedErr) + return + } + + s, err := strconv.ParseInt(timeoutStr, 10, 64) + if err != nil { + log.Error("Unable to parse the provided WATCHDOG_USEC: %v", err) + savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %w", err) + return + } + if s <= 0 { + log.Error("Unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr) + savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr) + return + } + watchdogTimeout = time.Duration(s) * time.Microsecond + } + } else { + log.Trace("No Systemd Notify Socket provided") + } + + numFDs := os.Getenv(listenFDsEnv) + if numFDs == "" { + return + } + n, err := strconv.Atoi(numFDs) + if err != nil { + savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDsEnv, numFDs, err) + return + } + + fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDsEnv), ",") + providedListenersToUnlink = make([]bool, n) + for _, fdStr := range fdsToUnlinkStr { + i, err := strconv.Atoi(fdStr) + if err != nil || i < 0 || i >= n { + continue + } + providedListenersToUnlink[i] = true + } + + for i := startFD; i < n+startFD; i++ { + file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i)) + + l, err := net.FileListener(file) + if err == nil { + // Close the inherited file if it's a listener + if err = file.Close(); err != nil { + savedErr = fmt.Errorf("error closing provided socket fd %d: %w", i, err) + return + } + providedListeners = append(providedListeners, l) + continue + } + + // If needed we can handle packetconns here. + savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err) + return + } + }) + return savedErr +} + +// closeProvidedListeners closes all unused provided listeners. +func closeProvidedListeners() { + mutex.Lock() + defer mutex.Unlock() + for _, l := range providedListeners { + err := l.Close() + if err != nil { + log.Error("Error in closing unused provided listener: %v", err) + } + } + providedListeners = []net.Listener{} +} + +// DefaultGetListener obtains a listener for the stream-oriented local network address: +// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". +func DefaultGetListener(network, address string) (net.Listener, error) { + // Add a deferral to say that we've tried to grab a listener + defer GetManager().InformCleanup() + switch network { + case "tcp", "tcp4", "tcp6": + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return GetListenerTCP(network, tcpAddr) + case "unix", "unixpacket": + unixAddr, err := net.ResolveUnixAddr(network, address) + if err != nil { + return nil, err + } + return GetListenerUnix(network, unixAddr) + default: + return nil, net.UnknownNetworkError(network) + } +} + +// GetListenerTCP announces on the local network address. The network must be: +// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the +// matching network and address, or creates a new one using net.ListenTCP. +func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) { + if err := getProvidedFDs(); err != nil { + return nil, err + } + + mutex.Lock() + defer mutex.Unlock() + + // look for a provided listener + for i, l := range providedListeners { + if isSameAddr(l.Addr(), address) { + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) + needsUnlink := providedListenersToUnlink[i] + providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...) + + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink) + return l.(*net.TCPListener), nil + } + } + + // no provided listener for this address -> make a fresh listener + l, err := net.ListenTCP(network, address) + if err != nil { + return nil, err + } + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, false) + return l, nil +} + +// GetListenerUnix announces on the local network address. The network must be: +// "unix" or "unixpacket". It returns a provided net.Listener for the +// matching network and address, or creates a new one using net.ListenUnix. +func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) { + if err := getProvidedFDs(); err != nil { + return nil, err + } + + mutex.Lock() + defer mutex.Unlock() + + // look for a provided listener + for i, l := range providedListeners { + if isSameAddr(l.Addr(), address) { + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) + needsUnlink := providedListenersToUnlink[i] + providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...) + + activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink) + activeListeners = append(activeListeners, l) + unixListener := l.(*net.UnixListener) + if needsUnlink { + unixListener.SetUnlinkOnClose(true) + } + return unixListener, nil + } + } + + // make a fresh listener + if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err) + } + + l, err := net.ListenUnix(network, address) + if err != nil { + return nil, err + } + + fileMode := os.FileMode(setting.UnixSocketPermission) + if err = os.Chmod(address.Name, fileMode); err != nil { + return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err) + } + + activeListeners = append(activeListeners, l) + activeListenersToUnlink = append(activeListenersToUnlink, true) + return l, nil +} + +func isSameAddr(a1, a2 net.Addr) bool { + // If the addresses are not on the same network fail. + if a1.Network() != a2.Network() { + return false + } + + // If the two addresses have the same string representation they're equal + a1s := a1.String() + a2s := a2.String() + if a1s == a2s { + return true + } + + // This allows for ipv6 vs ipv4 local addresses to compare as equal. This + // scenario is common when listening on localhost. + const ipv6prefix = "[::]" + a1s = strings.TrimPrefix(a1s, ipv6prefix) + a2s = strings.TrimPrefix(a2s, ipv6prefix) + const ipv4prefix = "0.0.0.0" + a1s = strings.TrimPrefix(a1s, ipv4prefix) + a2s = strings.TrimPrefix(a2s, ipv4prefix) + return a1s == a2s +} + +func getActiveListeners() []net.Listener { + mutex.Lock() + defer mutex.Unlock() + listeners := make([]net.Listener, len(activeListeners)) + copy(listeners, activeListeners) + return listeners +} + +func getActiveListenersToUnlink() []bool { + mutex.Lock() + defer mutex.Unlock() + listenersToUnlink := make([]bool, len(activeListenersToUnlink)) + copy(listenersToUnlink, activeListenersToUnlink) + return listenersToUnlink +} + +func getNotifySocket() (*net.UnixConn, error) { + if err := getProvidedFDs(); err != nil { + // This error will be logged elsewhere + return nil, nil + } + + if notifySocketAddr == "" { + return nil, nil + } + + socketAddr := &net.UnixAddr{ + Name: notifySocketAddr, + Net: "unixgram", + } + + notifySocket, err := net.DialUnix(socketAddr.Net, nil, socketAddr) + if err != nil { + log.Warn("failed to dial NOTIFY_SOCKET %s: %v", socketAddr, err) + return nil, err + } + + return notifySocket, nil +} + +func getWatchdogTimeout() time.Duration { + if err := getProvidedFDs(); err != nil { + // This error will be logged elsewhere + return 0 + } + + return watchdogTimeout +} |