block-accounting/backend/internal/pkg/sqlutils/tx.go
2024-05-29 01:57:31 +03:00

86 lines
1.8 KiB
Go

package sqltools
import (
"context"
"database/sql"
"errors"
"fmt"
)
type DBTX interface {
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type TransactionalStorage interface {
Transaction(ctx context.Context, db *sql.DB, f func(ctx context.Context) error) error
Conn(ctx context.Context) DBTX
}
type txCtxKey struct{}
var TxCtxKey = txCtxKey{}
func Transaction(ctx context.Context, db *sql.DB, fn func(context.Context) error) (err error) {
var tx *sql.Tx = new(sql.Tx)
hasExternalTx := hasExternalTransaction(ctx)
defer func() {
if hasExternalTx {
if err != nil {
err = fmt.Errorf("error perform operation. %w", err)
return
}
return
}
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = errors.Join(fmt.Errorf("error rollback transaction. %w", rbErr), err)
return
}
err = err
return
}
if commitErr := tx.Commit(); commitErr != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = errors.Join(fmt.Errorf("error rollback transaction. %w", rbErr), commitErr, err)
return
}
err = fmt.Errorf("error commit transaction. %w", err)
}
}()
if !hasExternalTx {
tx, err = db.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
if err != nil {
return fmt.Errorf("error begin transaction. %w", err)
}
ctx = context.WithValue(ctx, TxCtxKey, tx)
}
return fn(ctx)
}
func hasExternalTransaction(ctx context.Context) bool {
if _, ok := ctx.Value(TxCtxKey).(*sql.Tx); ok {
return true
}
return false
}