diff options
author | Daniel Baumann <daniel@debian.org> | 2024-10-18 20:33:49 +0200 |
---|---|---|
committer | Daniel Baumann <daniel@debian.org> | 2024-10-18 20:33:49 +0200 |
commit | dd136858f1ea40ad3c94191d647487fa4f31926c (patch) | |
tree | 58fec94a7b2a12510c9664b21793f1ed560c6518 /modules/util | |
parent | Initial commit. (diff) | |
download | forgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.tar.xz forgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.zip |
Adding upstream version 9.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
Diffstat (limited to '')
40 files changed, 3347 insertions, 0 deletions
diff --git a/modules/util/color.go b/modules/util/color.go new file mode 100644 index 0000000..9c520dc --- /dev/null +++ b/modules/util/color.go @@ -0,0 +1,57 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT +package util + +import ( + "fmt" + "strconv" + "strings" +) + +// Get color as RGB values in 0..255 range from the hex color string (with or without #) +func HexToRBGColor(colorString string) (float64, float64, float64) { + hexString := colorString + if strings.HasPrefix(colorString, "#") { + hexString = colorString[1:] + } + // only support transfer of rgb, rgba, rrggbb and rrggbbaa + // if not in these formats, use default values 0, 0, 0 + if len(hexString) != 3 && len(hexString) != 4 && len(hexString) != 6 && len(hexString) != 8 { + return 0, 0, 0 + } + if len(hexString) == 3 || len(hexString) == 4 { + hexString = fmt.Sprintf("%c%c%c%c%c%c", hexString[0], hexString[0], hexString[1], hexString[1], hexString[2], hexString[2]) + } + if len(hexString) == 8 { + hexString = hexString[0:6] + } + color, err := strconv.ParseUint(hexString, 16, 64) + if err != nil { + return 0, 0, 0 + } + r := float64(uint8(0xFF & (uint32(color) >> 16))) + g := float64(uint8(0xFF & (uint32(color) >> 8))) + b := float64(uint8(0xFF & uint32(color))) + return r, g, b +} + +// Returns relative luminance for a SRGB color - https://en.wikipedia.org/wiki/Relative_luminance +// Keep this in sync with web_src/js/utils/color.js +func GetRelativeLuminance(color string) float64 { + r, g, b := HexToRBGColor(color) + return (0.2126729*r + 0.7151522*g + 0.0721750*b) / 255 +} + +func UseLightText(backgroundColor string) bool { + return GetRelativeLuminance(backgroundColor) < 0.453 +} + +// Given a background color, returns a black or white foreground color that the highest +// contrast ratio. In the future, the APCA contrast function, or CSS `contrast-color` will be better. +// https://github.com/color-js/color.js/blob/eb7b53f7a13bb716ec8b28c7a56f052cd599acd9/src/contrast/APCA.js#L42 +func ContrastColor(backgroundColor string) string { + if UseLightText(backgroundColor) { + return "#fff" + } + return "#000" +} diff --git a/modules/util/color_test.go b/modules/util/color_test.go new file mode 100644 index 0000000..abd5551 --- /dev/null +++ b/modules/util/color_test.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_HexToRBGColor(t *testing.T) { + cases := []struct { + colorString string + expectedR float64 + expectedG float64 + expectedB float64 + }{ + {"2b8685", 43, 134, 133}, + {"1e1", 17, 238, 17}, + {"#1e1", 17, 238, 17}, + {"1e16", 17, 238, 17}, + {"3bb6b3", 59, 182, 179}, + {"#3bb6b399", 59, 182, 179}, + {"#0", 0, 0, 0}, + {"#00000", 0, 0, 0}, + {"#1234567", 0, 0, 0}, + } + for n, c := range cases { + r, g, b := HexToRBGColor(c.colorString) + assert.InDelta(t, c.expectedR, r, 0, "case %d: error R should match: expected %f, but get %f", n, c.expectedR, r) + assert.InDelta(t, c.expectedG, g, 0, "case %d: error G should match: expected %f, but get %f", n, c.expectedG, g) + assert.InDelta(t, c.expectedB, b, 0, "case %d: error B should match: expected %f, but get %f", n, c.expectedB, b) + } +} + +func Test_UseLightText(t *testing.T) { + cases := []struct { + color string + expected string + }{ + {"#d73a4a", "#fff"}, + {"#0075ca", "#fff"}, + {"#cfd3d7", "#000"}, + {"#a2eeef", "#000"}, + {"#7057ff", "#fff"}, + {"#008672", "#fff"}, + {"#e4e669", "#000"}, + {"#d876e3", "#000"}, + {"#ffffff", "#000"}, + {"#2b8684", "#fff"}, + {"#2b8786", "#fff"}, + {"#2c8786", "#000"}, + {"#3bb6b3", "#000"}, + {"#7c7268", "#fff"}, + {"#7e716c", "#fff"}, + {"#81706d", "#fff"}, + {"#807070", "#fff"}, + {"#84b6eb", "#000"}, + } + for n, c := range cases { + assert.Equal(t, c.expected, ContrastColor(c.color), "case %d: error should match", n) + } +} diff --git a/modules/util/error.go b/modules/util/error.go new file mode 100644 index 0000000..0f35971 --- /dev/null +++ b/modules/util/error.go @@ -0,0 +1,65 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "errors" + "fmt" +) + +// Common Errors forming the base of our error system +// +// Many Errors returned by Gitea can be tested against these errors +// using errors.Is. +var ( + ErrInvalidArgument = errors.New("invalid argument") + ErrPermissionDenied = errors.New("permission denied") + ErrAlreadyExist = errors.New("resource already exists") + ErrNotExist = errors.New("resource does not exist") +) + +// SilentWrap provides a simple wrapper for a wrapped error where the wrapped error message plays no part in the error message +// Especially useful for "untyped" errors created with "errors.New(…)" that can be classified as 'invalid argument', 'permission denied', 'exists already', or 'does not exist' +type SilentWrap struct { + Message string + Err error +} + +// Error returns the message +func (w SilentWrap) Error() string { + return w.Message +} + +// Unwrap returns the underlying error +func (w SilentWrap) Unwrap() error { + return w.Err +} + +// NewSilentWrapErrorf returns an error that formats as the given text but unwraps as the provided error +func NewSilentWrapErrorf(unwrap error, message string, args ...any) error { + if len(args) == 0 { + return SilentWrap{Message: message, Err: unwrap} + } + return SilentWrap{Message: fmt.Sprintf(message, args...), Err: unwrap} +} + +// NewInvalidArgumentErrorf returns an error that formats as the given text but unwraps as an ErrInvalidArgument +func NewInvalidArgumentErrorf(message string, args ...any) error { + return NewSilentWrapErrorf(ErrInvalidArgument, message, args...) +} + +// NewPermissionDeniedErrorf returns an error that formats as the given text but unwraps as an ErrPermissionDenied +func NewPermissionDeniedErrorf(message string, args ...any) error { + return NewSilentWrapErrorf(ErrPermissionDenied, message, args...) +} + +// NewAlreadyExistErrorf returns an error that formats as the given text but unwraps as an ErrAlreadyExist +func NewAlreadyExistErrorf(message string, args ...any) error { + return NewSilentWrapErrorf(ErrAlreadyExist, message, args...) +} + +// NewNotExistErrorf returns an error that formats as the given text but unwraps as an ErrNotExist +func NewNotExistErrorf(message string, args ...any) error { + return NewSilentWrapErrorf(ErrNotExist, message, args...) +} diff --git a/modules/util/file_unix.go b/modules/util/file_unix.go new file mode 100644 index 0000000..79a29c8 --- /dev/null +++ b/modules/util/file_unix.go @@ -0,0 +1,27 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build !windows + +package util + +import ( + "os" + + "golang.org/x/sys/unix" +) + +var defaultUmask int + +func init() { + // at the moment, the umask could only be gotten by calling unix.Umask(newUmask) + // use 0o077 as temp new umask to reduce the risks if this umask is used anywhere else before the correct umask is recovered + tempUmask := 0o077 + defaultUmask = unix.Umask(tempUmask) + unix.Umask(defaultUmask) +} + +func ApplyUmask(f string, newMode os.FileMode) error { + mod := newMode & ^os.FileMode(defaultUmask) + return os.Chmod(f, mod) +} diff --git a/modules/util/file_unix_test.go b/modules/util/file_unix_test.go new file mode 100644 index 0000000..d60082a --- /dev/null +++ b/modules/util/file_unix_test.go @@ -0,0 +1,36 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build !windows + +package util + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyUmask(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "test-filemode-") + require.NoError(t, err) + + err = os.Chmod(f.Name(), 0o777) + require.NoError(t, err) + st, err := os.Stat(f.Name()) + require.NoError(t, err) + assert.EqualValues(t, 0o777, st.Mode().Perm()&0o777) + + oldDefaultUmask := defaultUmask + defaultUmask = 0o037 + defer func() { + defaultUmask = oldDefaultUmask + }() + err = ApplyUmask(f.Name(), os.ModePerm) + require.NoError(t, err) + st, err = os.Stat(f.Name()) + require.NoError(t, err) + assert.EqualValues(t, 0o740, st.Mode().Perm()&0o777) +} diff --git a/modules/util/file_windows.go b/modules/util/file_windows.go new file mode 100644 index 0000000..77a33d3 --- /dev/null +++ b/modules/util/file_windows.go @@ -0,0 +1,15 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//go:build windows + +package util + +import ( + "os" +) + +func ApplyUmask(f string, newMode os.FileMode) error { + // do nothing for Windows, because Windows doesn't use umask + return nil +} diff --git a/modules/util/filebuffer/file_backed_buffer.go b/modules/util/filebuffer/file_backed_buffer.go new file mode 100644 index 0000000..739543e --- /dev/null +++ b/modules/util/filebuffer/file_backed_buffer.go @@ -0,0 +1,156 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package filebuffer + +import ( + "bytes" + "errors" + "io" + "math" + "os" +) + +var ( + // ErrInvalidMemorySize occurs if the memory size is not in a valid range + ErrInvalidMemorySize = errors.New("Memory size must be greater 0 and lower math.MaxInt32") + // ErrWriteAfterRead occurs if Write is called after a read operation + ErrWriteAfterRead = errors.New("Write is unsupported after a read operation") +) + +type readAtSeeker interface { + io.ReadSeeker + io.ReaderAt +} + +// FileBackedBuffer uses a memory buffer with a fixed size. +// If more data is written a temporary file is used instead. +// It implements io.ReadWriteCloser, io.ReadSeekCloser and io.ReaderAt +type FileBackedBuffer struct { + maxMemorySize int64 + size int64 + buffer bytes.Buffer + file *os.File + reader readAtSeeker +} + +// New creates a file backed buffer with a specific maximum memory size +func New(maxMemorySize int) (*FileBackedBuffer, error) { + if maxMemorySize < 0 || maxMemorySize > math.MaxInt32 { + return nil, ErrInvalidMemorySize + } + + return &FileBackedBuffer{ + maxMemorySize: int64(maxMemorySize), + }, nil +} + +// CreateFromReader creates a file backed buffer and copies the provided reader data into it. +func CreateFromReader(r io.Reader, maxMemorySize int) (*FileBackedBuffer, error) { + b, err := New(maxMemorySize) + if err != nil { + return nil, err + } + + _, err = io.Copy(b, r) + if err != nil { + return nil, err + } + + return b, nil +} + +// Write implements io.Writer +func (b *FileBackedBuffer) Write(p []byte) (int, error) { + if b.reader != nil { + return 0, ErrWriteAfterRead + } + + var n int + var err error + + if b.file != nil { + n, err = b.file.Write(p) + } else { + if b.size+int64(len(p)) > b.maxMemorySize { + b.file, err = os.CreateTemp("", "gitea-buffer-") + if err != nil { + return 0, err + } + + _, err = io.Copy(b.file, &b.buffer) + if err != nil { + return 0, err + } + + return b.Write(p) + } + + n, err = b.buffer.Write(p) + } + + if err != nil { + return n, err + } + b.size += int64(n) + return n, nil +} + +// Size returns the byte size of the buffered data +func (b *FileBackedBuffer) Size() int64 { + return b.size +} + +func (b *FileBackedBuffer) switchToReader() error { + if b.reader != nil { + return nil + } + + if b.file != nil { + if _, err := b.file.Seek(0, io.SeekStart); err != nil { + return err + } + b.reader = b.file + } else { + b.reader = bytes.NewReader(b.buffer.Bytes()) + } + return nil +} + +// Read implements io.Reader +func (b *FileBackedBuffer) Read(p []byte) (int, error) { + if err := b.switchToReader(); err != nil { + return 0, err + } + + return b.reader.Read(p) +} + +// ReadAt implements io.ReaderAt +func (b *FileBackedBuffer) ReadAt(p []byte, off int64) (int, error) { + if err := b.switchToReader(); err != nil { + return 0, err + } + + return b.reader.ReadAt(p, off) +} + +// Seek implements io.Seeker +func (b *FileBackedBuffer) Seek(offset int64, whence int) (int64, error) { + if err := b.switchToReader(); err != nil { + return 0, err + } + + return b.reader.Seek(offset, whence) +} + +// Close implements io.Closer +func (b *FileBackedBuffer) Close() error { + if b.file != nil { + err := b.file.Close() + os.Remove(b.file.Name()) + b.file = nil + return err + } + return nil +} diff --git a/modules/util/filebuffer/file_backed_buffer_test.go b/modules/util/filebuffer/file_backed_buffer_test.go new file mode 100644 index 0000000..c56c1c6 --- /dev/null +++ b/modules/util/filebuffer/file_backed_buffer_test.go @@ -0,0 +1,36 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package filebuffer + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileBackedBuffer(t *testing.T) { + cases := []struct { + MaxMemorySize int + Data string + }{ + {5, "test"}, + {5, "testtest"}, + } + + for _, c := range cases { + buf, err := CreateFromReader(strings.NewReader(c.Data), c.MaxMemorySize) + require.NoError(t, err) + + assert.EqualValues(t, len(c.Data), buf.Size()) + + data, err := io.ReadAll(buf) + require.NoError(t, err) + assert.Equal(t, c.Data, string(data)) + + require.NoError(t, buf.Close()) + } +} diff --git a/modules/util/io.go b/modules/util/io.go new file mode 100644 index 0000000..1559b01 --- /dev/null +++ b/modules/util/io.go @@ -0,0 +1,78 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "errors" + "io" +) + +// ReadAtMost reads at most len(buf) bytes from r into buf. +// It returns the number of bytes copied. n is only less than len(buf) if r provides fewer bytes. +// If EOF or ErrUnexpectedEOF occurs while reading, err will be nil. +func ReadAtMost(r io.Reader, buf []byte) (n int, err error) { + n, err = io.ReadFull(r, buf) + if err == io.EOF || err == io.ErrUnexpectedEOF { + err = nil + } + return n, err +} + +// ReadWithLimit reads at most "limit" bytes from r into buf. +// If EOF or ErrUnexpectedEOF occurs while reading, err will be nil. +func ReadWithLimit(r io.Reader, n int) (buf []byte, err error) { + return readWithLimit(r, 1024, n) +} + +func readWithLimit(r io.Reader, batch, limit int) ([]byte, error) { + if limit <= batch { + buf := make([]byte, limit) + n, err := ReadAtMost(r, buf) + if err != nil { + return nil, err + } + return buf[:n], nil + } + res := bytes.NewBuffer(make([]byte, 0, batch)) + bufFix := make([]byte, batch) + eof := false + for res.Len() < limit && !eof { + bufTmp := bufFix + if res.Len()+batch > limit { + bufTmp = bufFix[:limit-res.Len()] + } + n, err := io.ReadFull(r, bufTmp) + if err == io.EOF || err == io.ErrUnexpectedEOF { + eof = true + } else if err != nil { + return nil, err + } + if _, err = res.Write(bufTmp[:n]); err != nil { + return nil, err + } + } + return res.Bytes(), nil +} + +// ErrNotEmpty is an error reported when there is a non-empty reader +var ErrNotEmpty = errors.New("not-empty") + +// IsEmptyReader reads a reader and ensures it is empty +func IsEmptyReader(r io.Reader) (err error) { + var buf [1]byte + + for { + n, err := r.Read(buf[:]) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if n > 0 { + return ErrNotEmpty + } + } +} diff --git a/modules/util/io_test.go b/modules/util/io_test.go new file mode 100644 index 0000000..870e713 --- /dev/null +++ b/modules/util/io_test.go @@ -0,0 +1,67 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type readerWithError struct { + buf *bytes.Buffer +} + +func (r *readerWithError) Read(p []byte) (n int, err error) { + if r.buf.Len() < 2 { + return 0, errors.New("test error") + } + return r.buf.Read(p) +} + +func TestReadWithLimit(t *testing.T) { + bs := []byte("0123456789abcdef") + + // normal test + buf, err := readWithLimit(bytes.NewBuffer(bs), 5, 2) + require.NoError(t, err) + assert.Equal(t, []byte("01"), buf) + + buf, err = readWithLimit(bytes.NewBuffer(bs), 5, 5) + require.NoError(t, err) + assert.Equal(t, []byte("01234"), buf) + + buf, err = readWithLimit(bytes.NewBuffer(bs), 5, 6) + require.NoError(t, err) + assert.Equal(t, []byte("012345"), buf) + + buf, err = readWithLimit(bytes.NewBuffer(bs), 5, len(bs)) + require.NoError(t, err) + assert.Equal(t, []byte("0123456789abcdef"), buf) + + buf, err = readWithLimit(bytes.NewBuffer(bs), 5, 100) + require.NoError(t, err) + assert.Equal(t, []byte("0123456789abcdef"), buf) + + // test with error + buf, err = readWithLimit(&readerWithError{bytes.NewBuffer(bs)}, 5, 10) + require.NoError(t, err) + assert.Equal(t, []byte("0123456789"), buf) + + buf, err = readWithLimit(&readerWithError{bytes.NewBuffer(bs)}, 5, 100) + require.ErrorContains(t, err, "test error") + assert.Empty(t, buf) + + // test public function + buf, err = ReadWithLimit(bytes.NewBuffer(bs), 2) + require.NoError(t, err) + assert.Equal(t, []byte("01"), buf) + + buf, err = ReadWithLimit(bytes.NewBuffer(bs), 9999999) + require.NoError(t, err) + assert.Equal(t, []byte("0123456789abcdef"), buf) +} diff --git a/modules/util/keypair.go b/modules/util/keypair.go new file mode 100644 index 0000000..07f27bd --- /dev/null +++ b/modules/util/keypair.go @@ -0,0 +1,57 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" +) + +// GenerateKeyPair generates a public and private keypair +func GenerateKeyPair(bits int) (string, string, error) { + priv, _ := rsa.GenerateKey(rand.Reader, bits) + privPem := pemBlockForPriv(priv) + pubPem, err := pemBlockForPub(&priv.PublicKey) + if err != nil { + return "", "", err + } + return privPem, pubPem, nil +} + +func pemBlockForPriv(priv *rsa.PrivateKey) string { + privBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(priv), + }) + return string(privBytes) +} + +func pemBlockForPub(pub *rsa.PublicKey) (string, error) { + pubASN1, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return "", err + } + pubBytes := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubASN1, + }) + return string(pubBytes), nil +} + +// CreatePublicKeyFingerprint creates a fingerprint of the given key. +// The fingerprint is the sha256 sum of the PKIX structure of the key. +func CreatePublicKeyFingerprint(key crypto.PublicKey) ([]byte, error) { + bytes, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + return nil, err + } + + checksum := sha256.Sum256(bytes) + + return checksum[:], nil +} diff --git a/modules/util/keypair_test.go b/modules/util/keypair_test.go new file mode 100644 index 0000000..ec9bca7 --- /dev/null +++ b/modules/util/keypair_test.go @@ -0,0 +1,62 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKeygen(t *testing.T) { + priv, pub, err := GenerateKeyPair(2048) + require.NoError(t, err) + + assert.NotEmpty(t, priv) + assert.NotEmpty(t, pub) + + assert.Regexp(t, regexp.MustCompile("^-----BEGIN RSA PRIVATE KEY-----.*"), priv) + assert.Regexp(t, regexp.MustCompile("^-----BEGIN PUBLIC KEY-----.*"), pub) +} + +func TestSignUsingKeys(t *testing.T) { + priv, pub, err := GenerateKeyPair(2048) + require.NoError(t, err) + + privPem, _ := pem.Decode([]byte(priv)) + if privPem == nil || privPem.Type != "RSA PRIVATE KEY" { + t.Fatal("key is wrong type") + } + + privParsed, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) + require.NoError(t, err) + + pubPem, _ := pem.Decode([]byte(pub)) + if pubPem == nil || pubPem.Type != "PUBLIC KEY" { + t.Fatal("key failed to decode") + } + + pubParsed, err := x509.ParsePKIXPublicKey(pubPem.Bytes) + require.NoError(t, err) + + // Sign + msg := "activity pub is great!" + h := sha256.New() + h.Write([]byte(msg)) + d := h.Sum(nil) + sig, err := rsa.SignPKCS1v15(rand.Reader, privParsed, crypto.SHA256, d) + require.NoError(t, err) + + // Verify + err = rsa.VerifyPKCS1v15(pubParsed.(*rsa.PublicKey), crypto.SHA256, d, sig) + require.NoError(t, err) +} diff --git a/modules/util/legacy.go b/modules/util/legacy.go new file mode 100644 index 0000000..2d4de01 --- /dev/null +++ b/modules/util/legacy.go @@ -0,0 +1,38 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "io" + "os" +) + +// CopyFile copies file from source to target path. +func CopyFile(src, dest string) error { + si, err := os.Lstat(src) + if err != nil { + return err + } + + sr, err := os.Open(src) + if err != nil { + return err + } + defer sr.Close() + + dw, err := os.Create(dest) + if err != nil { + return err + } + defer dw.Close() + + if _, err = io.Copy(dw, sr); err != nil { + return err + } + + if err = os.Chtimes(dest, si.ModTime(), si.ModTime()); err != nil { + return err + } + return os.Chmod(dest, si.Mode()) +} diff --git a/modules/util/legacy_test.go b/modules/util/legacy_test.go new file mode 100644 index 0000000..62c2f8a --- /dev/null +++ b/modules/util/legacy_test.go @@ -0,0 +1,38 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCopyFile(t *testing.T) { + testContent := []byte("hello") + + tmpDir := os.TempDir() + now := time.Now() + srcFile := fmt.Sprintf("%s/copy-test-%d-src.txt", tmpDir, now.UnixMicro()) + dstFile := fmt.Sprintf("%s/copy-test-%d-dst.txt", tmpDir, now.UnixMicro()) + + _ = os.Remove(srcFile) + _ = os.Remove(dstFile) + defer func() { + _ = os.Remove(srcFile) + _ = os.Remove(dstFile) + }() + + err := os.WriteFile(srcFile, testContent, 0o777) + require.NoError(t, err) + err = CopyFile(srcFile, dstFile) + require.NoError(t, err) + dstContent, err := os.ReadFile(dstFile) + require.NoError(t, err) + assert.Equal(t, testContent, dstContent) +} diff --git a/modules/util/pack.go b/modules/util/pack.go new file mode 100644 index 0000000..7fc074a --- /dev/null +++ b/modules/util/pack.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "encoding/gob" +) + +// PackData uses gob to encode the given data in sequence +func PackData(data ...any) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + for _, datum := range data { + if err := enc.Encode(datum); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} + +// UnpackData uses gob to decode the given data in sequence +func UnpackData(buf []byte, data ...any) error { + r := bytes.NewReader(buf) + enc := gob.NewDecoder(r) + for _, datum := range data { + if err := enc.Decode(datum); err != nil { + return err + } + } + return nil +} diff --git a/modules/util/pack_test.go b/modules/util/pack_test.go new file mode 100644 index 0000000..42ada89 --- /dev/null +++ b/modules/util/pack_test.go @@ -0,0 +1,28 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPackAndUnpackData(t *testing.T) { + s := "string" + i := int64(4) + f := float32(4.1) + + var s2 string + var i2 int64 + var f2 float32 + + data, err := PackData(s, i, f) + require.NoError(t, err) + + require.NoError(t, UnpackData(data, &s2, &i2, &f2)) + require.NoError(t, UnpackData(data, &s2)) + require.Error(t, UnpackData(data, &i2)) + require.Error(t, UnpackData(data, &s2, &f2)) +} diff --git a/modules/util/paginate.go b/modules/util/paginate.go new file mode 100644 index 0000000..87f31b7 --- /dev/null +++ b/modules/util/paginate.go @@ -0,0 +1,33 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import "reflect" + +// PaginateSlice cut a slice as per pagination options +// if page = 0 it do not paginate +func PaginateSlice(list any, page, pageSize int) any { + if page <= 0 || pageSize <= 0 { + return list + } + if reflect.TypeOf(list).Kind() != reflect.Slice { + return list + } + + listValue := reflect.ValueOf(list) + + page-- + + if page*pageSize >= listValue.Len() { + return listValue.Slice(listValue.Len(), listValue.Len()).Interface() + } + + listValue = listValue.Slice(page*pageSize, listValue.Len()) + + if listValue.Len() > pageSize { + return listValue.Slice(0, pageSize).Interface() + } + + return listValue.Interface() +} diff --git a/modules/util/paginate_test.go b/modules/util/paginate_test.go new file mode 100644 index 0000000..6e69dd1 --- /dev/null +++ b/modules/util/paginate_test.go @@ -0,0 +1,46 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPaginateSlice(t *testing.T) { + stringSlice := []string{"a", "b", "c", "d", "e"} + result, ok := PaginateSlice(stringSlice, 1, 2).([]string) + assert.True(t, ok) + assert.EqualValues(t, []string{"a", "b"}, result) + + result, ok = PaginateSlice(stringSlice, 100, 2).([]string) + assert.True(t, ok) + assert.EqualValues(t, []string{}, result) + + result, ok = PaginateSlice(stringSlice, 3, 2).([]string) + assert.True(t, ok) + assert.EqualValues(t, []string{"e"}, result) + + result, ok = PaginateSlice(stringSlice, 1, 0).([]string) + assert.True(t, ok) + assert.EqualValues(t, []string{"a", "b", "c", "d", "e"}, result) + + result, ok = PaginateSlice(stringSlice, 1, -1).([]string) + assert.True(t, ok) + assert.EqualValues(t, []string{"a", "b", "c", "d", "e"}, result) + + type Test struct { + Val int + } + + testVar := []*Test{{Val: 2}, {Val: 3}, {Val: 4}} + testVar, ok = PaginateSlice(testVar, 1, 50).([]*Test) + assert.True(t, ok) + assert.EqualValues(t, []*Test{{Val: 2}, {Val: 3}, {Val: 4}}, testVar) + + testVar, ok = PaginateSlice(testVar, 2, 2).([]*Test) + assert.True(t, ok) + assert.EqualValues(t, []*Test{{Val: 4}}, testVar) +} diff --git a/modules/util/path.go b/modules/util/path.go new file mode 100644 index 0000000..185e7cf --- /dev/null +++ b/modules/util/path.go @@ -0,0 +1,322 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "errors" + "fmt" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "runtime" + "strings" +) + +// PathJoinRel joins the path elements into a single path, each element is cleaned by path.Clean separately. +// It only returns the following values (like path.Join), any redundant part (empty, relative dots, slashes) is removed. +// It's caller's duty to make every element not bypass its own directly level, to avoid security issues. +// +// empty => `` +// `` => `` +// `..` => `.` +// `dir` => `dir` +// `/dir/` => `dir` +// `foo\..\bar` => `foo\..\bar` +// {`foo`, ``, `bar`} => `foo/bar` +// {`foo`, `..`, `bar`} => `foo/bar` +func PathJoinRel(elem ...string) string { + elems := make([]string, len(elem)) + for i, e := range elem { + if e == "" { + continue + } + elems[i] = path.Clean("/" + e) + } + p := path.Join(elems...) + if p == "" { + return "" + } else if p == "/" { + return "." + } + return p[1:] +} + +// PathJoinRelX joins the path elements into a single path like PathJoinRel, +// and convert all backslashes to slashes. (X means "extended", also means the combination of `\` and `/`). +// It's caller's duty to make every element not bypass its own directly level, to avoid security issues. +// It returns similar results as PathJoinRel except: +// +// `foo\..\bar` => `bar` (because it's processed as `foo/../bar`) +// +// All backslashes are handled as slashes, the result only contains slashes. +func PathJoinRelX(elem ...string) string { + elems := make([]string, len(elem)) + for i, e := range elem { + if e == "" { + continue + } + elems[i] = path.Clean("/" + strings.ReplaceAll(e, "\\", "/")) + } + return PathJoinRel(elems...) +} + +const pathSeparator = string(os.PathSeparator) + +// FilePathJoinAbs joins the path elements into a single file path, each element is cleaned by filepath.Clean separately. +// All slashes/backslashes are converted to path separators before cleaning, the result only contains path separators. +// The first element must be an absolute path, caller should prepare the base path. +// It's caller's duty to make every element not bypass its own directly level, to avoid security issues. +// Like PathJoinRel, any redundant part (empty, relative dots, slashes) is removed. +// +// {`/foo`, ``, `bar`} => `/foo/bar` +// {`/foo`, `..`, `bar`} => `/foo/bar` +func FilePathJoinAbs(base string, sub ...string) string { + elems := make([]string, 1, len(sub)+1) + + // POSIX filesystem can have `\` in file names. Windows: `\` and `/` are both used for path separators + // to keep the behavior consistent, we do not allow `\` in file names, replace all `\` with `/` + if isOSWindows() { + elems[0] = filepath.Clean(base) + } else { + elems[0] = filepath.Clean(strings.ReplaceAll(base, "\\", pathSeparator)) + } + if !filepath.IsAbs(elems[0]) { + // This shouldn't happen. If there is really necessary to pass in relative path, return the full path with filepath.Abs() instead + panic(fmt.Sprintf("FilePathJoinAbs: %q (for path %v) is not absolute, do not guess a relative path based on current working directory", elems[0], elems)) + } + for _, s := range sub { + if s == "" { + continue + } + if isOSWindows() { + elems = append(elems, filepath.Clean(pathSeparator+s)) + } else { + elems = append(elems, filepath.Clean(pathSeparator+strings.ReplaceAll(s, "\\", pathSeparator))) + } + } + // the elems[0] must be an absolute path, just join them together + return filepath.Join(elems...) +} + +// IsDir returns true if given path is a directory, +// or returns false when it's a file or does not exist. +func IsDir(dir string) (bool, error) { + f, err := os.Stat(dir) + if err == nil { + return f.IsDir(), nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// IsFile returns true if given path is a file, +// or returns false when it's a directory or does not exist. +func IsFile(filePath string) (bool, error) { + f, err := os.Stat(filePath) + if err == nil { + return !f.IsDir(), nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// IsExist checks whether a file or directory exists. +// It returns false when the file or directory does not exist. +func IsExist(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil || os.IsExist(err) { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func statDir(dirPath, recPath string, includeDir, isDirOnly, followSymlinks bool) ([]string, error) { + dir, err := os.Open(dirPath) + if err != nil { + return nil, err + } + defer dir.Close() + + fis, err := dir.Readdir(0) + if err != nil { + return nil, err + } + + statList := make([]string, 0) + for _, fi := range fis { + if CommonSkip(fi.Name()) { + continue + } + + relPath := path.Join(recPath, fi.Name()) + curPath := path.Join(dirPath, fi.Name()) + if fi.IsDir() { + if includeDir { + statList = append(statList, relPath+"/") + } + s, err := statDir(curPath, relPath, includeDir, isDirOnly, followSymlinks) + if err != nil { + return nil, err + } + statList = append(statList, s...) + } else if !isDirOnly { + statList = append(statList, relPath) + } else if followSymlinks && fi.Mode()&os.ModeSymlink != 0 { + link, err := os.Readlink(curPath) + if err != nil { + return nil, err + } + + isDir, err := IsDir(link) + if err != nil { + return nil, err + } + if isDir { + if includeDir { + statList = append(statList, relPath+"/") + } + s, err := statDir(curPath, relPath, includeDir, isDirOnly, followSymlinks) + if err != nil { + return nil, err + } + statList = append(statList, s...) + } + } + } + return statList, nil +} + +// StatDir gathers information of given directory by depth-first. +// It returns slice of file list and includes subdirectories if enabled; +// it returns error and nil slice when error occurs in underlying functions, +// or given path is not a directory or does not exist. +// +// Slice does not include given path itself. +// If subdirectories is enabled, they will have suffix '/'. +func StatDir(rootPath string, includeDir ...bool) ([]string, error) { + if isDir, err := IsDir(rootPath); err != nil { + return nil, err + } else if !isDir { + return nil, errors.New("not a directory or does not exist: " + rootPath) + } + + isIncludeDir := false + if len(includeDir) != 0 { + isIncludeDir = includeDir[0] + } + return statDir(rootPath, "", isIncludeDir, false, false) +} + +func isOSWindows() bool { + return runtime.GOOS == "windows" +} + +var driveLetterRegexp = regexp.MustCompile("/[A-Za-z]:/") + +// FileURLToPath extracts the path information from a file://... url. +// It returns an error only if the URL is not a file URL. +func FileURLToPath(u *url.URL) (string, error) { + if u.Scheme != "file" { + return "", errors.New("URL scheme is not 'file': " + u.String()) + } + + path := u.Path + + if !isOSWindows() { + return path, nil + } + + // If it looks like there's a Windows drive letter at the beginning, strip off the leading slash. + if driveLetterRegexp.MatchString(path) { + return path[1:], nil + } + return path, nil +} + +// HomeDir returns path of '~'(in Linux) on Windows, +// it returns error when the variable does not exist. +func HomeDir() (home string, err error) { + // TODO: some users run Gitea with mismatched uid and "HOME=xxx" (they set HOME=xxx by environment manually) + // TODO: when running gitea as a sub command inside git, the HOME directory is not the user's home directory + // so at the moment we can not use `user.Current().HomeDir` + if isOSWindows() { + home = os.Getenv("USERPROFILE") + if home == "" { + home = os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") + } + } else { + home = os.Getenv("HOME") + } + + if home == "" { + return "", errors.New("cannot get home directory") + } + + return home, nil +} + +// CommonSkip will check a provided name to see if it represents file or directory that should not be watched +func CommonSkip(name string) bool { + if name == "" { + return true + } + + switch name[0] { + case '.': + return true + case 't', 'T': + return name[1:] == "humbs.db" + case 'd', 'D': + return name[1:] == "esktop.ini" + } + + return false +} + +// IsReadmeFileName reports whether name looks like a README file +// based on its name. +func IsReadmeFileName(name string) bool { + name = strings.ToLower(name) + if len(name) < 6 { + return false + } else if len(name) == 6 { + return name == "readme" + } + return name[:7] == "readme." +} + +// IsReadmeFileExtension reports whether name looks like a README file +// based on its name. It will look through the provided extensions and check if the file matches +// one of the extensions and provide the index in the extension list. +// If the filename is `readme.` with an unmatched extension it will match with the index equaling +// the length of the provided extension list. +// Note that the '.' should be provided in ext, e.g ".md" +func IsReadmeFileExtension(name string, ext ...string) (int, bool) { + name = strings.ToLower(name) + if len(name) < 6 || name[:6] != "readme" { + return 0, false + } + + for i, extension := range ext { + extension = strings.ToLower(extension) + if name[6:] == extension { + return i, true + } + } + + if name[6] == '.' { + return len(ext), true + } + + return 0, false +} diff --git a/modules/util/path_test.go b/modules/util/path_test.go new file mode 100644 index 0000000..3699f05 --- /dev/null +++ b/modules/util/path_test.go @@ -0,0 +1,213 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "net/url" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileURLToPath(t *testing.T) { + cases := []struct { + url string + expected string + haserror bool + windows bool + }{ + // case 0 + { + url: "", + haserror: true, + }, + // case 1 + { + url: "http://test.io", + haserror: true, + }, + // case 2 + { + url: "file:///path", + expected: "/path", + }, + // case 3 + { + url: "file:///C:/path", + expected: "C:/path", + windows: true, + }, + } + + for n, c := range cases { + if c.windows && runtime.GOOS != "windows" { + continue + } + u, _ := url.Parse(c.url) + p, err := FileURLToPath(u) + if c.haserror { + require.Error(t, err, "case %d: should return error", n) + } else { + require.NoError(t, err, "case %d: should not return error", n) + assert.Equal(t, c.expected, p, "case %d: should be equal", n) + } + } +} + +func TestMisc_IsReadmeFileName(t *testing.T) { + trueTestCases := []string{ + "readme", + "README", + "readME.mdown", + "README.md", + "readme.i18n.md", + } + falseTestCases := []string{ + "test.md", + "wow.MARKDOWN", + "LOL.mDoWn", + "test", + "abcdefg", + "abcdefghijklmnopqrstuvwxyz", + "test.md.test", + "readmf", + } + + for _, testCase := range trueTestCases { + assert.True(t, IsReadmeFileName(testCase)) + } + for _, testCase := range falseTestCases { + assert.False(t, IsReadmeFileName(testCase)) + } + + type extensionTestcase struct { + name string + expected bool + idx int + } + + exts := []string{".md", ".txt", ""} + testCasesExtensions := []extensionTestcase{ + { + name: "readme", + expected: true, + idx: 2, + }, + { + name: "readme.md", + expected: true, + idx: 0, + }, + { + name: "README.md", + expected: true, + idx: 0, + }, + { + name: "ReAdMe.Md", + expected: true, + idx: 0, + }, + { + name: "readme.txt", + expected: true, + idx: 1, + }, + { + name: "readme.doc", + expected: true, + idx: 3, + }, + { + name: "readmee.md", + }, + { + name: "readme..", + expected: true, + idx: 3, + }, + } + + for _, testCase := range testCasesExtensions { + idx, ok := IsReadmeFileExtension(testCase.name, exts...) + assert.Equal(t, testCase.expected, ok) + assert.Equal(t, testCase.idx, idx) + } +} + +func TestCleanPath(t *testing.T) { + cases := []struct { + elems []string + expected string + }{ + {[]string{}, ``}, + {[]string{``}, ``}, + {[]string{`..`}, `.`}, + {[]string{`a`}, `a`}, + {[]string{`/a/`}, `a`}, + {[]string{`../a/`, `../b`, `c/..`, `d`}, `a/b/d`}, + {[]string{`a\..\b`}, `a\..\b`}, + {[]string{`a`, ``, `b`}, `a/b`}, + {[]string{`a`, `..`, `b`}, `a/b`}, + {[]string{`lfs`, `repo/..`, `user/../path`}, `lfs/path`}, + } + for _, c := range cases { + assert.Equal(t, c.expected, PathJoinRel(c.elems...), "case: %v", c.elems) + } + + cases = []struct { + elems []string + expected string + }{ + {[]string{}, ``}, + {[]string{``}, ``}, + {[]string{`..`}, `.`}, + {[]string{`a`}, `a`}, + {[]string{`/a/`}, `a`}, + {[]string{`../a/`, `../b`, `c/..`, `d`}, `a/b/d`}, + {[]string{`a\..\b`}, `b`}, + {[]string{`a`, ``, `b`}, `a/b`}, + {[]string{`a`, `..`, `b`}, `a/b`}, + {[]string{`lfs`, `repo/..`, `user/../path`}, `lfs/path`}, + } + for _, c := range cases { + assert.Equal(t, c.expected, PathJoinRelX(c.elems...), "case: %v", c.elems) + } + + // for POSIX only, but the result is similar on Windows, because the first element must be an absolute path + if isOSWindows() { + cases = []struct { + elems []string + expected string + }{ + {[]string{`C:\..`}, `C:\`}, + {[]string{`C:\a`}, `C:\a`}, + {[]string{`C:\a/`}, `C:\a`}, + {[]string{`C:\..\a\`, `../b`, `c\..`, `d`}, `C:\a\b\d`}, + {[]string{`C:\a/..\b`}, `C:\b`}, + {[]string{`C:\a`, ``, `b`}, `C:\a\b`}, + {[]string{`C:\a`, `..`, `b`}, `C:\a\b`}, + {[]string{`C:\lfs`, `repo/..`, `user/../path`}, `C:\lfs\path`}, + } + } else { + cases = []struct { + elems []string + expected string + }{ + {[]string{`/..`}, `/`}, + {[]string{`/a`}, `/a`}, + {[]string{`/a/`}, `/a`}, + {[]string{`/../a/`, `../b`, `c/..`, `d`}, `/a/b/d`}, + {[]string{`/a\..\b`}, `/b`}, + {[]string{`/a`, ``, `b`}, `/a/b`}, + {[]string{`/a`, `..`, `b`}, `/a/b`}, + {[]string{`/lfs`, `repo/..`, `user/../path`}, `/lfs/path`}, + } + } + for _, c := range cases { + assert.Equal(t, c.expected, FilePathJoinAbs(c.elems[0], c.elems[1:]...), "case: %v", c.elems) + } +} diff --git a/modules/util/remove.go b/modules/util/remove.go new file mode 100644 index 0000000..d1e38fa --- /dev/null +++ b/modules/util/remove.go @@ -0,0 +1,104 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "os" + "runtime" + "syscall" + "time" +) + +const windowsSharingViolationError syscall.Errno = 32 + +// Remove removes the named file or (empty) directory with at most 5 attempts. +func Remove(name string) error { + var err error + for i := 0; i < 5; i++ { + err = os.Remove(name) + if err == nil { + break + } + unwrapped := err.(*os.PathError).Err + if unwrapped == syscall.EBUSY || unwrapped == syscall.ENOTEMPTY || unwrapped == syscall.EPERM || unwrapped == syscall.EMFILE || unwrapped == syscall.ENFILE { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if unwrapped == windowsSharingViolationError && runtime.GOOS == "windows" { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if unwrapped == syscall.ENOENT { + // it's already gone + return nil + } + } + return err +} + +// RemoveAll removes the named file or (empty) directory with at most 5 attempts. +func RemoveAll(name string) error { + var err error + for i := 0; i < 5; i++ { + err = os.RemoveAll(name) + if err == nil { + break + } + unwrapped := err.(*os.PathError).Err + if unwrapped == syscall.EBUSY || unwrapped == syscall.ENOTEMPTY || unwrapped == syscall.EPERM || unwrapped == syscall.EMFILE || unwrapped == syscall.ENFILE { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if unwrapped == windowsSharingViolationError && runtime.GOOS == "windows" { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if unwrapped == syscall.ENOENT { + // it's already gone + return nil + } + } + return err +} + +// Rename renames (moves) oldpath to newpath with at most 5 attempts. +func Rename(oldpath, newpath string) error { + var err error + for i := 0; i < 5; i++ { + err = os.Rename(oldpath, newpath) + if err == nil { + break + } + unwrapped := err.(*os.LinkError).Err + if unwrapped == syscall.EBUSY || unwrapped == syscall.ENOTEMPTY || unwrapped == syscall.EPERM || unwrapped == syscall.EMFILE || unwrapped == syscall.ENFILE { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if unwrapped == windowsSharingViolationError && runtime.GOOS == "windows" { + // try again + <-time.After(100 * time.Millisecond) + continue + } + + if i == 0 && os.IsNotExist(err) { + return err + } + + if unwrapped == syscall.ENOENT { + // it's already gone + return nil + } + } + return err +} diff --git a/modules/util/rotatingfilewriter/writer.go b/modules/util/rotatingfilewriter/writer.go new file mode 100644 index 0000000..c595f49 --- /dev/null +++ b/modules/util/rotatingfilewriter/writer.go @@ -0,0 +1,246 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package rotatingfilewriter + +import ( + "bufio" + "compress/gzip" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "code.gitea.io/gitea/modules/graceful/releasereopen" + "code.gitea.io/gitea/modules/util" +) + +type Options struct { + Rotate bool + MaximumSize int64 + RotateDaily bool + KeepDays int + Compress bool + CompressionLevel int +} + +type RotatingFileWriter struct { + mu sync.Mutex + fd *os.File + + currentSize int64 + openDate int + + options Options + + cancelReleaseReopen func() +} + +var ErrorPrintf func(format string, args ...any) + +// errorf tries to print error messages. Since this writer could be used by a logger system, this is the last chance to show the error in some cases +func errorf(format string, args ...any) { + if ErrorPrintf != nil { + ErrorPrintf("rotatingfilewriter: "+format+"\n", args...) + } +} + +// Open creates a new rotating file writer. +// Notice: if a file is opened by two rotators, there will be conflicts when rotating. +// In the future, there should be "rotating file manager" +func Open(filename string, options *Options) (*RotatingFileWriter, error) { + if options == nil { + options = &Options{} + } + + rfw := &RotatingFileWriter{ + options: *options, + } + + if err := rfw.open(filename); err != nil { + return nil, err + } + + rfw.cancelReleaseReopen = releasereopen.GetManager().Register(rfw) + return rfw, nil +} + +func (rfw *RotatingFileWriter) Write(b []byte) (int, error) { + if rfw.options.Rotate && ((rfw.options.MaximumSize > 0 && rfw.currentSize >= rfw.options.MaximumSize) || (rfw.options.RotateDaily && time.Now().Day() != rfw.openDate)) { + if err := rfw.DoRotate(); err != nil { + // if this writer is used by a logger system, it's the logger system's responsibility to handle/show the error + return 0, err + } + } + + n, err := rfw.fd.Write(b) + if err == nil { + rfw.currentSize += int64(n) + } + return n, err +} + +func (rfw *RotatingFileWriter) Flush() error { + return rfw.fd.Sync() +} + +func (rfw *RotatingFileWriter) Close() error { + rfw.mu.Lock() + if rfw.cancelReleaseReopen != nil { + rfw.cancelReleaseReopen() + rfw.cancelReleaseReopen = nil + } + rfw.mu.Unlock() + return rfw.fd.Close() +} + +func (rfw *RotatingFileWriter) open(filename string) error { + fd, err := os.OpenFile(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0o660) + if err != nil { + return err + } + + rfw.fd = fd + + finfo, err := fd.Stat() + if err != nil { + return err + } + rfw.currentSize = finfo.Size() + rfw.openDate = finfo.ModTime().Day() + + return nil +} + +func (rfw *RotatingFileWriter) ReleaseReopen() error { + return errors.Join( + rfw.fd.Close(), + rfw.open(rfw.fd.Name()), + ) +} + +// DoRotate the log file creating a backup like xx.2013-01-01.2 +func (rfw *RotatingFileWriter) DoRotate() error { + if !rfw.options.Rotate { + return nil + } + + rfw.mu.Lock() + defer rfw.mu.Unlock() + + prefix := fmt.Sprintf("%s.%s.", rfw.fd.Name(), time.Now().Format("2006-01-02")) + + var err error + fname := "" + for i := 1; err == nil && i <= 999; i++ { + fname = prefix + fmt.Sprintf("%03d", i) + _, err = os.Lstat(fname) + if rfw.options.Compress && err != nil { + _, err = os.Lstat(fname + ".gz") + } + } + // return error if the last file checked still existed + if err == nil { + return fmt.Errorf("cannot find free file to rename %s", rfw.fd.Name()) + } + + fd := rfw.fd + if err := fd.Close(); err != nil { // close file before rename + return err + } + + if err := util.Rename(fd.Name(), fname); err != nil { + return err + } + + if rfw.options.Compress { + go func() { + err := compressOldFile(fname, rfw.options.CompressionLevel) + if err != nil { + errorf("DoRotate: %v", err) + } + }() + } + + if err := rfw.open(fd.Name()); err != nil { + return err + } + + go deleteOldFiles( + filepath.Dir(fd.Name()), + filepath.Base(fd.Name()), + time.Now().AddDate(0, 0, -rfw.options.KeepDays), + ) + + return nil +} + +func compressOldFile(fname string, compressionLevel int) error { + reader, err := os.Open(fname) + if err != nil { + return fmt.Errorf("compressOldFile: failed to open existing file %s: %w", fname, err) + } + defer reader.Close() + + buffer := bufio.NewReader(reader) + fnameGz := fname + ".gz" + fw, err := os.OpenFile(fnameGz, os.O_WRONLY|os.O_CREATE, 0o660) + if err != nil { + return fmt.Errorf("compressOldFile: failed to open new file %s: %w", fnameGz, err) + } + defer fw.Close() + + zw, err := gzip.NewWriterLevel(fw, compressionLevel) + if err != nil { + return fmt.Errorf("compressOldFile: failed to create gzip writer: %w", err) + } + defer zw.Close() + + _, err = buffer.WriteTo(zw) + if err != nil { + _ = zw.Close() + _ = fw.Close() + _ = util.Remove(fname + ".gz") + return fmt.Errorf("compressOldFile: failed to write to gz file: %w", err) + } + _ = reader.Close() + + err = util.Remove(fname) + if err != nil { + return fmt.Errorf("compressOldFile: failed to delete old file: %w", err) + } + return nil +} + +func deleteOldFiles(dir, prefix string, removeBefore time.Time) { + err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) (returnErr error) { + defer func() { + if r := recover(); r != nil { + returnErr = fmt.Errorf("unable to delete old file '%s', error: %+v", path, r) + } + }() + + if err != nil { + return err + } + if d.IsDir() { + return nil + } + info, err := d.Info() + if err != nil { + return err + } + if info.ModTime().Before(removeBefore) { + if strings.HasPrefix(filepath.Base(path), prefix) { + return util.Remove(path) + } + } + return nil + }) + if err != nil { + errorf("deleteOldFiles: failed to delete old file: %v", err) + } +} diff --git a/modules/util/rotatingfilewriter/writer_test.go b/modules/util/rotatingfilewriter/writer_test.go new file mode 100644 index 0000000..5b3b351 --- /dev/null +++ b/modules/util/rotatingfilewriter/writer_test.go @@ -0,0 +1,49 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package rotatingfilewriter + +import ( + "compress/gzip" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompressOldFile(t *testing.T) { + tmpDir := t.TempDir() + fname := filepath.Join(tmpDir, "test") + nonGzip := filepath.Join(tmpDir, "test-nonGzip") + + f, err := os.OpenFile(fname, os.O_CREATE|os.O_WRONLY, 0o660) + require.NoError(t, err) + ng, err := os.OpenFile(nonGzip, os.O_CREATE|os.O_WRONLY, 0o660) + require.NoError(t, err) + + for i := 0; i < 999; i++ { + f.WriteString("This is a test file\n") + ng.WriteString("This is a test file\n") + } + f.Close() + ng.Close() + + err = compressOldFile(fname, gzip.DefaultCompression) + require.NoError(t, err) + + _, err = os.Lstat(fname + ".gz") + require.NoError(t, err) + + f, err = os.Open(fname + ".gz") + require.NoError(t, err) + zr, err := gzip.NewReader(f) + require.NoError(t, err) + data, err := io.ReadAll(zr) + require.NoError(t, err) + original, err := os.ReadFile(nonGzip) + require.NoError(t, err) + assert.Equal(t, original, data) +} diff --git a/modules/util/sanitize.go b/modules/util/sanitize.go new file mode 100644 index 0000000..0dd8b34 --- /dev/null +++ b/modules/util/sanitize.go @@ -0,0 +1,72 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "unicode" +) + +type sanitizedError struct { + err error +} + +func (err sanitizedError) Error() string { + return SanitizeCredentialURLs(err.err.Error()) +} + +func (err sanitizedError) Unwrap() error { + return err.err +} + +// SanitizeErrorCredentialURLs wraps the error and make sure the returned error message doesn't contain sensitive credentials in URLs +func SanitizeErrorCredentialURLs(err error) error { + return sanitizedError{err: err} +} + +const userPlaceholder = "sanitized-credential" + +var schemeSep = []byte("://") + +// SanitizeCredentialURLs remove all credentials in URLs (starting with "scheme://") for the input string: "https://user:pass@domain.com" => "https://sanitized-credential@domain.com" +func SanitizeCredentialURLs(s string) string { + bs := UnsafeStringToBytes(s) + schemeSepPos := bytes.Index(bs, schemeSep) + if schemeSepPos == -1 || bytes.IndexByte(bs[schemeSepPos:], '@') == -1 { + return s // fast return if there is no URL scheme or no userinfo + } + out := make([]byte, 0, len(bs)+len(userPlaceholder)) + for schemeSepPos != -1 { + schemeSepPos += 3 // skip the "://" + sepAtPos := -1 // the possible '@' position: "https://foo@[^here]host" + sepEndPos := schemeSepPos // the possible end position: "The https://host[^here] in log for test" + sepLoop: + for ; sepEndPos < len(bs); sepEndPos++ { + c := bs[sepEndPos] + if ('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') { + continue + } + switch c { + case '@': + sepAtPos = sepEndPos + case '-', '.', '_', '~', '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '%': + continue // due to RFC 3986, userinfo can contain - . _ ~ ! $ & ' ( ) * + , ; = : and any percent-encoded chars + default: + break sepLoop // if it is an invalid char for URL (eg: space, '/', and others), stop the loop + } + } + // if there is '@', and the string is like "s://u@h", then hide the "u" part + if sepAtPos != -1 && (schemeSepPos >= 4 && unicode.IsLetter(rune(bs[schemeSepPos-4]))) && sepAtPos-schemeSepPos > 0 && sepEndPos-sepAtPos > 0 { + out = append(out, bs[:schemeSepPos]...) + out = append(out, userPlaceholder...) + out = append(out, bs[sepAtPos:sepEndPos]...) + } else { + out = append(out, bs[:sepEndPos]...) + } + bs = bs[sepEndPos:] + schemeSepPos = bytes.Index(bs, schemeSep) + } + out = append(out, bs...) + return UnsafeBytesToString(out) +} diff --git a/modules/util/sanitize_test.go b/modules/util/sanitize_test.go new file mode 100644 index 0000000..0bcfd45 --- /dev/null +++ b/modules/util/sanitize_test.go @@ -0,0 +1,74 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSanitizeErrorCredentialURLs(t *testing.T) { + err := errors.New("error with https://a@b.com") + se := SanitizeErrorCredentialURLs(err) + assert.Equal(t, "error with https://"+userPlaceholder+"@b.com", se.Error()) +} + +func TestSanitizeCredentialURLs(t *testing.T) { + cases := []struct { + input string + expected string + }{ + { + "https://github.com/go-gitea/test_repo.git", + "https://github.com/go-gitea/test_repo.git", + }, + { + "https://mytoken@github.com/go-gitea/test_repo.git", + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + { + "https://user:password@github.com/go-gitea/test_repo.git", + "https://" + userPlaceholder + "@github.com/go-gitea/test_repo.git", + }, + { + "ftp://x@", + "ftp://" + userPlaceholder + "@", + }, + { + "ftp://x/@", + "ftp://x/@", + }, + { + "ftp://u@x/@", // test multiple @ chars + "ftp://" + userPlaceholder + "@x/@", + }, + { + "😊ftp://u@x😊", // test unicode + "😊ftp://" + userPlaceholder + "@x😊", + }, + { + "://@", + "://@", + }, + { + "//u:p@h", // do not process URLs without explicit scheme, they are not treated as "valid" URLs because there is no scheme context in string + "//u:p@h", + }, + { + "s://u@h", // the minimal pattern to be sanitized + "s://" + userPlaceholder + "@h", + }, + { + "URLs in log https://u:b@h and https://u:b@h:80/, with https://h.com and u@h.com", + "URLs in log https://" + userPlaceholder + "@h and https://" + userPlaceholder + "@h:80/, with https://h.com and u@h.com", + }, + } + + for n, c := range cases { + result := SanitizeCredentialURLs(c.input) + assert.Equal(t, c.expected, result, "case %d: error should match", n) + } +} diff --git a/modules/util/sec_to_time.go b/modules/util/sec_to_time.go new file mode 100644 index 0000000..ad0fb1a --- /dev/null +++ b/modules/util/sec_to_time.go @@ -0,0 +1,81 @@ +// Copyright 2022 Gitea. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "fmt" + "strings" +) + +// SecToTime converts an amount of seconds to a human-readable string. E.g. +// 66s -> 1 minute 6 seconds +// 52410s -> 14 hours 33 minutes +// 563418 -> 6 days 12 hours +// 1563418 -> 2 weeks 4 days +// 3937125s -> 1 month 2 weeks +// 45677465s -> 1 year 6 months +func SecToTime(durationVal any) string { + duration, _ := ToInt64(durationVal) + + formattedTime := "" + + // The following four variables are calculated by taking + // into account the previously calculated variables, this avoids + // pitfalls when using remainders. As that could lead to incorrect + // results when the calculated number equals the quotient number. + remainingDays := duration / (60 * 60 * 24) + years := remainingDays / 365 + remainingDays -= years * 365 + months := remainingDays * 12 / 365 + remainingDays -= months * 365 / 12 + weeks := remainingDays / 7 + remainingDays -= weeks * 7 + days := remainingDays + + // The following three variables are calculated without depending + // on the previous calculated variables. + hours := (duration / 3600) % 24 + minutes := (duration / 60) % 60 + seconds := duration % 60 + + // Extract only the relevant information of the time + // If the time is greater than a year, it makes no sense to display seconds. + switch { + case years > 0: + formattedTime = formatTime(years, "year", formattedTime) + formattedTime = formatTime(months, "month", formattedTime) + case months > 0: + formattedTime = formatTime(months, "month", formattedTime) + formattedTime = formatTime(weeks, "week", formattedTime) + case weeks > 0: + formattedTime = formatTime(weeks, "week", formattedTime) + formattedTime = formatTime(days, "day", formattedTime) + case days > 0: + formattedTime = formatTime(days, "day", formattedTime) + formattedTime = formatTime(hours, "hour", formattedTime) + case hours > 0: + formattedTime = formatTime(hours, "hour", formattedTime) + formattedTime = formatTime(minutes, "minute", formattedTime) + default: + formattedTime = formatTime(minutes, "minute", formattedTime) + formattedTime = formatTime(seconds, "second", formattedTime) + } + + // The formatTime() function always appends a space at the end. This will be trimmed + return strings.TrimRight(formattedTime, " ") +} + +// formatTime appends the given value to the existing forammattedTime. E.g: +// formattedTime = "1 year" +// input: value = 3, name = "month" +// output will be "1 year 3 months " +func formatTime(value int64, name, formattedTime string) string { + if value == 1 { + formattedTime = fmt.Sprintf("%s1 %s ", formattedTime, name) + } else if value > 1 { + formattedTime = fmt.Sprintf("%s%d %ss ", formattedTime, value, name) + } + + return formattedTime +} diff --git a/modules/util/sec_to_time_test.go b/modules/util/sec_to_time_test.go new file mode 100644 index 0000000..4d1213a --- /dev/null +++ b/modules/util/sec_to_time_test.go @@ -0,0 +1,30 @@ +// Copyright 2022 Gitea. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSecToTime(t *testing.T) { + second := int64(1) + minute := 60 * second + hour := 60 * minute + day := 24 * hour + year := 365 * day + + assert.Equal(t, "1 minute 6 seconds", SecToTime(minute+6*second)) + assert.Equal(t, "1 hour", SecToTime(hour)) + assert.Equal(t, "1 hour", SecToTime(hour+second)) + assert.Equal(t, "14 hours 33 minutes", SecToTime(14*hour+33*minute+30*second)) + assert.Equal(t, "6 days 12 hours", SecToTime(6*day+12*hour+30*minute+18*second)) + assert.Equal(t, "2 weeks 4 days", SecToTime((2*7+4)*day+2*hour+16*minute+58*second)) + assert.Equal(t, "4 weeks", SecToTime(4*7*day)) + assert.Equal(t, "4 weeks 1 day", SecToTime((4*7+1)*day)) + assert.Equal(t, "1 month 2 weeks", SecToTime((6*7+3)*day+13*hour+38*minute+45*second)) + assert.Equal(t, "11 months", SecToTime(year-25*day)) + assert.Equal(t, "1 year 5 months", SecToTime(year+163*day+10*hour+11*minute+5*second)) +} diff --git a/modules/util/shellquote.go b/modules/util/shellquote.go new file mode 100644 index 0000000..434dc42 --- /dev/null +++ b/modules/util/shellquote.go @@ -0,0 +1,101 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import "strings" + +// Bash has the definition of a metacharacter: +// * A character that, when unquoted, separates words. +// A metacharacter is one of: " \t\n|&;()<>" +// +// The following characters also have addition special meaning when unescaped: +// * ‘${[*?!"'`\’ +// +// Double Quotes preserve the literal value of all characters with then quotes +// excepting: ‘$’, ‘`’, ‘\’, and, when history expansion is enabled, ‘!’. +// The backslash retains its special meaning only when followed by one of the +// following characters: ‘$’, ‘`’, ‘"’, ‘\’, or newline. +// Backslashes preceding characters without a special meaning are left +// unmodified. A double quote may be quoted within double quotes by preceding +// it with a backslash. If enabled, history expansion will be performed unless +// an ‘!’ appearing in double quotes is escaped using a backslash. The +// backslash preceding the ‘!’ is not removed. +// +// -> This means that `!\n` cannot be safely expressed in `"`. +// +// Looking at the man page for Dash and ash the situation is similar. +// +// Now zsh requires that ‘}’, and ‘]’ are also enclosed in doublequotes or escaped +// +// Single quotes escape everything except a ‘'’ +// +// There's one other gotcha - ‘~’ at the start of a string needs to be expanded +// because people always expect that - of course if there is a special character before '/' +// this is not going to work + +const ( + tildePrefix = '~' + needsEscape = " \t\n|&;()<>${}[]*?!\"'`\\" + needsSingleQuote = "!\n" +) + +var ( + doubleQuoteEscaper = strings.NewReplacer(`$`, `\$`, "`", "\\`", `"`, `\"`, `\`, `\\`) + singleQuoteEscaper = strings.NewReplacer(`'`, `'\''`) + singleQuoteCoalescer = strings.NewReplacer(`''\'`, `\'`, `\'''`, `\'`) +) + +// ShellEscape will escape the provided string. +// We can't just use go-shellquote here because our preferences for escaping differ from those in that we want: +// +// * If the string doesn't require any escaping just leave it as it is. +// * If the string requires any escaping prefer double quote escaping +// * If we have ! or newlines then we need to use single quote escaping +func ShellEscape(toEscape string) string { + if len(toEscape) == 0 { + return toEscape + } + + start := 0 + + if toEscape[0] == tildePrefix { + // We're in the forcibly non-escaped section... + idx := strings.IndexRune(toEscape, '/') + if idx < 0 { + idx = len(toEscape) + } else { + idx++ + } + if !strings.ContainsAny(toEscape[:idx], needsEscape) { + // We'll assume that they intend ~ expansion to occur + start = idx + } + } + + // Now for simplicity we'll look at the rest of the string + if !strings.ContainsAny(toEscape[start:], needsEscape) { + return toEscape + } + + // OK we have to do some escaping + sb := &strings.Builder{} + _, _ = sb.WriteString(toEscape[:start]) + + // Do we have any characters which absolutely need to be within single quotes - that is simply ! or \n? + if strings.ContainsAny(toEscape[start:], needsSingleQuote) { + // We need to single quote escape. + sb2 := &strings.Builder{} + _, _ = sb2.WriteRune('\'') + _, _ = singleQuoteEscaper.WriteString(sb2, toEscape[start:]) + _, _ = sb2.WriteRune('\'') + _, _ = singleQuoteCoalescer.WriteString(sb, sb2.String()) + return sb.String() + } + + // OK we can just use " just escape the things that need escaping + _, _ = sb.WriteRune('"') + _, _ = doubleQuoteEscaper.WriteString(sb, toEscape[start:]) + _, _ = sb.WriteRune('"') + return sb.String() +} diff --git a/modules/util/shellquote_test.go b/modules/util/shellquote_test.go new file mode 100644 index 0000000..969998c --- /dev/null +++ b/modules/util/shellquote_test.go @@ -0,0 +1,91 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import "testing" + +func TestShellEscape(t *testing.T) { + tests := []struct { + name string + toEscape string + want string + }{ + { + "Simplest case - nothing to escape", + "a/b/c/d", + "a/b/c/d", + }, { + "Prefixed tilde - with normal stuff - should not escape", + "~/src/go/gitea/gitea", + "~/src/go/gitea/gitea", + }, { + "Typical windows path with spaces - should get doublequote escaped", + `C:\Program Files\Gitea v1.13 - I like lots of spaces\gitea`, + `"C:\\Program Files\\Gitea v1.13 - I like lots of spaces\\gitea"`, + }, { + "Forward-slashed windows path with spaces - should get doublequote escaped", + "C:/Program Files/Gitea v1.13 - I like lots of spaces/gitea", + `"C:/Program Files/Gitea v1.13 - I like lots of spaces/gitea"`, + }, { + "Prefixed tilde - but then a space filled path", + "~git/Gitea v1.13/gitea", + `~git/"Gitea v1.13/gitea"`, + }, { + "Bangs are unfortunately not predictable so need to be singlequoted", + "C:/Program Files/Gitea!/gitea", + `'C:/Program Files/Gitea!/gitea'`, + }, { + "Newlines are just irritating", + "/home/git/Gitea\n\nWHY-WOULD-YOU-DO-THIS\n\nGitea/gitea", + "'/home/git/Gitea\n\nWHY-WOULD-YOU-DO-THIS\n\nGitea/gitea'", + }, { + "Similarly we should nicely handle multiple single quotes if we have to single-quote", + "'!''!'''!''!'!'", + `\''!'\'\''!'\'\'\''!'\'\''!'\''!'\'`, + }, { + "Double quote < ...", + "~/<gitea", + "~/\"<gitea\"", + }, { + "Double quote > ...", + "~/gitea>", + "~/\"gitea>\"", + }, { + "Double quote and escape $ ...", + "~/$gitea", + "~/\"\\$gitea\"", + }, { + "Double quote {...", + "~/{gitea", + "~/\"{gitea\"", + }, { + "Double quote }...", + "~/gitea}", + "~/\"gitea}\"", + }, { + "Double quote ()...", + "~/(gitea)", + "~/\"(gitea)\"", + }, { + "Double quote and escape `...", + "~/gitea`", + "~/\"gitea\\`\"", + }, { + "Double quotes can handle a number of things without having to escape them but not everything ...", + "~/<gitea> ${gitea} `gitea` [gitea] (gitea) \"gitea\" \\gitea\\ 'gitea'", + "~/\"<gitea> \\${gitea} \\`gitea\\` [gitea] (gitea) \\\"gitea\\\" \\\\gitea\\\\ 'gitea'\"", + }, { + "Single quotes don't need to escape except for '...", + "~/<gitea> ${gitea} `gitea` (gitea) !gitea! \"gitea\" \\gitea\\ 'gitea'", + "~/'<gitea> ${gitea} `gitea` (gitea) !gitea! \"gitea\" \\gitea\\ '\\''gitea'\\'", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ShellEscape(tt.toEscape); got != tt.want { + t.Errorf("ShellEscape(%q):\nGot: %s\nWanted: %s", tt.toEscape, got, tt.want) + } + }) + } +} diff --git a/modules/util/slice.go b/modules/util/slice.go new file mode 100644 index 0000000..9c878c2 --- /dev/null +++ b/modules/util/slice.go @@ -0,0 +1,73 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "cmp" + "slices" + "strings" +) + +// SliceContainsString sequential searches if string exists in slice. +func SliceContainsString(slice []string, target string, insensitive ...bool) bool { + if len(insensitive) != 0 && insensitive[0] { + target = strings.ToLower(target) + return slices.ContainsFunc(slice, func(t string) bool { return strings.ToLower(t) == target }) + } + + return slices.Contains(slice, target) +} + +// SliceSortedEqual returns true if the two slices will be equal when they get sorted. +// It doesn't require that the slices have been sorted, and it doesn't sort them either. +func SliceSortedEqual[T comparable](s1, s2 []T) bool { + if len(s1) != len(s2) { + return false + } + + counts := make(map[T]int, len(s1)) + for _, v := range s1 { + counts[v]++ + } + for _, v := range s2 { + counts[v]-- + } + + for _, v := range counts { + if v != 0 { + return false + } + } + return true +} + +// SliceRemoveAll removes all the target elements from the slice. +func SliceRemoveAll[T comparable](slice []T, target T) []T { + return slices.DeleteFunc(slice, func(t T) bool { return t == target }) +} + +// Sorted returns the sorted slice +// Note: The parameter is sorted inline. +func Sorted[S ~[]E, E cmp.Ordered](values S) S { + slices.Sort(values) + return values +} + +// TODO: Replace with "maps.Values" once available, current it only in golang.org/x/exp/maps but not in standard library +func ValuesOfMap[K comparable, V any](m map[K]V) []V { + values := make([]V, 0, len(m)) + for _, v := range m { + values = append(values, v) + } + return values +} + +// TODO: Replace with "maps.Keys" once available, current it only in golang.org/x/exp/maps but not in standard library +func KeysOfMap[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/modules/util/slice_test.go b/modules/util/slice_test.go new file mode 100644 index 0000000..a910f5e --- /dev/null +++ b/modules/util/slice_test.go @@ -0,0 +1,55 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSliceContainsString(t *testing.T) { + assert.True(t, SliceContainsString([]string{"c", "b", "a", "b"}, "a")) + assert.True(t, SliceContainsString([]string{"c", "b", "a", "b"}, "b")) + assert.True(t, SliceContainsString([]string{"c", "b", "a", "b"}, "A", true)) + assert.True(t, SliceContainsString([]string{"C", "B", "A", "B"}, "a", true)) + + assert.False(t, SliceContainsString([]string{"c", "b", "a", "b"}, "z")) + assert.False(t, SliceContainsString([]string{"c", "b", "a", "b"}, "A")) + assert.False(t, SliceContainsString([]string{}, "a")) + assert.False(t, SliceContainsString(nil, "a")) +} + +func TestSliceSortedEqual(t *testing.T) { + assert.True(t, SliceSortedEqual([]int{2, 0, 2, 3}, []int{2, 0, 2, 3})) + assert.True(t, SliceSortedEqual([]int{3, 0, 2, 2}, []int{2, 0, 2, 3})) + assert.True(t, SliceSortedEqual([]int{}, []int{})) + assert.True(t, SliceSortedEqual([]int(nil), nil)) + assert.True(t, SliceSortedEqual([]int(nil), []int{})) + assert.True(t, SliceSortedEqual([]int{}, []int{})) + + assert.True(t, SliceSortedEqual([]string{"2", "0", "2", "3"}, []string{"2", "0", "2", "3"})) + assert.True(t, SliceSortedEqual([]float64{2, 0, 2, 3}, []float64{2, 0, 2, 3})) + assert.True(t, SliceSortedEqual([]bool{false, true, false}, []bool{false, true, false})) + + assert.False(t, SliceSortedEqual([]int{2, 0, 2}, []int{2, 0, 2, 3})) + assert.False(t, SliceSortedEqual([]int{}, []int{2, 0, 2, 3})) + assert.False(t, SliceSortedEqual(nil, []int{2, 0, 2, 3})) + assert.False(t, SliceSortedEqual([]int{2, 0, 2, 4}, []int{2, 0, 2, 3})) + assert.False(t, SliceSortedEqual([]int{2, 0, 0, 3}, []int{2, 0, 2, 3})) +} + +func TestSliceRemoveAll(t *testing.T) { + assert.ElementsMatch(t, []int{2, 2, 3}, SliceRemoveAll([]int{2, 0, 2, 3}, 0)) + assert.ElementsMatch(t, []int{0, 3}, SliceRemoveAll([]int{2, 0, 2, 3}, 2)) + assert.Empty(t, SliceRemoveAll([]int{0, 0, 0, 0}, 0)) + assert.ElementsMatch(t, []int{2, 0, 2, 3}, SliceRemoveAll([]int{2, 0, 2, 3}, 4)) + assert.Empty(t, SliceRemoveAll([]int{}, 0)) + assert.ElementsMatch(t, []int(nil), SliceRemoveAll([]int(nil), 0)) + assert.Empty(t, SliceRemoveAll([]int{}, 0)) + + assert.ElementsMatch(t, []string{"2", "2", "3"}, SliceRemoveAll([]string{"2", "0", "2", "3"}, "0")) + assert.ElementsMatch(t, []float64{2, 2, 3}, SliceRemoveAll([]float64{2, 0, 2, 3}, 0)) + assert.ElementsMatch(t, []bool{false, false}, SliceRemoveAll([]bool{false, true, false}, true)) +} diff --git a/modules/util/string.go b/modules/util/string.go new file mode 100644 index 0000000..cf50f59 --- /dev/null +++ b/modules/util/string.go @@ -0,0 +1,97 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import "unsafe" + +func isSnakeCaseUpper(c byte) bool { + return 'A' <= c && c <= 'Z' +} + +func isSnakeCaseLowerOrNumber(c byte) bool { + return 'a' <= c && c <= 'z' || '0' <= c && c <= '9' +} + +// ToSnakeCase convert the input string to snake_case format. +// +// Some samples. +// +// "FirstName" => "first_name" +// "HTTPServer" => "http_server" +// "NoHTTPS" => "no_https" +// "GO_PATH" => "go_path" +// "GO PATH" => "go_path" // space is converted to underscore. +// "GO-PATH" => "go_path" // hyphen is converted to underscore. +func ToSnakeCase(input string) string { + if len(input) == 0 { + return "" + } + + var res []byte + if len(input) == 1 { + c := input[0] + if isSnakeCaseUpper(c) { + res = []byte{c + 'a' - 'A'} + } else if isSnakeCaseLowerOrNumber(c) { + res = []byte{c} + } else { + res = []byte{'_'} + } + } else { + res = make([]byte, 0, len(input)*4/3) + pos := 0 + needSep := false + for pos < len(input) { + c := input[pos] + if c >= 0x80 { + res = append(res, c) + pos++ + continue + } + isUpper := isSnakeCaseUpper(c) + if isUpper || isSnakeCaseLowerOrNumber(c) { + end := pos + 1 + if isUpper { + // skip the following upper letters + for end < len(input) && isSnakeCaseUpper(input[end]) { + end++ + } + if end-pos > 1 && end < len(input) && isSnakeCaseLowerOrNumber(input[end]) { + end-- + } + } + // skip the following lower or number letters + for end < len(input) && (isSnakeCaseLowerOrNumber(input[end]) || input[end] >= 0x80) { + end++ + } + if needSep { + res = append(res, '_') + } + res = append(res, input[pos:end]...) + pos = end + needSep = true + } else { + res = append(res, '_') + pos++ + needSep = false + } + } + for i := 0; i < len(res); i++ { + if isSnakeCaseUpper(res[i]) { + res[i] += 'a' - 'A' + } + } + } + return UnsafeBytesToString(res) +} + +// UnsafeBytesToString uses Go's unsafe package to convert a byte slice to a string. +func UnsafeBytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} + +// UnsafeStringToBytes uses Go's unsafe package to convert a string to a byte slice. +func UnsafeStringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} diff --git a/modules/util/string_test.go b/modules/util/string_test.go new file mode 100644 index 0000000..0a4a8bb --- /dev/null +++ b/modules/util/string_test.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestToSnakeCase(t *testing.T) { + cases := map[string]string{ + // all old cases from the legacy package + "HTTPServer": "http_server", + "_camelCase": "_camel_case", + "NoHTTPS": "no_https", + "Wi_thF": "wi_th_f", + "_AnotherTES_TCaseP": "_another_tes_t_case_p", + "ALL": "all", + "_HELLO_WORLD_": "_hello_world_", + "HELLO_WORLD": "hello_world", + "HELLO____WORLD": "hello____world", + "TW": "tw", + "_C": "_c", + + " sentence case ": "__sentence_case__", + " Mixed-hyphen case _and SENTENCE_case and UPPER-case": "_mixed_hyphen_case__and_sentence_case_and_upper_case", + + // new cases + " ": "_", + "A": "a", + "A0": "a0", + "a0": "a0", + "Aa0": "aa0", + "啊": "啊", + "A啊": "a啊", + "Aa啊b": "aa啊b", + "A啊B": "a啊_b", + "Aa啊B": "aa啊_b", + "TheCase2": "the_case2", + "ObjIDs": "obj_i_ds", // the strange database column name which already exists + } + for input, expected := range cases { + assert.Equal(t, expected, ToSnakeCase(input)) + } +} diff --git a/modules/util/timer.go b/modules/util/timer.go new file mode 100644 index 0000000..f9a7950 --- /dev/null +++ b/modules/util/timer.go @@ -0,0 +1,36 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "sync" + "time" +) + +func Debounce(d time.Duration) func(f func()) { + type debouncer struct { + mu sync.Mutex + t *time.Timer + } + db := &debouncer{} + + return func(f func()) { + db.mu.Lock() + defer db.mu.Unlock() + + if db.t != nil { + db.t.Stop() + } + var trigger *time.Timer + trigger = time.AfterFunc(d, func() { + db.mu.Lock() + defer db.mu.Unlock() + if trigger == db.t { + f() + db.t = nil + } + }) + db.t = trigger + } +} diff --git a/modules/util/timer_test.go b/modules/util/timer_test.go new file mode 100644 index 0000000..602800c --- /dev/null +++ b/modules/util/timer_test.go @@ -0,0 +1,30 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDebounce(t *testing.T) { + var c int64 + d := Debounce(50 * time.Millisecond) + d(func() { atomic.AddInt64(&c, 1) }) + assert.EqualValues(t, 0, atomic.LoadInt64(&c)) + d(func() { atomic.AddInt64(&c, 1) }) + d(func() { atomic.AddInt64(&c, 1) }) + time.Sleep(100 * time.Millisecond) + assert.EqualValues(t, 1, atomic.LoadInt64(&c)) + d(func() { atomic.AddInt64(&c, 1) }) + assert.EqualValues(t, 1, atomic.LoadInt64(&c)) + d(func() { atomic.AddInt64(&c, 1) }) + d(func() { atomic.AddInt64(&c, 1) }) + d(func() { atomic.AddInt64(&c, 1) }) + time.Sleep(100 * time.Millisecond) + assert.EqualValues(t, 2, atomic.LoadInt64(&c)) +} diff --git a/modules/util/truncate.go b/modules/util/truncate.go new file mode 100644 index 0000000..77b116e --- /dev/null +++ b/modules/util/truncate.go @@ -0,0 +1,54 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "strings" + "unicode/utf8" +) + +// in UTF8 "…" is 3 bytes so doesn't really gain us anything... +const ( + utf8Ellipsis = "…" + asciiEllipsis = "..." +) + +// SplitStringAtByteN splits a string at byte n accounting for rune boundaries. (Combining characters are not accounted for.) +func SplitStringAtByteN(input string, n int) (left, right string) { + if len(input) <= n { + return input, "" + } + + if !utf8.ValidString(input) { + if n-3 < 0 { + return input, "" + } + return input[:n-3] + asciiEllipsis, asciiEllipsis + input[n-3:] + } + + end := 0 + for end <= n-3 { + _, size := utf8.DecodeRuneInString(input[end:]) + if end+size > n-3 { + break + } + end += size + } + + return input[:end] + utf8Ellipsis, utf8Ellipsis + input[end:] +} + +// SplitTrimSpace splits the string at given separator and trims leading and trailing space +func SplitTrimSpace(input, sep string) []string { + // replace CRLF with LF + input = strings.ReplaceAll(input, "\r\n", "\n") + + var stringList []string + for _, s := range strings.Split(input, sep) { + // trim leading and trailing space + stringList = append(stringList, strings.TrimSpace(s)) + } + + return stringList +} diff --git a/modules/util/truncate_test.go b/modules/util/truncate_test.go new file mode 100644 index 0000000..dfe1230 --- /dev/null +++ b/modules/util/truncate_test.go @@ -0,0 +1,46 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitString(t *testing.T) { + type testCase struct { + input string + n int + leftSub string + ellipsis string + } + + test := func(tc []*testCase, f func(input string, n int) (left, right string)) { + for _, c := range tc { + l, r := f(c.input, c.n) + if c.ellipsis != "" { + assert.Equal(t, c.leftSub+c.ellipsis, l, "test split %q at %d, expected leftSub: %q", c.input, c.n, c.leftSub) + assert.Equal(t, c.ellipsis+c.input[len(c.leftSub):], r, "test split %s at %d, expected rightSub: %q", c.input, c.n, c.input[len(c.leftSub):]) + } else { + assert.Equal(t, c.leftSub, l, "test split %q at %d, expected leftSub: %q", c.input, c.n, c.leftSub) + assert.Empty(t, r, "test split %q at %d, expected rightSub: %q", c.input, c.n, "") + } + } + } + + tc := []*testCase{ + {"abc123xyz", 0, "", utf8Ellipsis}, + {"abc123xyz", 1, "", utf8Ellipsis}, + {"abc123xyz", 4, "a", utf8Ellipsis}, + {"啊bc123xyz", 4, "", utf8Ellipsis}, + {"啊bc123xyz", 6, "啊", utf8Ellipsis}, + {"啊bc", 5, "啊bc", ""}, + {"啊bc", 6, "啊bc", ""}, + {"abc\xef\x03\xfe", 3, "", asciiEllipsis}, + {"abc\xef\x03\xfe", 4, "a", asciiEllipsis}, + {"\xef\x03", 1, "\xef\x03", ""}, + } + test(tc, SplitStringAtByteN) +} diff --git a/modules/util/url.go b/modules/util/url.go new file mode 100644 index 0000000..6237033 --- /dev/null +++ b/modules/util/url.go @@ -0,0 +1,50 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "net/url" + "path" + "strings" +) + +// PathEscapeSegments escapes segments of a path while not escaping forward slash +func PathEscapeSegments(path string) string { + slice := strings.Split(path, "/") + for index := range slice { + slice[index] = url.PathEscape(slice[index]) + } + escapedPath := strings.Join(slice, "/") + return escapedPath +} + +// URLJoin joins url components, like path.Join, but preserving contents +func URLJoin(base string, elems ...string) string { + if !strings.HasSuffix(base, "/") { + base += "/" + } + baseURL, err := url.Parse(base) + if err != nil { + return "" + } + joinedPath := path.Join(elems...) + argURL, err := url.Parse(joinedPath) + if err != nil { + return "" + } + joinedURL := baseURL.ResolveReference(argURL).String() + if !baseURL.IsAbs() && !strings.HasPrefix(base, "/") { + return joinedURL[1:] // Removing leading '/' if needed + } + return joinedURL +} + +func SanitizeURL(s string) (string, error) { + u, err := url.Parse(s) + if err != nil { + return "", err + } + u.User = nil + return u.String(), nil +} diff --git a/modules/util/util.go b/modules/util/util.go new file mode 100644 index 0000000..dcd7cf4 --- /dev/null +++ b/modules/util/util.go @@ -0,0 +1,264 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "fmt" + "math/big" + "strconv" + "strings" + + "code.gitea.io/gitea/modules/optional" + + "golang.org/x/crypto/ssh" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// OptionalBoolParse get the corresponding optional.Option[bool] of a string using strconv.ParseBool +func OptionalBoolParse(s string) optional.Option[bool] { + v, e := strconv.ParseBool(s) + if e != nil { + return optional.None[bool]() + } + return optional.Some(v) +} + +// IsEmptyString checks if the provided string is empty +func IsEmptyString(s string) bool { + return len(strings.TrimSpace(s)) == 0 +} + +// NormalizeEOL will convert Windows (CRLF) and Mac (CR) EOLs to UNIX (LF) +func NormalizeEOL(input []byte) []byte { + var right, left, pos int + if right = bytes.IndexByte(input, '\r'); right == -1 { + return input + } + length := len(input) + tmp := make([]byte, length) + + // We know that left < length because otherwise right would be -1 from IndexByte. + copy(tmp[pos:pos+right], input[left:left+right]) + pos += right + tmp[pos] = '\n' + left += right + 1 + pos++ + + for left < length { + if input[left] == '\n' { + left++ + } + + right = bytes.IndexByte(input[left:], '\r') + if right == -1 { + copy(tmp[pos:], input[left:]) + pos += length - left + break + } + copy(tmp[pos:pos+right], input[left:left+right]) + pos += right + tmp[pos] = '\n' + left += right + 1 + pos++ + } + return tmp[:pos] +} + +// CryptoRandomInt returns a crypto random integer between 0 and limit, inclusive +func CryptoRandomInt(limit int64) (int64, error) { + rInt, err := rand.Int(rand.Reader, big.NewInt(limit)) + if err != nil { + return 0, err + } + return rInt.Int64(), nil +} + +const alphanumericalChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +// CryptoRandomString generates a crypto random alphanumerical string, each byte is generated by [0,61] range +func CryptoRandomString(length int64) (string, error) { + buf := make([]byte, length) + limit := int64(len(alphanumericalChars)) + for i := range buf { + num, err := CryptoRandomInt(limit) + if err != nil { + return "", err + } + buf[i] = alphanumericalChars[num] + } + return string(buf), nil +} + +// CryptoRandomBytes generates `length` crypto bytes +// This differs from CryptoRandomString, as each byte in CryptoRandomString is generated by [0,61] range +// This function generates totally random bytes, each byte is generated by [0,255] range +func CryptoRandomBytes(length int64) ([]byte, error) { + buf := make([]byte, length) + _, err := rand.Read(buf) + return buf, err +} + +// ToUpperASCII returns s with all ASCII letters mapped to their upper case. +func ToUpperASCII(s string) string { + b := []byte(s) + for i, c := range b { + if 'a' <= c && c <= 'z' { + b[i] -= 'a' - 'A' + } + } + return string(b) +} + +// ToTitleCase returns s with all english words capitalized +func ToTitleCase(s string) string { + // `cases.Title` is not thread-safe, do not use global shared variable for it + return cases.Title(language.English).String(s) +} + +// ToTitleCaseNoLower returns s with all english words capitalized without lower-casing +func ToTitleCaseNoLower(s string) string { + // `cases.Title` is not thread-safe, do not use global shared variable for it + return cases.Title(language.English, cases.NoLower).String(s) +} + +// ToInt64 transform a given int into int64. +func ToInt64(number any) (int64, error) { + var value int64 + switch v := number.(type) { + case int: + value = int64(v) + case int8: + value = int64(v) + case int16: + value = int64(v) + case int32: + value = int64(v) + case int64: + value = v + + case uint: + value = int64(v) + case uint8: + value = int64(v) + case uint16: + value = int64(v) + case uint32: + value = int64(v) + case uint64: + value = int64(v) + + case float32: + value = int64(v) + case float64: + value = int64(v) + + case string: + var err error + if value, err = strconv.ParseInt(v, 10, 64); err != nil { + return 0, err + } + default: + return 0, fmt.Errorf("unable to convert %v to int64", number) + } + return value, nil +} + +// ToFloat64 transform a given int into float64. +func ToFloat64(number any) (float64, error) { + var value float64 + switch v := number.(type) { + case int: + value = float64(v) + case int8: + value = float64(v) + case int16: + value = float64(v) + case int32: + value = float64(v) + case int64: + value = float64(v) + + case uint: + value = float64(v) + case uint8: + value = float64(v) + case uint16: + value = float64(v) + case uint32: + value = float64(v) + case uint64: + value = float64(v) + + case float32: + value = float64(v) + case float64: + value = v + + case string: + var err error + if value, err = strconv.ParseFloat(v, 64); err != nil { + return 0, err + } + default: + return 0, fmt.Errorf("unable to convert %v to float64", number) + } + return value, nil +} + +// ToPointer returns the pointer of a copy of any given value +func ToPointer[T any](val T) *T { + return &val +} + +// Iif is an "inline-if", it returns "trueVal" if "condition" is true, otherwise "falseVal" +func Iif[T any](condition bool, trueVal, falseVal T) T { + if condition { + return trueVal + } + return falseVal +} + +// IfZero returns "def" if "v" is a zero value, otherwise "v" +func IfZero[T comparable](v, def T) T { + var zero T + if v == zero { + return def + } + return v +} + +func ReserveLineBreakForTextarea(input string) string { + // Since the content is from a form which is a textarea, the line endings are \r\n. + // It's a standard behavior of HTML. + // But we want to store them as \n like what GitHub does. + // And users are unlikely to really need to keep the \r. + // Other than this, we should respect the original content, even leading or trailing spaces. + return strings.ReplaceAll(input, "\r\n", "\n") +} + +// GenerateSSHKeypair generates a ed25519 SSH-compatible keypair. +func GenerateSSHKeypair() (publicKey, privateKey []byte, err error) { + public, private, err := ed25519.GenerateKey(nil) + if err != nil { + return nil, nil, fmt.Errorf("ed25519.GenerateKey: %w", err) + } + + privPEM, err := ssh.MarshalPrivateKey(private, "") + if err != nil { + return nil, nil, fmt.Errorf("ssh.MarshalPrivateKey: %w", err) + } + + sshPublicKey, err := ssh.NewPublicKey(public) + if err != nil { + return nil, nil, fmt.Errorf("ssh.NewPublicKey: %w", err) + } + + return ssh.MarshalAuthorizedKey(sshPublicKey), pem.EncodeToMemory(privPEM), nil +} diff --git a/modules/util/util_test.go b/modules/util/util_test.go new file mode 100644 index 0000000..549b53f --- /dev/null +++ b/modules/util/util_test.go @@ -0,0 +1,277 @@ +// Copyright 2018 The Gitea Authors. All rights reserved. +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package util_test + +import ( + "bytes" + "crypto/rand" + "regexp" + "strings" + "testing" + + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/test" + "code.gitea.io/gitea/modules/util" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestURLJoin(t *testing.T) { + type test struct { + Expected string + Base string + Elements []string + } + newTest := func(expected, base string, elements ...string) test { + return test{Expected: expected, Base: base, Elements: elements} + } + for _, test := range []test{ + newTest("https://try.gitea.io/a/b/c", + "https://try.gitea.io", "a/b", "c"), + newTest("https://try.gitea.io/a/b/c", + "https://try.gitea.io/", "/a/b/", "/c/"), + newTest("https://try.gitea.io/a/c", + "https://try.gitea.io/", "/a/./b/", "../c/"), + newTest("a/b/c", + "a", "b/c/"), + newTest("a/b/d", + "a/", "b/c/", "/../d/"), + newTest("https://try.gitea.io/a/b/c#d", + "https://try.gitea.io", "a/b", "c#d"), + newTest("/a/b/d", + "/a/", "b/c/", "/../d/"), + newTest("/a/b/c", + "/a", "b/c/"), + newTest("/a/b/c#hash", + "/a", "b/c#hash"), + } { + assert.Equal(t, test.Expected, util.URLJoin(test.Base, test.Elements...)) + } +} + +func TestIsEmptyString(t *testing.T) { + cases := []struct { + s string + expected bool + }{ + {"", true}, + {" ", true}, + {" ", true}, + {" a", false}, + } + + for _, v := range cases { + assert.Equal(t, v.expected, util.IsEmptyString(v.s)) + } +} + +func Test_NormalizeEOL(t *testing.T) { + data1 := []string{ + "", + "This text starts with empty lines", + "another", + "", + "", + "", + "Some other empty lines in the middle", + "more.", + "And more.", + "Ends with empty lines too.", + "", + "", + "", + } + + data2 := []string{ + "This text does not start with empty lines", + "another", + "", + "", + "", + "Some other empty lines in the middle", + "more.", + "And more.", + "Ends without EOLtoo.", + } + + buildEOLData := func(data []string, eol string) []byte { + return []byte(strings.Join(data, eol)) + } + + dos := buildEOLData(data1, "\r\n") + unix := buildEOLData(data1, "\n") + mac := buildEOLData(data1, "\r") + + assert.Equal(t, unix, util.NormalizeEOL(dos)) + assert.Equal(t, unix, util.NormalizeEOL(mac)) + assert.Equal(t, unix, util.NormalizeEOL(unix)) + + dos = buildEOLData(data2, "\r\n") + unix = buildEOLData(data2, "\n") + mac = buildEOLData(data2, "\r") + + assert.Equal(t, unix, util.NormalizeEOL(dos)) + assert.Equal(t, unix, util.NormalizeEOL(mac)) + assert.Equal(t, unix, util.NormalizeEOL(unix)) + + assert.Equal(t, []byte("one liner"), util.NormalizeEOL([]byte("one liner"))) + assert.Equal(t, []byte("\n"), util.NormalizeEOL([]byte("\n"))) + assert.Equal(t, []byte("\ntwo liner"), util.NormalizeEOL([]byte("\ntwo liner"))) + assert.Equal(t, []byte("two liner\n"), util.NormalizeEOL([]byte("two liner\n"))) + assert.Equal(t, []byte{}, util.NormalizeEOL([]byte{})) + + assert.Equal(t, []byte("mix\nand\nmatch\n."), util.NormalizeEOL([]byte("mix\r\nand\rmatch\n."))) +} + +func Test_RandomInt(t *testing.T) { + randInt, err := util.CryptoRandomInt(255) + assert.GreaterOrEqual(t, randInt, int64(0)) + assert.LessOrEqual(t, randInt, int64(255)) + require.NoError(t, err) +} + +func Test_RandomString(t *testing.T) { + str1, err := util.CryptoRandomString(32) + require.NoError(t, err) + matches, err := regexp.MatchString(`^[a-zA-Z0-9]{32}$`, str1) + require.NoError(t, err) + assert.True(t, matches) + + str2, err := util.CryptoRandomString(32) + require.NoError(t, err) + matches, err = regexp.MatchString(`^[a-zA-Z0-9]{32}$`, str1) + require.NoError(t, err) + assert.True(t, matches) + + assert.NotEqual(t, str1, str2) + + str3, err := util.CryptoRandomString(256) + require.NoError(t, err) + matches, err = regexp.MatchString(`^[a-zA-Z0-9]{256}$`, str3) + require.NoError(t, err) + assert.True(t, matches) + + str4, err := util.CryptoRandomString(256) + require.NoError(t, err) + matches, err = regexp.MatchString(`^[a-zA-Z0-9]{256}$`, str4) + require.NoError(t, err) + assert.True(t, matches) + + assert.NotEqual(t, str3, str4) +} + +func Test_RandomBytes(t *testing.T) { + bytes1, err := util.CryptoRandomBytes(32) + require.NoError(t, err) + + bytes2, err := util.CryptoRandomBytes(32) + require.NoError(t, err) + + assert.NotEqual(t, bytes1, bytes2) + + bytes3, err := util.CryptoRandomBytes(256) + require.NoError(t, err) + + bytes4, err := util.CryptoRandomBytes(256) + require.NoError(t, err) + + assert.NotEqual(t, bytes3, bytes4) +} + +func TestOptionalBoolParse(t *testing.T) { + assert.Equal(t, optional.None[bool](), util.OptionalBoolParse("")) + assert.Equal(t, optional.None[bool](), util.OptionalBoolParse("x")) + + assert.Equal(t, optional.Some(false), util.OptionalBoolParse("0")) + assert.Equal(t, optional.Some(false), util.OptionalBoolParse("f")) + assert.Equal(t, optional.Some(false), util.OptionalBoolParse("False")) + + assert.Equal(t, optional.Some(true), util.OptionalBoolParse("1")) + assert.Equal(t, optional.Some(true), util.OptionalBoolParse("t")) + assert.Equal(t, optional.Some(true), util.OptionalBoolParse("True")) +} + +// Test case for any function which accepts and returns a single string. +type StringTest struct { + in, out string +} + +var upperTests = []StringTest{ + {"", ""}, + {"ONLYUPPER", "ONLYUPPER"}, + {"abc", "ABC"}, + {"AbC123", "ABC123"}, + {"azAZ09_", "AZAZ09_"}, + {"longStrinGwitHmixofsmaLLandcAps", "LONGSTRINGWITHMIXOFSMALLANDCAPS"}, + {"long\u0250string\u0250with\u0250nonascii\u2C6Fchars", "LONG\u0250STRING\u0250WITH\u0250NONASCII\u2C6FCHARS"}, + {"\u0250\u0250\u0250\u0250\u0250", "\u0250\u0250\u0250\u0250\u0250"}, + {"a\u0080\U0010FFFF", "A\u0080\U0010FFFF"}, + {"lél", "LéL"}, +} + +func TestToUpperASCII(t *testing.T) { + for _, tc := range upperTests { + assert.Equal(t, util.ToUpperASCII(tc.in), tc.out) + } +} + +func BenchmarkToUpper(b *testing.B) { + for _, tc := range upperTests { + b.Run(tc.in, func(b *testing.B) { + for i := 0; i < b.N; i++ { + util.ToUpperASCII(tc.in) + } + }) + } +} + +func TestToTitleCase(t *testing.T) { + assert.Equal(t, `Foo Bar Baz`, util.ToTitleCase(`foo bar baz`)) + assert.Equal(t, `Foo Bar Baz`, util.ToTitleCase(`FOO BAR BAZ`)) +} + +func TestToPointer(t *testing.T) { + assert.Equal(t, "abc", *util.ToPointer("abc")) + assert.Equal(t, 123, *util.ToPointer(123)) + abc := "abc" + assert.NotSame(t, &abc, util.ToPointer(abc)) + val123 := 123 + assert.NotSame(t, &val123, util.ToPointer(val123)) +} + +func TestReserveLineBreakForTextarea(t *testing.T) { + assert.Equal(t, "test\ndata", util.ReserveLineBreakForTextarea("test\r\ndata")) + assert.Equal(t, "test\ndata\n", util.ReserveLineBreakForTextarea("test\r\ndata\r\n")) +} + +const ( + testPublicKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAOhB7/zzhC+HXDdGOdLwJln5NYwm6UNXx3chmQSVTG4\n" + testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz +c2gtZWQyNTUxOQAAACADoQe/884Qvh1w3RjnS8CZZ+TWMJulDV8d3IZkElUxuAAA +AIggISIjICEiIwAAAAtzc2gtZWQyNTUxOQAAACADoQe/884Qvh1w3RjnS8CZZ+TW +MJulDV8d3IZkElUxuAAAAEAAAQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0e +HwOhB7/zzhC+HXDdGOdLwJln5NYwm6UNXx3chmQSVTG4AAAAAAECAwQF +-----END OPENSSH PRIVATE KEY-----` + "\n" +) + +func TestGeneratingEd25519Keypair(t *testing.T) { + defer test.MockProtect(&rand.Reader)() + + // Only 32 bytes needs to be provided to generate a ed25519 keypair. + // And another 32 bytes are required, which is included as random value + // in the OpenSSH format. + b := make([]byte, 64) + for i := 0; i < 64; i++ { + b[i] = byte(i) + } + rand.Reader = bytes.NewReader(b) + + publicKey, privateKey, err := util.GenerateSSHKeypair() + require.NoError(t, err) + assert.EqualValues(t, testPublicKey, string(publicKey)) + assert.EqualValues(t, testPrivateKey, string(privateKey)) +} |