package storage import ( "context" "database/sql" "errors" "fmt" "github.com/jmoiron/sqlx" ) type txKey struct{} var ctxKey txKey = txKey{} type DBTX interface { sqlx.Ext sqlx.ExtContext } func Transaction(ctx context.Context, db *sqlx.DB, fn func(context.Context) error) (err error) { tx := txFromContext(ctx) if tx == nil { tx, err = db.BeginTxx(ctx, &sql.TxOptions{ Isolation: sql.LevelRepeatableRead, }) if err != nil { return fmt.Errorf("failed to begin tx: %w", err) } defer func() { if err == nil { err = tx.Commit() } if err != nil { if rbErr := tx.Rollback(); rbErr != nil { err = errors.Join(err, rbErr) } } }() ctx = txContext(ctx, tx) } return fn(ctx) } func Conn(ctx context.Context, db DBTX) DBTX { if tx := txFromContext(ctx); tx != nil { return tx } return db } func txFromContext(ctx context.Context) *sqlx.Tx { if tx, ok := ctx.Value(ctxKey).(*sqlx.Tx); ok { return tx } return nil } func txContext(parent context.Context, tx *sqlx.Tx) context.Context { return context.WithValue(parent, tx, ctxKey) }