mirror of
https://github.com/emo2007/block-accounting.git
synced 2024-11-13 22:06:27 +00:00
86 lines
1.8 KiB
Go
86 lines
1.8 KiB
Go
|
package sqltools
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
)
|
||
|
|
||
|
type DBTX interface {
|
||
|
Query(query string, args ...any) (*sql.Rows, error)
|
||
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||
|
QueryRow(query string, args ...any) *sql.Row
|
||
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||
|
Exec(query string, args ...any) (sql.Result, error)
|
||
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||
|
}
|
||
|
|
||
|
type TransactionalStorage interface {
|
||
|
Transaction(ctx context.Context, db *sql.DB, f func(ctx context.Context) error) error
|
||
|
Conn(ctx context.Context) DBTX
|
||
|
}
|
||
|
|
||
|
type txCtxKey struct{}
|
||
|
|
||
|
var TxCtxKey = txCtxKey{}
|
||
|
|
||
|
func Transaction(ctx context.Context, db *sql.DB, fn func(context.Context) error) (err error) {
|
||
|
var tx *sql.Tx = new(sql.Tx)
|
||
|
|
||
|
hasExternalTx := hasExternalTransaction(ctx)
|
||
|
|
||
|
defer func() {
|
||
|
if hasExternalTx {
|
||
|
if err != nil {
|
||
|
err = fmt.Errorf("error perform operation. %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||
|
err = errors.Join(fmt.Errorf("error rollback transaction. %w", rbErr), err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
err = err
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if commitErr := tx.Commit(); commitErr != nil {
|
||
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||
|
err = errors.Join(fmt.Errorf("error rollback transaction. %w", rbErr), commitErr, err)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
err = fmt.Errorf("error commit transaction. %w", err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if !hasExternalTx {
|
||
|
tx, err = db.BeginTx(ctx, &sql.TxOptions{
|
||
|
Isolation: sql.LevelRepeatableRead,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("error begin transaction. %w", err)
|
||
|
}
|
||
|
|
||
|
ctx = context.WithValue(ctx, TxCtxKey, tx)
|
||
|
}
|
||
|
|
||
|
return fn(ctx)
|
||
|
}
|
||
|
|
||
|
func hasExternalTransaction(ctx context.Context) bool {
|
||
|
if _, ok := ctx.Value(TxCtxKey).(*sql.Tx); ok {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|