From e68b9d00a6e05b3a941f63ffb696f91e554ac5ec Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 18 Oct 2024 20:33:49 +0200 Subject: Adding upstream version 9.0.3. Signed-off-by: Daniel Baumann --- models/db/collation.go | 159 +++++++++++++++ models/db/common.go | 53 +++++ models/db/consistency.go | 31 +++ models/db/context.go | 331 +++++++++++++++++++++++++++++++ models/db/context_committer_test.go | 102 ++++++++++ models/db/context_test.go | 87 +++++++++ models/db/convert.go | 64 ++++++ models/db/engine.go | 354 ++++++++++++++++++++++++++++++++++ models/db/engine_test.go | 154 +++++++++++++++ models/db/error.go | 74 +++++++ models/db/index.go | 148 ++++++++++++++ models/db/index_test.go | 127 ++++++++++++ models/db/install/db.go | 64 ++++++ models/db/iterate.go | 43 +++++ models/db/iterate_test.go | 45 +++++ models/db/list.go | 215 +++++++++++++++++++++ models/db/list_test.go | 53 +++++ models/db/log.go | 107 ++++++++++ models/db/main_test.go | 17 ++ models/db/name.go | 106 ++++++++++ models/db/paginator/main_test.go | 14 ++ models/db/paginator/paginator.go | 7 + models/db/paginator/paginator_test.go | 59 ++++++ models/db/search.go | 33 ++++ models/db/sequence.go | 70 +++++++ models/db/sql_postgres_with_schema.go | 74 +++++++ 26 files changed, 2591 insertions(+) create mode 100644 models/db/collation.go create mode 100644 models/db/common.go create mode 100644 models/db/consistency.go create mode 100644 models/db/context.go create mode 100644 models/db/context_committer_test.go create mode 100644 models/db/context_test.go create mode 100644 models/db/convert.go create mode 100755 models/db/engine.go create mode 100644 models/db/engine_test.go create mode 100644 models/db/error.go create mode 100644 models/db/index.go create mode 100644 models/db/index_test.go create mode 100644 models/db/install/db.go create mode 100644 models/db/iterate.go create mode 100644 models/db/iterate_test.go create mode 100644 models/db/list.go create mode 100644 models/db/list_test.go create mode 100644 models/db/log.go create mode 100644 models/db/main_test.go create mode 100644 models/db/name.go create mode 100644 models/db/paginator/main_test.go create mode 100644 models/db/paginator/paginator.go create mode 100644 models/db/paginator/paginator_test.go create mode 100644 models/db/search.go create mode 100644 models/db/sequence.go create mode 100644 models/db/sql_postgres_with_schema.go (limited to 'models/db') diff --git a/models/db/collation.go b/models/db/collation.go new file mode 100644 index 0000000..39d28fa --- /dev/null +++ b/models/db/collation.go @@ -0,0 +1,159 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "errors" + "fmt" + "strings" + + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "xorm.io/xorm" + "xorm.io/xorm/schemas" +) + +type CheckCollationsResult struct { + ExpectedCollation string + AvailableCollation container.Set[string] + DatabaseCollation string + IsCollationCaseSensitive func(s string) bool + CollationEquals func(a, b string) bool + ExistingTableNumber int + + InconsistentCollationColumns []string +} + +func findAvailableCollationsMySQL(x *xorm.Engine) (ret container.Set[string], err error) { + var res []struct { + Collation string + } + if err = x.SQL("SHOW COLLATION WHERE (Collation = 'utf8mb4_bin') OR (Collation LIKE '%\\_as\\_cs%')").Find(&res); err != nil { + return nil, err + } + ret = make(container.Set[string], len(res)) + for _, r := range res { + ret.Add(r.Collation) + } + return ret, nil +} + +func CheckCollations(x *xorm.Engine) (*CheckCollationsResult, error) { + dbTables, err := x.DBMetas() + if err != nil { + return nil, err + } + + res := &CheckCollationsResult{ + ExistingTableNumber: len(dbTables), + CollationEquals: func(a, b string) bool { return a == b }, + } + + var candidateCollations []string + if x.Dialect().URI().DBType == schemas.MYSQL { + if _, err = x.SQL("SELECT @@collation_database").Get(&res.DatabaseCollation); err != nil { + return nil, err + } + res.IsCollationCaseSensitive = func(s string) bool { + return s == "utf8mb4_bin" || strings.HasSuffix(s, "_as_cs") + } + candidateCollations = []string{"utf8mb4_0900_as_cs", "uca1400_as_cs", "utf8mb4_bin"} + res.AvailableCollation, err = findAvailableCollationsMySQL(x) + if err != nil { + return nil, err + } + res.CollationEquals = func(a, b string) bool { + // MariaDB adds the "utf8mb4_" prefix, eg: "utf8mb4_uca1400_as_cs", but not the name "uca1400_as_cs" in "SHOW COLLATION" + // At the moment, it's safe to ignore the database difference, just trim the prefix and compare. It could be fixed easily if there is any problem in the future. + return a == b || strings.TrimPrefix(a, "utf8mb4_") == strings.TrimPrefix(b, "utf8mb4_") + } + } else { + return nil, nil + } + + if res.DatabaseCollation == "" { + return nil, errors.New("unable to get collation for current database") + } + + res.ExpectedCollation = setting.Database.CharsetCollation + if res.ExpectedCollation == "" { + for _, collation := range candidateCollations { + if res.AvailableCollation.Contains(collation) { + res.ExpectedCollation = collation + break + } + } + } + + if res.ExpectedCollation == "" { + return nil, errors.New("unable to find a suitable collation for current database") + } + + allColumnsMatchExpected := true + allColumnsMatchDatabase := true + for _, table := range dbTables { + for _, col := range table.Columns() { + if col.Collation != "" { + allColumnsMatchExpected = allColumnsMatchExpected && res.CollationEquals(col.Collation, res.ExpectedCollation) + allColumnsMatchDatabase = allColumnsMatchDatabase && res.CollationEquals(col.Collation, res.DatabaseCollation) + if !res.IsCollationCaseSensitive(col.Collation) || !res.CollationEquals(col.Collation, res.DatabaseCollation) { + res.InconsistentCollationColumns = append(res.InconsistentCollationColumns, fmt.Sprintf("%s.%s", table.Name, col.Name)) + } + } + } + } + // if all columns match expected collation or all match database collation, then it could also be considered as "consistent" + if allColumnsMatchExpected || allColumnsMatchDatabase { + res.InconsistentCollationColumns = nil + } + return res, nil +} + +func CheckCollationsDefaultEngine() (*CheckCollationsResult, error) { + return CheckCollations(x) +} + +func alterDatabaseCollation(x *xorm.Engine, collation string) error { + if x.Dialect().URI().DBType == schemas.MYSQL { + _, err := x.Exec("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE " + collation) + return err + } + return errors.New("unsupported database type") +} + +// preprocessDatabaseCollation checks database & table column collation, and alter the database collation if needed +func preprocessDatabaseCollation(x *xorm.Engine) { + r, err := CheckCollations(x) + if err != nil { + log.Error("Failed to check database collation: %v", err) + } + if r == nil { + return // no check result means the database doesn't need to do such check/process (at the moment ....) + } + + // try to alter database collation to expected if the database is empty, it might fail in some cases (and it isn't necessary to succeed) + // at the moment. + if !r.CollationEquals(r.DatabaseCollation, r.ExpectedCollation) && r.ExistingTableNumber == 0 { + if err = alterDatabaseCollation(x, r.ExpectedCollation); err != nil { + log.Error("Failed to change database collation to %q: %v", r.ExpectedCollation, err) + } else { + if r, err = CheckCollations(x); err != nil { + log.Error("Failed to check database collation again after altering: %v", err) // impossible case + return + } + log.Warn("Current database has been altered to use collation %q", r.DatabaseCollation) + } + } + + // check column collation, and show warning/error to end users -- no need to fatal, do not block the startup + if !r.IsCollationCaseSensitive(r.DatabaseCollation) { + log.Warn("Current database is using a case-insensitive collation %q, although Forgejo could work with it, there might be some rare cases which don't work as expected.", r.DatabaseCollation) + } + + if len(r.InconsistentCollationColumns) > 0 { + log.Error("There are %d table columns using inconsistent collation, they should use %q. Please go to admin panel Self Check page", len(r.InconsistentCollationColumns), r.DatabaseCollation) + } +} diff --git a/models/db/common.go b/models/db/common.go new file mode 100644 index 0000000..f3fd3e7 --- /dev/null +++ b/models/db/common.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "strings" + + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" + + "xorm.io/builder" +) + +// BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively. +// Handles especially SQLite correctly as UPPER there only transforms ASCII letters. +func BuildCaseInsensitiveLike(key, value string) builder.Cond { + if setting.Database.Type.IsSQLite3() { + return builder.Like{"UPPER(" + key + ")", util.ToUpperASCII(value)} + } + return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)} +} + +// BuildCaseInsensitiveIn returns a condition to check if the given value is in the given values case-insensitively. +// Handles especially SQLite correctly as UPPER there only transforms ASCII letters. +func BuildCaseInsensitiveIn(key string, values []string) builder.Cond { + uppers := make([]string, 0, len(values)) + if setting.Database.Type.IsSQLite3() { + for _, value := range values { + uppers = append(uppers, util.ToUpperASCII(value)) + } + } else { + for _, value := range values { + uppers = append(uppers, strings.ToUpper(value)) + } + } + + return builder.In("UPPER("+key+")", uppers) +} + +// BuilderDialect returns the xorm.Builder dialect of the engine +func BuilderDialect() string { + switch { + case setting.Database.Type.IsMySQL(): + return builder.MYSQL + case setting.Database.Type.IsSQLite3(): + return builder.SQLITE + case setting.Database.Type.IsPostgreSQL(): + return builder.POSTGRES + default: + return "" + } +} diff --git a/models/db/consistency.go b/models/db/consistency.go new file mode 100644 index 0000000..d19732c --- /dev/null +++ b/models/db/consistency.go @@ -0,0 +1,31 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + + "xorm.io/builder" +) + +// CountOrphanedObjects count subjects with have no existing refobject anymore +func CountOrphanedObjects(ctx context.Context, subject, refobject, joinCond string) (int64, error) { + return GetEngine(ctx). + Table("`"+subject+"`"). + Join("LEFT", "`"+refobject+"`", joinCond). + Where(builder.IsNull{"`" + refobject + "`.id"}). + Select("COUNT(`" + subject + "`.`id`)"). + Count() +} + +// DeleteOrphanedObjects delete subjects with have no existing refobject anymore +func DeleteOrphanedObjects(ctx context.Context, subject, refobject, joinCond string) error { + subQuery := builder.Select("`"+subject+"`.id"). + From("`"+subject+"`"). + Join("LEFT", "`"+refobject+"`", joinCond). + Where(builder.IsNull{"`" + refobject + "`.id"}) + b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`") + _, err := GetEngine(ctx).Exec(b) + return err +} diff --git a/models/db/context.go b/models/db/context.go new file mode 100644 index 0000000..43f6125 --- /dev/null +++ b/models/db/context.go @@ -0,0 +1,331 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + "database/sql" + + "xorm.io/builder" + "xorm.io/xorm" +) + +// DefaultContext is the default context to run xorm queries in +// will be overwritten by Init with HammerContext +var DefaultContext context.Context + +// contextKey is a value for use with context.WithValue. +type contextKey struct { + name string +} + +// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context +var ( + enginedContextKey = &contextKey{"engined"} + _ Engined = &Context{} +) + +// Context represents a db context +type Context struct { + context.Context + e Engine + transaction bool +} + +func newContext(ctx context.Context, e Engine, transaction bool) *Context { + return &Context{ + Context: ctx, + e: e, + transaction: transaction, + } +} + +// InTransaction if context is in a transaction +func (ctx *Context) InTransaction() bool { + return ctx.transaction +} + +// Engine returns db engine +func (ctx *Context) Engine() Engine { + return ctx.e +} + +// Value shadows Value for context.Context but allows us to get ourselves and an Engined object +func (ctx *Context) Value(key any) any { + if key == enginedContextKey { + return ctx + } + return ctx.Context.Value(key) +} + +// WithContext returns this engine tied to this context +func (ctx *Context) WithContext(other context.Context) *Context { + return newContext(ctx, ctx.e.Context(other), ctx.transaction) +} + +// Engined structs provide an Engine +type Engined interface { + Engine() Engine +} + +// GetEngine will get a db Engine from this context or return an Engine restricted to this context +func GetEngine(ctx context.Context) Engine { + if e := getEngine(ctx); e != nil { + return e + } + return x.Context(ctx) +} + +// getEngine will get a db Engine from this context or return nil +func getEngine(ctx context.Context) Engine { + if engined, ok := ctx.(Engined); ok { + return engined.Engine() + } + enginedInterface := ctx.Value(enginedContextKey) + if enginedInterface != nil { + return enginedInterface.(Engined).Engine() + } + return nil +} + +// Committer represents an interface to Commit or Close the Context +type Committer interface { + Commit() error + Close() error +} + +// halfCommitter is a wrapper of Committer. +// It can be closed early, but can't be committed early, it is useful for reusing a transaction. +type halfCommitter struct { + committer Committer + committed bool +} + +func (c *halfCommitter) Commit() error { + c.committed = true + // should do nothing, and the parent committer will commit later + return nil +} + +func (c *halfCommitter) Close() error { + if c.committed { + // it's "commit and close", should do nothing, and the parent committer will commit later + return nil + } + + // it's "rollback and close", let the parent committer rollback right now + return c.committer.Close() +} + +// TxContext represents a transaction Context, +// it will reuse the existing transaction in the parent context or create a new one. +// Some tips to use: +// +// 1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically. +// 2. To maintain the old code which uses `TxContext`: +// a. Always call `Close()` before returning regardless of whether `Commit()` has been called. +// b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data. +// c. Remember the `Committer` will be a halfCommitter when a transaction is being reused. +// So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction. +// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function. +// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback. +func TxContext(parentCtx context.Context) (*Context, Committer, error) { + if sess, ok := inTransaction(parentCtx); ok { + return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil + } + + sess := x.NewSession() + if err := sess.Begin(); err != nil { + sess.Close() + return nil, nil, err + } + + return newContext(DefaultContext, sess, true), sess, nil +} + +// WithTx represents executing database operations on a transaction, if the transaction exist, +// this function will reuse it otherwise will create a new one and close it when finished. +func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error { + if sess, ok := inTransaction(parentCtx); ok { + err := f(newContext(parentCtx, sess, true)) + if err != nil { + // rollback immediately, in case the caller ignores returned error and tries to commit the transaction. + _ = sess.Close() + } + return err + } + return txWithNoCheck(parentCtx, f) +} + +func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error { + sess := x.NewSession() + defer sess.Close() + if err := sess.Begin(); err != nil { + return err + } + + if err := f(newContext(parentCtx, sess, true)); err != nil { + return err + } + + return sess.Commit() +} + +// Insert inserts records into database +func Insert(ctx context.Context, beans ...any) error { + _, err := GetEngine(ctx).Insert(beans...) + return err +} + +// Exec executes a sql with args +func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) { + return GetEngine(ctx).Exec(sqlAndArgs...) +} + +func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) { + if !cond.IsValid() { + panic("cond is invalid in db.Get(ctx, cond). This should not be possible.") + } + + var bean T + has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean) + if err != nil { + return nil, false, err + } else if !has { + return nil, false, nil + } + return &bean, true, nil +} + +func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) { + var bean T + has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean) + if err != nil { + return nil, false, err + } else if !has { + return nil, false, nil + } + return &bean, true, nil +} + +func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) { + if !cond.IsValid() { + panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.") + } + + var bean T + return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean) +} + +func ExistByID[T any](ctx context.Context, id int64) (bool, error) { + var bean T + return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean) +} + +// DeleteByID deletes the given bean with the given ID +func DeleteByID[T any](ctx context.Context, id int64) (int64, error) { + var bean T + return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean) +} + +func DeleteByIDs[T any](ctx context.Context, ids ...int64) error { + if len(ids) == 0 { + return nil + } + + var bean T + _, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean) + return err +} + +func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) { + if opts == nil || !opts.ToConds().IsValid() { + panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.") + } + + var bean T + return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean) +} + +// DeleteByBean deletes all records according non-empty fields of the bean as conditions. +func DeleteByBean(ctx context.Context, bean any) (int64, error) { + return GetEngine(ctx).Delete(bean) +} + +// FindIDs finds the IDs for the given table name satisfying the given condition +// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition +func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) { + ids := make([]int64, 0, 10) + if err := GetEngine(ctx).Table(tableName). + Cols(idCol). + Where(cond). + Find(&ids); err != nil { + return nil, err + } + return ids, nil +} + +// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one +// Timestamps of the entities won't be updated +func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error { + _, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean) + return err +} + +// DeleteBeans deletes all given beans, beans must contain delete conditions. +func DeleteBeans(ctx context.Context, beans ...any) (err error) { + e := GetEngine(ctx) + for i := range beans { + if _, err = e.Delete(beans[i]); err != nil { + return err + } + } + return nil +} + +// TruncateBeans deletes all given beans, beans may contain delete conditions. +func TruncateBeans(ctx context.Context, beans ...any) (err error) { + e := GetEngine(ctx) + for i := range beans { + if _, err = e.Truncate(beans[i]); err != nil { + return err + } + } + return nil +} + +// CountByBean counts the number of database records according non-empty fields of the bean as conditions. +func CountByBean(ctx context.Context, bean any) (int64, error) { + return GetEngine(ctx).Count(bean) +} + +// TableName returns the table name according a bean object +func TableName(bean any) string { + return x.TableName(bean) +} + +// InTransaction returns true if the engine is in a transaction otherwise return false +func InTransaction(ctx context.Context) bool { + _, ok := inTransaction(ctx) + return ok +} + +func inTransaction(ctx context.Context) (*xorm.Session, bool) { + e := getEngine(ctx) + if e == nil { + return nil, false + } + + switch t := e.(type) { + case *xorm.Engine: + return nil, false + case *xorm.Session: + if t.IsInTx() { + return t, true + } + return nil, false + default: + return nil, false + } +} diff --git a/models/db/context_committer_test.go b/models/db/context_committer_test.go new file mode 100644 index 0000000..38e91f2 --- /dev/null +++ b/models/db/context_committer_test.go @@ -0,0 +1,102 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db // it's not db_test, because this file is for testing the private type halfCommitter + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type MockCommitter struct { + wants []string + gots []string +} + +func NewMockCommitter(wants ...string) *MockCommitter { + return &MockCommitter{ + wants: wants, + } +} + +func (c *MockCommitter) Commit() error { + c.gots = append(c.gots, "commit") + return nil +} + +func (c *MockCommitter) Close() error { + c.gots = append(c.gots, "close") + return nil +} + +func (c *MockCommitter) Assert(t *testing.T) { + assert.Equal(t, c.wants, c.gots, "want operations %v, but got %v", c.wants, c.gots) +} + +func Test_halfCommitter(t *testing.T) { + /* + Do something like: + + ctx, committer, err := db.TxContext(db.DefaultContext) + if err != nil { + return nil + } + defer committer.Close() + + // ... + + if err != nil { + return nil + } + + // ... + + return committer.Commit() + */ + + testWithCommitter := func(committer Committer, f func(committer Committer) error) { + if err := f(&halfCommitter{committer: committer}); err == nil { + committer.Commit() + } + committer.Close() + } + + t.Run("commit and close", func(t *testing.T) { + mockCommitter := NewMockCommitter("commit", "close") + + testWithCommitter(mockCommitter, func(committer Committer) error { + defer committer.Close() + return committer.Commit() + }) + + mockCommitter.Assert(t) + }) + + t.Run("rollback and close", func(t *testing.T) { + mockCommitter := NewMockCommitter("close", "close") + + testWithCommitter(mockCommitter, func(committer Committer) error { + defer committer.Close() + if true { + return fmt.Errorf("error") + } + return committer.Commit() + }) + + mockCommitter.Assert(t) + }) + + t.Run("close and commit", func(t *testing.T) { + mockCommitter := NewMockCommitter("close", "close") + + testWithCommitter(mockCommitter, func(committer Committer) error { + committer.Close() + committer.Commit() + return fmt.Errorf("error") + }) + + mockCommitter.Assert(t) + }) +} diff --git a/models/db/context_test.go b/models/db/context_test.go new file mode 100644 index 0000000..855f360 --- /dev/null +++ b/models/db/context_test.go @@ -0,0 +1,87 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + "context" + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInTransaction(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + assert.False(t, db.InTransaction(db.DefaultContext)) + require.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error { + assert.True(t, db.InTransaction(ctx)) + return nil + })) + + ctx, committer, err := db.TxContext(db.DefaultContext) + require.NoError(t, err) + defer committer.Close() + assert.True(t, db.InTransaction(ctx)) + require.NoError(t, db.WithTx(ctx, func(ctx context.Context) error { + assert.True(t, db.InTransaction(ctx)) + return nil + })) +} + +func TestTxContext(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + { // create new transaction + ctx, committer, err := db.TxContext(db.DefaultContext) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + require.NoError(t, committer.Commit()) + } + + { // reuse the transaction created by TxContext and commit it + ctx, committer, err := db.TxContext(db.DefaultContext) + engine := db.GetEngine(ctx) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.Equal(t, engine, db.GetEngine(ctx)) + require.NoError(t, committer.Commit()) + } + require.NoError(t, committer.Commit()) + } + + { // reuse the transaction created by TxContext and close it + ctx, committer, err := db.TxContext(db.DefaultContext) + engine := db.GetEngine(ctx) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + assert.Equal(t, engine, db.GetEngine(ctx)) + require.NoError(t, committer.Close()) + } + require.NoError(t, committer.Close()) + } + + { // reuse the transaction created by WithTx + require.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error { + assert.True(t, db.InTransaction(ctx)) + { + ctx, committer, err := db.TxContext(ctx) + require.NoError(t, err) + assert.True(t, db.InTransaction(ctx)) + require.NoError(t, committer.Commit()) + } + return nil + })) + } +} diff --git a/models/db/convert.go b/models/db/convert.go new file mode 100644 index 0000000..956e17d --- /dev/null +++ b/models/db/convert.go @@ -0,0 +1,64 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "fmt" + "strconv" + "strings" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "xorm.io/xorm" + "xorm.io/xorm/schemas" +) + +// ConvertDatabaseTable converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic +func ConvertDatabaseTable() error { + if x.Dialect().URI().DBType != schemas.MYSQL { + return nil + } + + r, err := CheckCollations(x) + if err != nil { + return err + } + + databaseName := strings.SplitN(setting.Database.Name, "?", 2)[0] + _, err = x.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE %s", databaseName, r.ExpectedCollation)) + if err != nil { + return err + } + + tables, err := x.DBMetas() + if err != nil { + return err + } + for _, table := range tables { + if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic", table.Name)); err != nil { + return err + } + + if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE %s", table.Name, r.ExpectedCollation)); err != nil { + return err + } + } + + return nil +} + +// Cell2Int64 converts a xorm.Cell type to int64, +// and handles possible irregular cases. +func Cell2Int64(val xorm.Cell) int64 { + switch (*val).(type) { + case []uint8: + log.Trace("Cell2Int64 ([]uint8): %v", *val) + + v, _ := strconv.ParseInt(string((*val).([]uint8)), 10, 64) + return v + default: + return (*val).(int64) + } +} diff --git a/models/db/engine.go b/models/db/engine.go new file mode 100755 index 0000000..6164959 --- /dev/null +++ b/models/db/engine.go @@ -0,0 +1,354 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2018 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "reflect" + "strings" + "time" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "xorm.io/xorm" + "xorm.io/xorm/contexts" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + _ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver + _ "github.com/lib/pq" // Needed for the Postgresql driver +) + +var ( + x *xorm.Engine + tables []any + initFuncs []func() error +) + +// Engine represents a xorm engine or session. +type Engine interface { + Table(tableNameOrBean any) *xorm.Session + Count(...any) (int64, error) + Decr(column string, arg ...any) *xorm.Session + Delete(...any) (int64, error) + Truncate(...any) (int64, error) + Exec(...any) (sql.Result, error) + Find(any, ...any) error + Get(beans ...any) (bool, error) + ID(any) *xorm.Session + In(string, ...any) *xorm.Session + Incr(column string, arg ...any) *xorm.Session + Insert(...any) (int64, error) + Iterate(any, xorm.IterFunc) error + IsTableExist(any) (bool, error) + Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session + SQL(any, ...any) *xorm.Session + Where(any, ...any) *xorm.Session + Asc(colNames ...string) *xorm.Session + Desc(colNames ...string) *xorm.Session + Limit(limit int, start ...int) *xorm.Session + NoAutoTime() *xorm.Session + SumInt(bean any, columnName string) (res int64, err error) + Sync(...any) error + Select(string) *xorm.Session + SetExpr(string, any) *xorm.Session + NotIn(string, ...any) *xorm.Session + OrderBy(any, ...any) *xorm.Session + Exist(...any) (bool, error) + Distinct(...string) *xorm.Session + Query(...any) ([]map[string][]byte, error) + Cols(...string) *xorm.Session + Context(ctx context.Context) *xorm.Session + Ping() error +} + +// TableInfo returns table's information via an object +func TableInfo(v any) (*schemas.Table, error) { + return x.TableInfo(v) +} + +// DumpTables dump tables information +func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { + return x.DumpTables(tables, w, tp...) +} + +// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync +func RegisterModel(bean any, initFunc ...func() error) { + tables = append(tables, bean) + if len(initFuncs) > 0 && initFunc[0] != nil { + initFuncs = append(initFuncs, initFunc[0]) + } +} + +func init() { + gonicNames := []string{"SSL", "UID"} + for _, name := range gonicNames { + names.LintGonicMapper[name] = true + } +} + +// newXORMEngine returns a new XORM engine from the configuration +func newXORMEngine() (*xorm.Engine, error) { + connStr, err := setting.DBConnStr() + if err != nil { + return nil, err + } + + var engine *xorm.Engine + + if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 { + // OK whilst we sort out our schema issues - create a schema aware postgres + registerPostgresSchemaDriver() + engine, err = xorm.NewEngine("postgresschema", connStr) + } else { + engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr) + } + + if err != nil { + return nil, err + } + if setting.Database.Type.IsMySQL() { + engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + } + engine.SetSchema(setting.Database.Schema) + return engine, nil +} + +// SyncAllTables sync the schemas of all tables, is required by unit test code +func SyncAllTables() error { + _, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{ + WarnIfDatabaseColumnMissed: true, + }, tables...) + return err +} + +// InitEngine initializes the xorm.Engine and sets it as db.DefaultContext +func InitEngine(ctx context.Context) error { + xormEngine, err := newXORMEngine() + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + xormEngine.SetMapper(names.GonicMapper{}) + // WARNING: for serv command, MUST remove the output to os.stdout, + // so use log file to instead print to stdout. + xormEngine.SetLogger(NewXORMLogger(setting.Database.LogSQL)) + xormEngine.ShowSQL(setting.Database.LogSQL) + xormEngine.SetMaxOpenConns(setting.Database.MaxOpenConns) + xormEngine.SetMaxIdleConns(setting.Database.MaxIdleConns) + xormEngine.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) + xormEngine.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) + xormEngine.SetDefaultContext(ctx) + + if setting.Database.SlowQueryThreshold > 0 { + xormEngine.AddHook(&SlowQueryHook{ + Treshold: setting.Database.SlowQueryThreshold, + Logger: log.GetLogger("xorm"), + }) + } + + errorLogger := log.GetLogger("xorm") + if setting.IsInTesting { + errorLogger = log.GetLogger(log.DEFAULT) + } + + xormEngine.AddHook(&ErrorQueryHook{ + Logger: errorLogger, + }) + + SetDefaultEngine(ctx, xormEngine) + return nil +} + +// SetDefaultEngine sets the default engine for db +func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) { + x = eng + DefaultContext = &Context{ + Context: ctx, + e: x, + } +} + +// UnsetDefaultEngine closes and unsets the default engine +// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now, +// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `x` and DefaultContext without close +// Global database engine related functions are all racy and there is no graceful close right now. +func UnsetDefaultEngine() { + if x != nil { + _ = x.Close() + x = nil + } + DefaultContext = nil +} + +// InitEngineWithMigration initializes a new xorm.Engine and sets it as the db.DefaultContext +// This function must never call .Sync() if the provided migration function fails. +// When called from the "doctor" command, the migration function is a version check +// that prevents the doctor from fixing anything in the database if the migration level +// is different from the expected value. +func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) { + if err = InitEngine(ctx); err != nil { + return err + } + + if err = x.Ping(); err != nil { + return err + } + + preprocessDatabaseCollation(x) + + // We have to run migrateFunc here in case the user is re-running installation on a previously created DB. + // If we do not then table schemas will be changed and there will be conflicts when the migrations run properly. + // + // Installation should only be being re-run if users want to recover an old database. + // However, we should think carefully about should we support re-install on an installed instance, + // as there may be other problems due to secret reinitialization. + if err = migrateFunc(x); err != nil { + return fmt.Errorf("migrate: %w", err) + } + + if err = SyncAllTables(); err != nil { + return fmt.Errorf("sync database struct error: %w", err) + } + + for _, initFunc := range initFuncs { + if err := initFunc(); err != nil { + return fmt.Errorf("initFunc failed: %w", err) + } + } + + return nil +} + +// NamesToBean return a list of beans or an error +func NamesToBean(names ...string) ([]any, error) { + beans := []any{} + if len(names) == 0 { + beans = append(beans, tables...) + return beans, nil + } + // Need to map provided names to beans... + beanMap := make(map[string]any) + for _, bean := range tables { + beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean + beanMap[strings.ToLower(x.TableName(bean))] = bean + beanMap[strings.ToLower(x.TableName(bean, true))] = bean + } + + gotBean := make(map[any]bool) + for _, name := range names { + bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))] + if !ok { + return nil, fmt.Errorf("no table found that matches: %s", name) + } + if !gotBean[bean] { + beans = append(beans, bean) + gotBean[bean] = true + } + } + return beans, nil +} + +// DumpDatabase dumps all data from database according the special database SQL syntax to file system. +func DumpDatabase(filePath, dbType string) error { + var tbs []*schemas.Table + for _, t := range tables { + t, err := x.TableInfo(t) + if err != nil { + return err + } + tbs = append(tbs, t) + } + + type Version struct { + ID int64 `xorm:"pk autoincr"` + Version int64 + } + t, err := x.TableInfo(&Version{}) + if err != nil { + return err + } + tbs = append(tbs, t) + + if len(dbType) > 0 { + return x.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType)) + } + return x.DumpTablesToFile(tbs, filePath) +} + +// MaxBatchInsertSize returns the table's max batch insert size +func MaxBatchInsertSize(bean any) int { + t, err := x.TableInfo(bean) + if err != nil { + return 50 + } + return 999 / len(t.ColumnsSeq()) +} + +// IsTableNotEmpty returns true if table has at least one record +func IsTableNotEmpty(beanOrTableName any) (bool, error) { + return x.Table(beanOrTableName).Exist() +} + +// DeleteAllRecords will delete all the records of this table +func DeleteAllRecords(tableName string) error { + _, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName)) + return err +} + +// GetMaxID will return max id of the table +func GetMaxID(beanOrTableName any) (maxID int64, err error) { + _, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID) + return maxID, err +} + +func SetLogSQL(ctx context.Context, on bool) { + e := GetEngine(ctx) + if x, ok := e.(*xorm.Engine); ok { + x.ShowSQL(on) + } else if sess, ok := e.(*xorm.Session); ok { + sess.Engine().ShowSQL(on) + } +} + +type SlowQueryHook struct { + Treshold time.Duration + Logger log.Logger +} + +var _ contexts.Hook = &SlowQueryHook{} + +func (SlowQueryHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + return c.Ctx, nil +} + +func (h *SlowQueryHook) AfterProcess(c *contexts.ContextHook) error { + if c.ExecuteTime >= h.Treshold { + h.Logger.Log(8, log.WARN, "[Slow SQL Query] %s %v - %v", c.SQL, c.Args, c.ExecuteTime) + } + return nil +} + +type ErrorQueryHook struct { + Logger log.Logger +} + +var _ contexts.Hook = &ErrorQueryHook{} + +func (ErrorQueryHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + return c.Ctx, nil +} + +func (h *ErrorQueryHook) AfterProcess(c *contexts.ContextHook) error { + if c.Err != nil && !errors.Is(c.Err, context.Canceled) { + h.Logger.Log(8, log.ERROR, "[Error SQL Query] %s %v - %v", c.SQL, c.Args, c.Err) + } + return nil +} diff --git a/models/db/engine_test.go b/models/db/engine_test.go new file mode 100644 index 0000000..230ee3f --- /dev/null +++ b/models/db/engine_test.go @@ -0,0 +1,154 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + "path/filepath" + "testing" + "time" + + "code.gitea.io/gitea/models/db" + issues_model "code.gitea.io/gitea/models/issues" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/test" + + _ "code.gitea.io/gitea/cmd" // for TestPrimaryKeys + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "xorm.io/xorm" +) + +func TestDumpDatabase(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + dir := t.TempDir() + + type Version struct { + ID int64 `xorm:"pk autoincr"` + Version int64 + } + require.NoError(t, db.GetEngine(db.DefaultContext).Sync(new(Version))) + + for _, dbType := range setting.SupportedDatabaseTypes { + require.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), dbType)) + } +} + +func TestDeleteOrphanedObjects(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + countBefore, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{}) + require.NoError(t, err) + + _, err = db.GetEngine(db.DefaultContext).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003}) + require.NoError(t, err) + + orphaned, err := db.CountOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") + require.NoError(t, err) + assert.EqualValues(t, 3, orphaned) + + err = db.DeleteOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") + require.NoError(t, err) + + countAfter, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{}) + require.NoError(t, err) + assert.EqualValues(t, countBefore, countAfter) +} + +func TestPrimaryKeys(t *testing.T) { + // Some dbs require that all tables have primary keys, see + // https://github.com/go-gitea/gitea/issues/21086 + // https://github.com/go-gitea/gitea/issues/16802 + // To avoid creating tables without primary key again, this test will check them. + // Import "code.gitea.io/gitea/cmd" to make sure each db.RegisterModel in init functions has been called. + + beans, err := db.NamesToBean() + if err != nil { + t.Fatal(err) + } + + whitelist := map[string]string{ + "the_table_name_to_skip_checking": "Write a note here to explain why", + "forgejo_sem_ver": "seriously dude", + } + + for _, bean := range beans { + table, err := db.TableInfo(bean) + if err != nil { + t.Fatal(err) + } + if why, ok := whitelist[table.Name]; ok { + t.Logf("ignore %q because %q", table.Name, why) + continue + } + if len(table.PrimaryKeys) == 0 { + t.Errorf("table %q has no primary key", table.Name) + } + } +} + +func TestSlowQuery(t *testing.T) { + lc, cleanup := test.NewLogChecker("slow-query", log.INFO) + lc.StopMark("[Slow SQL Query]") + defer cleanup() + + e := db.GetEngine(db.DefaultContext) + engine, ok := e.(*xorm.Engine) + assert.True(t, ok) + + // It's not possible to clean this up with XORM, but it's luckily not harmful + // to leave around. + engine.AddHook(&db.SlowQueryHook{ + Treshold: time.Second * 10, + Logger: log.GetLogger("slow-query"), + }) + + // NOOP query. + e.Exec("SELECT 1 WHERE false;") + + _, stopped := lc.Check(100 * time.Millisecond) + assert.False(t, stopped) + + engine.AddHook(&db.SlowQueryHook{ + Treshold: 0, // Every query should be logged. + Logger: log.GetLogger("slow-query"), + }) + + // NOOP query. + e.Exec("SELECT 1 WHERE false;") + + _, stopped = lc.Check(100 * time.Millisecond) + assert.True(t, stopped) +} + +func TestErrorQuery(t *testing.T) { + lc, cleanup := test.NewLogChecker("error-query", log.INFO) + lc.StopMark("[Error SQL Query]") + defer cleanup() + + e := db.GetEngine(db.DefaultContext) + engine, ok := e.(*xorm.Engine) + assert.True(t, ok) + + // It's not possible to clean this up with XORM, but it's luckily not harmful + // to leave around. + engine.AddHook(&db.ErrorQueryHook{ + Logger: log.GetLogger("error-query"), + }) + + // Valid query. + e.Exec("SELECT 1 WHERE false;") + + _, stopped := lc.Check(100 * time.Millisecond) + assert.False(t, stopped) + + // Table doesn't exist. + e.Exec("SELECT column FROM table;") + + _, stopped = lc.Check(100 * time.Millisecond) + assert.True(t, stopped) +} diff --git a/models/db/error.go b/models/db/error.go new file mode 100644 index 0000000..665e970 --- /dev/null +++ b/models/db/error.go @@ -0,0 +1,74 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "fmt" + + "code.gitea.io/gitea/modules/util" +) + +// ErrCancelled represents an error due to context cancellation +type ErrCancelled struct { + Message string +} + +// IsErrCancelled checks if an error is a ErrCancelled. +func IsErrCancelled(err error) bool { + _, ok := err.(ErrCancelled) + return ok +} + +func (err ErrCancelled) Error() string { + return "Cancelled: " + err.Message +} + +// ErrCancelledf returns an ErrCancelled for the provided format and args +func ErrCancelledf(format string, args ...any) error { + return ErrCancelled{ + fmt.Sprintf(format, args...), + } +} + +// ErrSSHDisabled represents an "SSH disabled" error. +type ErrSSHDisabled struct{} + +// IsErrSSHDisabled checks if an error is a ErrSSHDisabled. +func IsErrSSHDisabled(err error) bool { + _, ok := err.(ErrSSHDisabled) + return ok +} + +func (err ErrSSHDisabled) Error() string { + return "SSH is disabled" +} + +// ErrNotExist represents a non-exist error. +type ErrNotExist struct { + Resource string + ID int64 +} + +// IsErrNotExist checks if an error is an ErrNotExist +func IsErrNotExist(err error) bool { + _, ok := err.(ErrNotExist) + return ok +} + +func (err ErrNotExist) Error() string { + name := "record" + if err.Resource != "" { + name = err.Resource + } + + if err.ID != 0 { + return fmt.Sprintf("%s does not exist [id: %d]", name, err.ID) + } + return fmt.Sprintf("%s does not exist", name) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrNotExist) Unwrap() error { + return util.ErrNotExist +} diff --git a/models/db/index.go b/models/db/index.go new file mode 100644 index 0000000..259ddd6 --- /dev/null +++ b/models/db/index.go @@ -0,0 +1,148 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + "errors" + "fmt" + "strconv" + + "code.gitea.io/gitea/modules/setting" +) + +// ResourceIndex represents a resource index which could be used as issue/release and others +// We can create different tables i.e. issue_index, release_index, etc. +type ResourceIndex struct { + GroupID int64 `xorm:"pk"` + MaxIndex int64 `xorm:"index"` +} + +var ( + // ErrResouceOutdated represents an error when request resource outdated + ErrResouceOutdated = errors.New("resource outdated") + // ErrGetResourceIndexFailed represents an error when resource index retries 3 times + ErrGetResourceIndexFailed = errors.New("get resource index failed") +) + +// SyncMaxResourceIndex sync the max index with the resource +func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) { + e := GetEngine(ctx) + + // try to update the max_index and acquire the write-lock for the record + res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=? WHERE group_id=? AND max_index= threshold, nil +} diff --git a/models/db/iterate.go b/models/db/iterate.go new file mode 100644 index 0000000..e1caefa --- /dev/null +++ b/models/db/iterate.go @@ -0,0 +1,43 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + + "code.gitea.io/gitea/modules/setting" + + "xorm.io/builder" +) + +// Iterate iterate all the Bean object +func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error { + var start int + batchSize := setting.Database.IterateBufferSize + sess := GetEngine(ctx) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + beans := make([]*Bean, 0, batchSize) + if cond != nil { + sess = sess.Where(cond) + } + if err := sess.Limit(batchSize, start).Find(&beans); err != nil { + return err + } + if len(beans) == 0 { + return nil + } + start += len(beans) + + for _, bean := range beans { + if err := f(ctx, bean); err != nil { + return err + } + } + } + } +} diff --git a/models/db/iterate_test.go b/models/db/iterate_test.go new file mode 100644 index 0000000..7535d01 --- /dev/null +++ b/models/db/iterate_test.go @@ -0,0 +1,45 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + "context" + "testing" + + "code.gitea.io/gitea/models/db" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIterate(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + xe := unittest.GetXORMEngine() + require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) + + cnt, err := db.GetEngine(db.DefaultContext).Count(&repo_model.RepoUnit{}) + require.NoError(t, err) + + var repoUnitCnt int + err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repo *repo_model.RepoUnit) error { + repoUnitCnt++ + return nil + }) + require.NoError(t, err) + assert.EqualValues(t, cnt, repoUnitCnt) + + err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error { + has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID) + if err != nil { + return err + } + if !has { + return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID} + } + return nil + }) + require.NoError(t, err) +} diff --git a/models/db/list.go b/models/db/list.go new file mode 100644 index 0000000..5c005a0 --- /dev/null +++ b/models/db/list.go @@ -0,0 +1,215 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + + "code.gitea.io/gitea/modules/setting" + + "xorm.io/builder" + "xorm.io/xorm" +) + +const ( + // DefaultMaxInSize represents default variables number on IN () in SQL + DefaultMaxInSize = 50 + defaultFindSliceSize = 10 +) + +// Paginator is the base for different ListOptions types +type Paginator interface { + GetSkipTake() (skip, take int) + IsListAll() bool +} + +// SetSessionPagination sets pagination for a database session +func SetSessionPagination(sess Engine, p Paginator) *xorm.Session { + skip, take := p.GetSkipTake() + + return sess.Limit(take, skip) +} + +// ListOptions options to paginate results +type ListOptions struct { + PageSize int + Page int // start from 1 + ListAll bool // if true, then PageSize and Page will not be taken +} + +var ListOptionsAll = ListOptions{ListAll: true} + +var ( + _ Paginator = &ListOptions{} + _ FindOptions = ListOptions{} +) + +// GetSkipTake returns the skip and take values +func (opts *ListOptions) GetSkipTake() (skip, take int) { + opts.SetDefaultValues() + return (opts.Page - 1) * opts.PageSize, opts.PageSize +} + +func (opts ListOptions) GetPage() int { + return opts.Page +} + +func (opts ListOptions) GetPageSize() int { + return opts.PageSize +} + +// IsListAll indicates PageSize and Page will be ignored +func (opts ListOptions) IsListAll() bool { + return opts.ListAll +} + +// SetDefaultValues sets default values +func (opts *ListOptions) SetDefaultValues() { + if opts.PageSize <= 0 { + opts.PageSize = setting.API.DefaultPagingNum + } + if opts.PageSize > setting.API.MaxResponseItems { + opts.PageSize = setting.API.MaxResponseItems + } + if opts.Page <= 0 { + opts.Page = 1 + } +} + +func (opts ListOptions) ToConds() builder.Cond { + return builder.NewCond() +} + +// AbsoluteListOptions absolute options to paginate results +type AbsoluteListOptions struct { + skip int + take int +} + +var _ Paginator = &AbsoluteListOptions{} + +// NewAbsoluteListOptions creates a list option with applied limits +func NewAbsoluteListOptions(skip, take int) *AbsoluteListOptions { + if skip < 0 { + skip = 0 + } + if take <= 0 { + take = setting.API.DefaultPagingNum + } + if take > setting.API.MaxResponseItems { + take = setting.API.MaxResponseItems + } + return &AbsoluteListOptions{skip, take} +} + +// IsListAll will always return false +func (opts *AbsoluteListOptions) IsListAll() bool { + return false +} + +// GetSkipTake returns the skip and take values +func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) { + return opts.skip, opts.take +} + +// FindOptions represents a find options +type FindOptions interface { + GetPage() int + GetPageSize() int + IsListAll() bool + ToConds() builder.Cond +} + +type JoinFunc func(sess Engine) error + +type FindOptionsJoin interface { + ToJoins() []JoinFunc +} + +type FindOptionsOrder interface { + ToOrders() string +} + +// Find represents a common find function which accept an options interface +func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) { + sess := GetEngine(ctx).Where(opts.ToConds()) + + if joinOpt, ok := opts.(FindOptionsJoin); ok { + for _, joinFunc := range joinOpt.ToJoins() { + if err := joinFunc(sess); err != nil { + return nil, err + } + } + } + if orderOpt, ok := opts.(FindOptionsOrder); ok { + if order := orderOpt.ToOrders(); order != "" { + sess.OrderBy(order) + } + } + + page, pageSize := opts.GetPage(), opts.GetPageSize() + if !opts.IsListAll() && pageSize > 0 { + if page == 0 { + page = 1 + } + sess.Limit(pageSize, (page-1)*pageSize) + } + + findPageSize := defaultFindSliceSize + if pageSize > 0 { + findPageSize = pageSize + } + objects := make([]*T, 0, findPageSize) + if err := sess.Find(&objects); err != nil { + return nil, err + } + return objects, nil +} + +// Count represents a common count function which accept an options interface +func Count[T any](ctx context.Context, opts FindOptions) (int64, error) { + sess := GetEngine(ctx).Where(opts.ToConds()) + if joinOpt, ok := opts.(FindOptionsJoin); ok { + for _, joinFunc := range joinOpt.ToJoins() { + if err := joinFunc(sess); err != nil { + return 0, err + } + } + } + + var object T + return sess.Count(&object) +} + +// FindAndCount represents a common findandcount function which accept an options interface +func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, error) { + sess := GetEngine(ctx).Where(opts.ToConds()) + page, pageSize := opts.GetPage(), opts.GetPageSize() + if !opts.IsListAll() && pageSize > 0 && page >= 1 { + sess.Limit(pageSize, (page-1)*pageSize) + } + if joinOpt, ok := opts.(FindOptionsJoin); ok { + for _, joinFunc := range joinOpt.ToJoins() { + if err := joinFunc(sess); err != nil { + return nil, 0, err + } + } + } + if orderOpt, ok := opts.(FindOptionsOrder); ok { + if order := orderOpt.ToOrders(); order != "" { + sess.OrderBy(order) + } + } + + findPageSize := defaultFindSliceSize + if pageSize > 0 { + findPageSize = pageSize + } + objects := make([]*T, 0, findPageSize) + cnt, err := sess.FindAndCount(&objects) + if err != nil { + return nil, 0, err + } + return objects, cnt, nil +} diff --git a/models/db/list_test.go b/models/db/list_test.go new file mode 100644 index 0000000..82240d2 --- /dev/null +++ b/models/db/list_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "xorm.io/builder" +) + +type mockListOptions struct { + db.ListOptions +} + +func (opts mockListOptions) IsListAll() bool { + return true +} + +func (opts mockListOptions) ToConds() builder.Cond { + return builder.NewCond() +} + +func TestFind(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + xe := unittest.GetXORMEngine() + require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) + + var repoUnitCount int + _, err := db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount) + require.NoError(t, err) + assert.NotEmpty(t, repoUnitCount) + + opts := mockListOptions{} + repoUnits, err := db.Find[repo_model.RepoUnit](db.DefaultContext, opts) + require.NoError(t, err) + assert.Len(t, repoUnits, repoUnitCount) + + cnt, err := db.Count[repo_model.RepoUnit](db.DefaultContext, opts) + require.NoError(t, err) + assert.EqualValues(t, repoUnitCount, cnt) + + repoUnits, newCnt, err := db.FindAndCount[repo_model.RepoUnit](db.DefaultContext, opts) + require.NoError(t, err) + assert.EqualValues(t, cnt, newCnt) + assert.Len(t, repoUnits, repoUnitCount) +} diff --git a/models/db/log.go b/models/db/log.go new file mode 100644 index 0000000..457ee80 --- /dev/null +++ b/models/db/log.go @@ -0,0 +1,107 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "fmt" + "sync/atomic" + + "code.gitea.io/gitea/modules/log" + + xormlog "xorm.io/xorm/log" +) + +// XORMLogBridge a logger bridge from Logger to xorm +type XORMLogBridge struct { + showSQL atomic.Bool + logger log.Logger +} + +// NewXORMLogger inits a log bridge for xorm +func NewXORMLogger(showSQL bool) xormlog.Logger { + l := &XORMLogBridge{logger: log.GetLogger("xorm")} + l.showSQL.Store(showSQL) + return l +} + +const stackLevel = 8 + +// Log a message with defined skip and at logging level +func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...any) { + l.logger.Log(skip+1, level, format, v...) +} + +// Debug show debug log +func (l *XORMLogBridge) Debug(v ...any) { + l.Log(stackLevel, log.DEBUG, "%s", fmt.Sprint(v...)) +} + +// Debugf show debug log +func (l *XORMLogBridge) Debugf(format string, v ...any) { + l.Log(stackLevel, log.DEBUG, format, v...) +} + +// Error show error log +func (l *XORMLogBridge) Error(v ...any) { + l.Log(stackLevel, log.ERROR, "%s", fmt.Sprint(v...)) +} + +// Errorf show error log +func (l *XORMLogBridge) Errorf(format string, v ...any) { + l.Log(stackLevel, log.ERROR, format, v...) +} + +// Info show information level log +func (l *XORMLogBridge) Info(v ...any) { + l.Log(stackLevel, log.INFO, "%s", fmt.Sprint(v...)) +} + +// Infof show information level log +func (l *XORMLogBridge) Infof(format string, v ...any) { + l.Log(stackLevel, log.INFO, format, v...) +} + +// Warn show warning log +func (l *XORMLogBridge) Warn(v ...any) { + l.Log(stackLevel, log.WARN, "%s", fmt.Sprint(v...)) +} + +// Warnf show warning log +func (l *XORMLogBridge) Warnf(format string, v ...any) { + l.Log(stackLevel, log.WARN, format, v...) +} + +// Level get logger level +func (l *XORMLogBridge) Level() xormlog.LogLevel { + switch l.logger.GetLevel() { + case log.TRACE, log.DEBUG: + return xormlog.LOG_DEBUG + case log.INFO: + return xormlog.LOG_INFO + case log.WARN: + return xormlog.LOG_WARNING + case log.ERROR: + return xormlog.LOG_ERR + case log.NONE: + return xormlog.LOG_OFF + } + return xormlog.LOG_UNKNOWN +} + +// SetLevel set the logger level +func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) { +} + +// ShowSQL set if record SQL +func (l *XORMLogBridge) ShowSQL(show ...bool) { + if len(show) == 0 { + show = []bool{true} + } + l.showSQL.Store(show[0]) +} + +// IsShowSQL if record SQL +func (l *XORMLogBridge) IsShowSQL() bool { + return l.showSQL.Load() +} diff --git a/models/db/main_test.go b/models/db/main_test.go new file mode 100644 index 0000000..7d80b40 --- /dev/null +++ b/models/db/main_test.go @@ -0,0 +1,17 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + "testing" + + "code.gitea.io/gitea/models/unittest" + + _ "code.gitea.io/gitea/models" + _ "code.gitea.io/gitea/models/repo" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m) +} diff --git a/models/db/name.go b/models/db/name.go new file mode 100644 index 0000000..51be33a --- /dev/null +++ b/models/db/name.go @@ -0,0 +1,106 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "fmt" + "regexp" + "strings" + "unicode/utf8" + + "code.gitea.io/gitea/modules/util" +) + +var ( + // ErrNameEmpty name is empty error + ErrNameEmpty = util.SilentWrap{Message: "name is empty", Err: util.ErrInvalidArgument} + + // AlphaDashDotPattern characters prohibited in a user name (anything except A-Za-z0-9_.-) + AlphaDashDotPattern = regexp.MustCompile(`[^\w-\.]`) +) + +// ErrNameReserved represents a "reserved name" error. +type ErrNameReserved struct { + Name string +} + +// IsErrNameReserved checks if an error is a ErrNameReserved. +func IsErrNameReserved(err error) bool { + _, ok := err.(ErrNameReserved) + return ok +} + +func (err ErrNameReserved) Error() string { + return fmt.Sprintf("name is reserved [name: %s]", err.Name) +} + +// Unwrap unwraps this as a ErrInvalid err +func (err ErrNameReserved) Unwrap() error { + return util.ErrInvalidArgument +} + +// ErrNamePatternNotAllowed represents a "pattern not allowed" error. +type ErrNamePatternNotAllowed struct { + Pattern string +} + +// IsErrNamePatternNotAllowed checks if an error is an ErrNamePatternNotAllowed. +func IsErrNamePatternNotAllowed(err error) bool { + _, ok := err.(ErrNamePatternNotAllowed) + return ok +} + +func (err ErrNamePatternNotAllowed) Error() string { + return fmt.Sprintf("name pattern is not allowed [pattern: %s]", err.Pattern) +} + +// Unwrap unwraps this as a ErrInvalid err +func (err ErrNamePatternNotAllowed) Unwrap() error { + return util.ErrInvalidArgument +} + +// ErrNameCharsNotAllowed represents a "character not allowed in name" error. +type ErrNameCharsNotAllowed struct { + Name string +} + +// IsErrNameCharsNotAllowed checks if an error is an ErrNameCharsNotAllowed. +func IsErrNameCharsNotAllowed(err error) bool { + _, ok := err.(ErrNameCharsNotAllowed) + return ok +} + +func (err ErrNameCharsNotAllowed) Error() string { + return fmt.Sprintf("name is invalid [%s]: must be valid alpha or numeric or dash(-_) or dot characters", err.Name) +} + +// Unwrap unwraps this as a ErrInvalid err +func (err ErrNameCharsNotAllowed) Unwrap() error { + return util.ErrInvalidArgument +} + +// IsUsableName checks if name is reserved or pattern of name is not allowed +// based on given reserved names and patterns. +// Names are exact match, patterns can be prefix or suffix match with placeholder '*'. +func IsUsableName(names, patterns []string, name string) error { + name = strings.TrimSpace(strings.ToLower(name)) + if utf8.RuneCountInString(name) == 0 { + return ErrNameEmpty + } + + for i := range names { + if name == names[i] { + return ErrNameReserved{name} + } + } + + for _, pat := range patterns { + if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) || + (pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) { + return ErrNamePatternNotAllowed{pat} + } + } + + return nil +} diff --git a/models/db/paginator/main_test.go b/models/db/paginator/main_test.go new file mode 100644 index 0000000..47993ae --- /dev/null +++ b/models/db/paginator/main_test.go @@ -0,0 +1,14 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package paginator + +import ( + "testing" + + "code.gitea.io/gitea/models/unittest" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m) +} diff --git a/models/db/paginator/paginator.go b/models/db/paginator/paginator.go new file mode 100644 index 0000000..bcda47d --- /dev/null +++ b/models/db/paginator/paginator.go @@ -0,0 +1,7 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package paginator + +// dummy only. in the future, the models/db/list_options.go should be moved here to decouple from db package +// otherwise the unit test will cause cycle import diff --git a/models/db/paginator/paginator_test.go b/models/db/paginator/paginator_test.go new file mode 100644 index 0000000..2060221 --- /dev/null +++ b/models/db/paginator/paginator_test.go @@ -0,0 +1,59 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package paginator + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" +) + +func TestPaginator(t *testing.T) { + cases := []struct { + db.Paginator + Skip int + Take int + Start int + End int + }{ + { + Paginator: &db.ListOptions{Page: -1, PageSize: -1}, + Skip: 0, + Take: setting.API.DefaultPagingNum, + Start: 0, + End: setting.API.DefaultPagingNum, + }, + { + Paginator: &db.ListOptions{Page: 2, PageSize: 10}, + Skip: 10, + Take: 10, + Start: 10, + End: 20, + }, + { + Paginator: db.NewAbsoluteListOptions(-1, -1), + Skip: 0, + Take: setting.API.DefaultPagingNum, + Start: 0, + End: setting.API.DefaultPagingNum, + }, + { + Paginator: db.NewAbsoluteListOptions(2, 10), + Skip: 2, + Take: 10, + Start: 2, + End: 12, + }, + } + + for _, c := range cases { + skip, take := c.Paginator.GetSkipTake() + + assert.Equal(t, c.Skip, skip) + assert.Equal(t, c.Take, take) + } +} diff --git a/models/db/search.go b/models/db/search.go new file mode 100644 index 0000000..37565f4 --- /dev/null +++ b/models/db/search.go @@ -0,0 +1,33 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +// SearchOrderBy is used to sort the result +type SearchOrderBy string + +func (s SearchOrderBy) String() string { + return string(s) +} + +// Strings for sorting result +const ( + SearchOrderByAlphabetically SearchOrderBy = "name ASC" + SearchOrderByAlphabeticallyReverse SearchOrderBy = "name DESC" + SearchOrderByLeastUpdated SearchOrderBy = "updated_unix ASC" + SearchOrderByRecentUpdated SearchOrderBy = "updated_unix DESC" + SearchOrderByOldest SearchOrderBy = "created_unix ASC" + SearchOrderByNewest SearchOrderBy = "created_unix DESC" + SearchOrderByID SearchOrderBy = "id ASC" + SearchOrderByIDReverse SearchOrderBy = "id DESC" + SearchOrderByStars SearchOrderBy = "num_stars ASC" + SearchOrderByStarsReverse SearchOrderBy = "num_stars DESC" + SearchOrderByForks SearchOrderBy = "num_forks ASC" + SearchOrderByForksReverse SearchOrderBy = "num_forks DESC" +) + +const ( + // Which means a condition to filter the records which don't match any id. + // It's different from zero which means the condition could be ignored. + NoConditionID = -1 +) diff --git a/models/db/sequence.go b/models/db/sequence.go new file mode 100644 index 0000000..f49ad93 --- /dev/null +++ b/models/db/sequence.go @@ -0,0 +1,70 @@ +// Copyright 2018 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + "fmt" + "regexp" + + "code.gitea.io/gitea/modules/setting" +) + +// CountBadSequences looks for broken sequences from recreate-table mistakes +func CountBadSequences(_ context.Context) (int64, error) { + if !setting.Database.Type.IsPostgreSQL() { + return 0, nil + } + + sess := x.NewSession() + defer sess.Close() + + var sequences []string + schema := x.Dialect().URI().Schema + + sess.Engine().SetSchema("") + if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil { + return 0, err + } + sess.Engine().SetSchema(schema) + + return int64(len(sequences)), nil +} + +// FixBadSequences fixes for broken sequences from recreate-table mistakes +func FixBadSequences(_ context.Context) error { + if !setting.Database.Type.IsPostgreSQL() { + return nil + } + + sess := x.NewSession() + defer sess.Close() + if err := sess.Begin(); err != nil { + return err + } + + var sequences []string + schema := sess.Engine().Dialect().URI().Schema + + sess.Engine().SetSchema("") + if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil { + return err + } + sess.Engine().SetSchema(schema) + + sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`) + + for _, sequence := range sequences { + tableName := sequenceRegexp.FindStringSubmatch(sequence)[1] + newSequenceName := tableName + "_id_seq" + if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil { + return err + } + if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil { + return err + } + } + + return sess.Commit() +} diff --git a/models/db/sql_postgres_with_schema.go b/models/db/sql_postgres_with_schema.go new file mode 100644 index 0000000..ec63447 --- /dev/null +++ b/models/db/sql_postgres_with_schema.go @@ -0,0 +1,74 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "database/sql" + "database/sql/driver" + "sync" + + "code.gitea.io/gitea/modules/setting" + + "github.com/lib/pq" + "xorm.io/xorm/dialects" +) + +var registerOnce sync.Once + +func registerPostgresSchemaDriver() { + registerOnce.Do(func() { + sql.Register("postgresschema", &postgresSchemaDriver{}) + dialects.RegisterDriver("postgresschema", dialects.QueryDriver("postgres")) + }) +} + +type postgresSchemaDriver struct { + pq.Driver +} + +// Open opens a new connection to the database. name is a connection string. +// This function opens the postgres connection in the default manner but immediately +// runs set_config to set the search_path appropriately +func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) { + conn, err := d.Driver.Open(name) + if err != nil { + return conn, err + } + schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema) + + // golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here + // and in any case pq does not implement it + if execer, ok := conn.(driver.Execer); ok { //nolint + _, err := execer.Exec(`SELECT set_config( + 'search_path', + $1 || ',' || current_setting('search_path'), + false)`, []driver.Value{schemaValue}) + if err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil + } + + stmt, err := conn.Prepare(`SELECT set_config( + 'search_path', + $1 || ',' || current_setting('search_path'), + false)`) + if err != nil { + _ = conn.Close() + return nil, err + } + defer stmt.Close() + + // driver.String.ConvertValue will never return err for string + + // golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here + _, err = stmt.Exec([]driver.Value{schemaValue}) //nolint + if err != nil { + _ = conn.Close() + return nil, err + } + + return conn, nil +} -- cgit v1.2.3