36 lines
838 B
Go
36 lines
838 B
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
|
||
|
|
"gorm.io/gorm"
|
||
|
|
)
|
||
|
|
|
||
|
|
type txKeyType struct{}
|
||
|
|
|
||
|
|
var txKey = txKeyType{}
|
||
|
|
|
||
|
|
// DBFromContext returns the transactional *gorm.DB from context if present; otherwise returns base.
|
||
|
|
func DBFromContext(ctx context.Context, base *gorm.DB) *gorm.DB {
|
||
|
|
if v := ctx.Value(txKey); v != nil {
|
||
|
|
if tx, ok := v.(*gorm.DB); ok && tx != nil {
|
||
|
|
return tx
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return base
|
||
|
|
}
|
||
|
|
|
||
|
|
type TxManager struct {
|
||
|
|
db *gorm.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewTxManager(db *gorm.DB) *TxManager { return &TxManager{db: db} }
|
||
|
|
|
||
|
|
// WithTransaction runs fn inside a DB transaction, injecting the *gorm.DB tx into ctx.
|
||
|
|
func (m *TxManager) WithTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
||
|
|
return m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
|
|
ctxTx := context.WithValue(ctx, txKey, tx)
|
||
|
|
return fn(ctxTx)
|
||
|
|
})
|
||
|
|
}
|