summaryrefslogtreecommitdiffstats
path: root/models/db/context.go
blob: 43f612518aacfbf502f24de84cc402e8a1ca4411 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package db

import (
	"context"
	"database/sql"

	"xorm.io/builder"
	"xorm.io/xorm"
)

// DefaultContext is the default context to run xorm queries in
// will be overwritten by Init with HammerContext
var DefaultContext context.Context

// contextKey is a value for use with context.WithValue.
type contextKey struct {
	name string
}

// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
var (
	enginedContextKey         = &contextKey{"engined"}
	_                 Engined = &Context{}
)

// Context represents a db context
type Context struct {
	context.Context
	e           Engine
	transaction bool
}

func newContext(ctx context.Context, e Engine, transaction bool) *Context {
	return &Context{
		Context:     ctx,
		e:           e,
		transaction: transaction,
	}
}

// InTransaction if context is in a transaction
func (ctx *Context) InTransaction() bool {
	return ctx.transaction
}

// Engine returns db engine
func (ctx *Context) Engine() Engine {
	return ctx.e
}

// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *Context) Value(key any) any {
	if key == enginedContextKey {
		return ctx
	}
	return ctx.Context.Value(key)
}

// WithContext returns this engine tied to this context
func (ctx *Context) WithContext(other context.Context) *Context {
	return newContext(ctx, ctx.e.Context(other), ctx.transaction)
}

// Engined structs provide an Engine
type Engined interface {
	Engine() Engine
}

// GetEngine will get a db Engine from this context or return an Engine restricted to this context
func GetEngine(ctx context.Context) Engine {
	if e := getEngine(ctx); e != nil {
		return e
	}
	return x.Context(ctx)
}

// getEngine will get a db Engine from this context or return nil
func getEngine(ctx context.Context) Engine {
	if engined, ok := ctx.(Engined); ok {
		return engined.Engine()
	}
	enginedInterface := ctx.Value(enginedContextKey)
	if enginedInterface != nil {
		return enginedInterface.(Engined).Engine()
	}
	return nil
}

// Committer represents an interface to Commit or Close the Context
type Committer interface {
	Commit() error
	Close() error
}

// halfCommitter is a wrapper of Committer.
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
type halfCommitter struct {
	committer Committer
	committed bool
}

func (c *halfCommitter) Commit() error {
	c.committed = true
	// should do nothing, and the parent committer will commit later
	return nil
}

func (c *halfCommitter) Close() error {
	if c.committed {
		// it's "commit and close", should do nothing, and the parent committer will commit later
		return nil
	}

	// it's "rollback and close", let the parent committer rollback right now
	return c.committer.Close()
}

// TxContext represents a transaction Context,
// it will reuse the existing transaction in the parent context or create a new one.
// Some tips to use:
//
//	1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically.
//	2. To maintain the old code which uses `TxContext`:
//	  a. Always call `Close()` before returning regardless of whether `Commit()` has been called.
//	  b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data.
//	  c. Remember the `Committer` will be a halfCommitter when a transaction is being reused.
//	     So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction.
//	     And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
//	  d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
	if sess, ok := inTransaction(parentCtx); ok {
		return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil
	}

	sess := x.NewSession()
	if err := sess.Begin(); err != nil {
		sess.Close()
		return nil, nil, err
	}

	return newContext(DefaultContext, sess, true), sess, nil
}

// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
	if sess, ok := inTransaction(parentCtx); ok {
		err := f(newContext(parentCtx, sess, true))
		if err != nil {
			// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
			_ = sess.Close()
		}
		return err
	}
	return txWithNoCheck(parentCtx, f)
}

func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
	sess := x.NewSession()
	defer sess.Close()
	if err := sess.Begin(); err != nil {
		return err
	}

	if err := f(newContext(parentCtx, sess, true)); err != nil {
		return err
	}

	return sess.Commit()
}

// Insert inserts records into database
func Insert(ctx context.Context, beans ...any) error {
	_, err := GetEngine(ctx).Insert(beans...)
	return err
}

// Exec executes a sql with args
func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
	return GetEngine(ctx).Exec(sqlAndArgs...)
}

func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
	if !cond.IsValid() {
		panic("cond is invalid in db.Get(ctx, cond). This should not be possible.")
	}

	var bean T
	has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
	if err != nil {
		return nil, false, err
	} else if !has {
		return nil, false, nil
	}
	return &bean, true, nil
}

func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
	var bean T
	has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
	if err != nil {
		return nil, false, err
	} else if !has {
		return nil, false, nil
	}
	return &bean, true, nil
}

func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
	if !cond.IsValid() {
		panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.")
	}

	var bean T
	return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
}

func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
	var bean T
	return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
}

// DeleteByID deletes the given bean with the given ID
func DeleteByID[T any](ctx context.Context, id int64) (int64, error) {
	var bean T
	return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean)
}

func DeleteByIDs[T any](ctx context.Context, ids ...int64) error {
	if len(ids) == 0 {
		return nil
	}

	var bean T
	_, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean)
	return err
}

func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) {
	if opts == nil || !opts.ToConds().IsValid() {
		panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.")
	}

	var bean T
	return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean)
}

// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
func DeleteByBean(ctx context.Context, bean any) (int64, error) {
	return GetEngine(ctx).Delete(bean)
}

// FindIDs finds the IDs for the given table name satisfying the given condition
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
	ids := make([]int64, 0, 10)
	if err := GetEngine(ctx).Table(tableName).
		Cols(idCol).
		Where(cond).
		Find(&ids); err != nil {
		return nil, err
	}
	return ids, nil
}

// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
// Timestamps of the entities won't be updated
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error {
	_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
	return err
}

// DeleteBeans deletes all given beans, beans must contain delete conditions.
func DeleteBeans(ctx context.Context, beans ...any) (err error) {
	e := GetEngine(ctx)
	for i := range beans {
		if _, err = e.Delete(beans[i]); err != nil {
			return err
		}
	}
	return nil
}

// TruncateBeans deletes all given beans, beans may contain delete conditions.
func TruncateBeans(ctx context.Context, beans ...any) (err error) {
	e := GetEngine(ctx)
	for i := range beans {
		if _, err = e.Truncate(beans[i]); err != nil {
			return err
		}
	}
	return nil
}

// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
func CountByBean(ctx context.Context, bean any) (int64, error) {
	return GetEngine(ctx).Count(bean)
}

// TableName returns the table name according a bean object
func TableName(bean any) string {
	return x.TableName(bean)
}

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
	_, ok := inTransaction(ctx)
	return ok
}

func inTransaction(ctx context.Context) (*xorm.Session, bool) {
	e := getEngine(ctx)
	if e == nil {
		return nil, false
	}

	switch t := e.(type) {
	case *xorm.Engine:
		return nil, false
	case *xorm.Session:
		if t.IsInTx() {
			return t, true
		}
		return nil, false
	default:
		return nil, false
	}
}