67 lines
1.1 KiB
Go
67 lines
1.1 KiB
Go
|
package storage
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
|
||
|
"github.com/jmoiron/sqlx"
|
||
|
)
|
||
|
|
||
|
type txKey struct{}
|
||
|
|
||
|
var ctxKey txKey = txKey{}
|
||
|
|
||
|
type DBTX interface {
|
||
|
sqlx.Ext
|
||
|
sqlx.ExtContext
|
||
|
}
|
||
|
|
||
|
func Transaction(ctx context.Context, db *sqlx.DB, fn func(context.Context) error) (err error) {
|
||
|
tx := txFromContext(ctx)
|
||
|
if tx == nil {
|
||
|
tx, err = db.BeginTxx(ctx, &sql.TxOptions{
|
||
|
Isolation: sql.LevelRepeatableRead,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to begin tx: %w", err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err == nil {
|
||
|
err = tx.Commit()
|
||
|
}
|
||
|
if err != nil {
|
||
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||
|
err = errors.Join(err, rbErr)
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
ctx = txContext(ctx, tx)
|
||
|
}
|
||
|
|
||
|
return fn(ctx)
|
||
|
}
|
||
|
|
||
|
func Conn(ctx context.Context, db DBTX) DBTX {
|
||
|
if tx := txFromContext(ctx); tx != nil {
|
||
|
return tx
|
||
|
}
|
||
|
|
||
|
return db
|
||
|
}
|
||
|
|
||
|
func txFromContext(ctx context.Context) *sqlx.Tx {
|
||
|
if tx, ok := ctx.Value(ctxKey).(*sqlx.Tx); ok {
|
||
|
return tx
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func txContext(parent context.Context, tx *sqlx.Tx) context.Context {
|
||
|
return context.WithValue(parent, tx, ctxKey)
|
||
|
}
|