block-accounting/backend/internal/infrastructure/repository/auth/repository.go
2024-06-04 00:54:17 +03:00

273 lines
6.1 KiB
Go

package auth
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/emochka2007/block-accounting/internal/pkg/models"
sqltools "github.com/emochka2007/block-accounting/internal/pkg/sqlutils"
"github.com/google/uuid"
)
type AddTokenParams struct {
UserId uuid.UUID
Token string
TokenExpiredAt time.Time
RefreshToken string
RefreshTokenExpiredAt time.Time
CreatedAt time.Time
RemoteAddr string
}
type GetTokenParams struct {
UserId uuid.UUID
Token string
RefreshToken string
}
type RefreshTokenParams struct {
UserId uuid.UUID
OldToken string
Token string
TokenExpiredAt time.Time
OldRefreshToken string
RefreshToken string
RefreshTokenExpiredAt time.Time
}
type AccessToken struct {
UserId uuid.UUID
Token string
TokenExpiredAt time.Time
RefreshToken string
RefreshTokenExpiredAt time.Time
CreatedAt time.Time
}
type Repository interface {
AddToken(ctx context.Context, params AddTokenParams) error
GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error)
RefreshToken(ctx context.Context, params RefreshTokenParams) error
AddInvite(ctx context.Context, params AddInviteParams) error
MarkAsUsedLink(ctx context.Context, linkHash string, usedAt time.Time) (uuid.UUID, error)
}
type repositorySQL struct {
db *sql.DB
}
func (r *repositorySQL) AddToken(ctx context.Context, params AddTokenParams) error {
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Insert("access_tokens").
Columns(
"user_id",
"token",
"refresh_token",
"token_expired_at",
"refresh_token_expired_at",
"remote_addr",
).
Values(
params.UserId,
params.Token,
params.RefreshToken,
params.TokenExpiredAt,
params.RefreshTokenExpiredAt,
params.RemoteAddr,
).PlaceholderFormat(sq.Dollar)
if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error add tokens. %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (r *repositorySQL) RefreshToken(ctx context.Context, params RefreshTokenParams) error {
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
updateQuery := sq.Update("access_tokens").
SetMap(sq.Eq{
"token": params.Token,
"refresh_token": params.RefreshToken,
"token_expired_at": params.TokenExpiredAt,
"refresh_token_expired_at": params.RefreshTokenExpiredAt,
}).
Where(sq.Eq{
"user_id": params.UserId,
"token": params.OldToken,
"refresh_token": params.OldRefreshToken,
}).PlaceholderFormat(sq.Dollar)
if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error update tokens. %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (r *repositorySQL) GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) {
var token *AccessToken = new(AccessToken)
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Select(
"user_id",
"token",
"token_expired_at",
"refresh_token",
"refresh_token_expired_at",
"created_at",
).From("access_tokens").
Where(sq.Eq{
"token": params.Token,
"user_id": params.UserId,
}).PlaceholderFormat(sq.Dollar)
if params.RefreshToken != "" {
query = query.Where(sq.Eq{
"refresh_token": params.RefreshToken,
})
}
rows, err := query.RunWith(r.Conn(ctx)).QueryContext(ctx)
if err != nil {
return fmt.Errorf("error fetch token from database. %w", err)
}
defer func() {
if cErr := rows.Close(); cErr != nil {
err = errors.Join(fmt.Errorf("error close database rows. %w", cErr), err)
}
}()
for rows.Next() {
if err := rows.Scan(
&token.UserId,
&token.Token,
&token.TokenExpiredAt,
&token.RefreshToken,
&token.RefreshTokenExpiredAt,
&token.CreatedAt,
); err != nil {
return fmt.Errorf("error scan row. %w", err)
}
}
return nil
}); err != nil {
return nil, err
}
return token, nil
}
type AddInviteParams struct {
LinkHash string
OrganizationID uuid.UUID
CreatedBy models.User
CreatedAt time.Time
ExpiredAt time.Time
}
func (r *repositorySQL) AddInvite(
ctx context.Context,
params AddInviteParams,
) error {
return sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Insert("invites").Columns(
"link_hash",
"organization_id",
"created_by",
"created_at",
"expired_at",
).Values(
params.LinkHash,
params.OrganizationID,
params.CreatedBy.Id(),
params.CreatedAt,
params.ExpiredAt,
).PlaceholderFormat(sq.Dollar)
if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error add invite link. %w", err)
}
return nil
})
}
func (r *repositorySQL) MarkAsUsedLink(
ctx context.Context,
linkHash string,
usedAt time.Time,
) (uuid.UUID, error) {
var orgID uuid.UUID
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Select("organization_id", "expired_at").From("invites").Where(sq.Eq{
"link_hash": linkHash,
}).Limit(1).PlaceholderFormat(sq.Dollar)
var expAt time.Time
if err := query.RunWith(r.Conn(ctx)).QueryRowContext(ctx).Scan(&orgID, &expAt); err != nil {
return fmt.Errorf("error fetch expiration date from database. %w", err)
}
if expAt.Before(time.Now()) {
return ErrorInviteLinkExpired
}
updateQuery := sq.Update("invites").SetMap(sq.Eq{
"used_at": usedAt,
}).PlaceholderFormat(sq.Dollar)
if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error add invite link. %w", err)
}
return nil
}); err != nil {
return uuid.Nil, err
}
return orgID, nil
}
func NewRepository(db *sql.DB) Repository {
return &repositorySQL{
db: db,
}
}
func (s *repositorySQL) Conn(ctx context.Context) sqltools.DBTX {
if tx, ok := ctx.Value(sqltools.TxCtxKey).(*sql.Tx); ok {
return tx
}
return s.db
}