draincloud-core/internal/storage/transaction.go

67 lines
1.1 KiB
Go
Raw Normal View History

2024-09-27 22:37:58 +00:00
package storage
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
)
type txKey struct{}
var ctxKey txKey = txKey{}
type DBTX interface {
sqlx.Ext
sqlx.ExtContext
}
func Transaction(ctx context.Context, db *sqlx.DB, fn func(context.Context) error) (err error) {
tx := txFromContext(ctx)
if tx == nil {
tx, err = db.BeginTxx(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
if err != nil {
return fmt.Errorf("failed to begin tx: %w", err)
}
defer func() {
if err == nil {
err = tx.Commit()
}
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = errors.Join(err, rbErr)
}
}
}()
ctx = txContext(ctx, tx)
}
return fn(ctx)
}
func Conn(ctx context.Context, db DBTX) DBTX {
if tx := txFromContext(ctx); tx != nil {
return tx
}
return db
}
func txFromContext(ctx context.Context) *sqlx.Tx {
if tx, ok := ctx.Value(ctxKey).(*sqlx.Tx); ok {
return tx
}
return nil
}
func txContext(parent context.Context, tx *sqlx.Tx) context.Context {
return context.WithValue(parent, tx, ctxKey)
}