diff options
Diffstat (limited to '')
-rw-r--r-- | modules/setting/database.go | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/modules/setting/database.go b/modules/setting/database.go new file mode 100644 index 0000000..76fae27 --- /dev/null +++ b/modules/setting/database.go @@ -0,0 +1,204 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package setting + +import ( + "errors" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" +) + +var ( + // SupportedDatabaseTypes includes all XORM supported databases type, sqlite3 maybe added by `database_sqlite3.go` + SupportedDatabaseTypes = []string{"mysql", "postgres"} + // DatabaseTypeNames contains the friendly names for all database types + DatabaseTypeNames = map[string]string{"mysql": "MySQL", "postgres": "PostgreSQL", "sqlite3": "SQLite3"} + + // EnableSQLite3 use SQLite3, set by build flag + EnableSQLite3 bool + + // Database holds the database settings + Database = struct { + Type DatabaseType + Host string + Name string + User string + Passwd string + Schema string + SSLMode string + Path string + LogSQL bool + MysqlCharset string + CharsetCollation string + Timeout int // seconds + SQLiteJournalMode string + DBConnectRetries int + DBConnectBackoff time.Duration + MaxIdleConns int + MaxOpenConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + IterateBufferSize int + AutoMigration bool + SlowQueryThreshold time.Duration + }{ + Timeout: 500, + IterateBufferSize: 50, + } +) + +// LoadDBSetting loads the database settings +func LoadDBSetting() { + loadDBSetting(CfgProvider) +} + +func loadDBSetting(rootCfg ConfigProvider) { + sec := rootCfg.Section("database") + Database.Type = DatabaseType(sec.Key("DB_TYPE").String()) + + Database.Host = sec.Key("HOST").String() + Database.Name = sec.Key("NAME").String() + Database.User = sec.Key("USER").String() + if len(Database.Passwd) == 0 { + Database.Passwd = sec.Key("PASSWD").String() + } + Database.Schema = sec.Key("SCHEMA").String() + Database.SSLMode = sec.Key("SSL_MODE").MustString("disable") + Database.CharsetCollation = sec.Key("CHARSET_COLLATION").String() + + Database.Path = sec.Key("PATH").MustString(filepath.Join(AppDataPath, "forgejo.db")) + Database.Timeout = sec.Key("SQLITE_TIMEOUT").MustInt(500) + Database.SQLiteJournalMode = sec.Key("SQLITE_JOURNAL_MODE").MustString("") + + Database.MaxIdleConns = sec.Key("MAX_IDLE_CONNS").MustInt(2) + if Database.Type.IsMySQL() { + Database.ConnMaxLifetime = sec.Key("CONN_MAX_LIFETIME").MustDuration(3 * time.Second) + } else { + Database.ConnMaxLifetime = sec.Key("CONN_MAX_LIFETIME").MustDuration(0) + } + Database.ConnMaxIdleTime = sec.Key("CONN_MAX_IDLETIME").MustDuration(0) + Database.MaxOpenConns = sec.Key("MAX_OPEN_CONNS").MustInt(100) + + Database.IterateBufferSize = sec.Key("ITERATE_BUFFER_SIZE").MustInt(50) + Database.LogSQL = sec.Key("LOG_SQL").MustBool(false) + Database.DBConnectRetries = sec.Key("DB_RETRIES").MustInt(10) + Database.DBConnectBackoff = sec.Key("DB_RETRY_BACKOFF").MustDuration(3 * time.Second) + Database.AutoMigration = sec.Key("AUTO_MIGRATION").MustBool(true) + + deprecatedSetting(rootCfg, "database", "SLOW_QUERY_TRESHOLD", "database", "SLOW_QUERY_THRESHOLD", "1.23") + if sec.HasKey("SLOW_QUERY_TRESHOLD") && !sec.HasKey("SLOW_QUERY_THRESHOLD") { + Database.SlowQueryThreshold = sec.Key("SLOW_QUERY_TRESHOLD").MustDuration(5 * time.Second) + } else { + Database.SlowQueryThreshold = sec.Key("SLOW_QUERY_THRESHOLD").MustDuration(5 * time.Second) + } +} + +// DBConnStr returns database connection string +func DBConnStr() (string, error) { + var connStr string + paramSep := "?" + if strings.Contains(Database.Name, paramSep) { + paramSep = "&" + } + switch Database.Type { + case "mysql": + connType := "tcp" + if len(Database.Host) > 0 && Database.Host[0] == '/' { // looks like a unix socket + connType = "unix" + } + tls := Database.SSLMode + if tls == "disable" { // allow (Postgres-inspired) default value to work in MySQL + tls = "false" + } + connStr = fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", + Database.User, Database.Passwd, connType, Database.Host, Database.Name, paramSep, tls) + case "postgres": + connStr = getPostgreSQLConnectionString(Database.Host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) + case "sqlite3": + if !EnableSQLite3 { + return "", errors.New("this Gitea binary was not built with SQLite3 support") + } + if err := os.MkdirAll(filepath.Dir(Database.Path), os.ModePerm); err != nil { + return "", fmt.Errorf("Failed to create directories: %w", err) + } + journalMode := "" + if Database.SQLiteJournalMode != "" { + journalMode = "&_journal_mode=" + Database.SQLiteJournalMode + } + connStr = fmt.Sprintf("file:%s?cache=shared&mode=rwc&_busy_timeout=%d&_txlock=immediate%s", + Database.Path, Database.Timeout, journalMode) + default: + return "", fmt.Errorf("unknown database type: %s", Database.Type) + } + + return connStr, nil +} + +// parsePostgreSQLHostPort parses given input in various forms defined in +// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING +// and returns proper host and port number. +func parsePostgreSQLHostPort(info string) (host, port string) { + if h, p, err := net.SplitHostPort(info); err == nil { + host, port = h, p + } else { + // treat the "info" as "host", if it's an IPv6 address, remove the wrapper + host = info + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + } + + // set fallback values + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "5432" + } + return host, port +} + +func getPostgreSQLConnectionString(dbHost, dbUser, dbPasswd, dbName, dbsslMode string) (connStr string) { + dbName, dbParam, _ := strings.Cut(dbName, "?") + host, port := parsePostgreSQLHostPort(dbHost) + connURL := url.URL{ + Scheme: "postgres", + User: url.UserPassword(dbUser, dbPasswd), + Host: net.JoinHostPort(host, port), + Path: dbName, + OmitHost: false, + RawQuery: dbParam, + } + query := connURL.Query() + if strings.HasPrefix(host, "/") { // looks like a unix socket + query.Add("host", host) + connURL.Host = ":" + port + } + query.Set("sslmode", dbsslMode) + connURL.RawQuery = query.Encode() + return connURL.String() +} + +type DatabaseType string + +func (t DatabaseType) String() string { + return string(t) +} + +func (t DatabaseType) IsSQLite3() bool { + return t == "sqlite3" +} + +func (t DatabaseType) IsMySQL() bool { + return t == "mysql" +} + +func (t DatabaseType) IsPostgreSQL() bool { + return t == "postgres" +} |