mirror of
https://github.com/emo2007/block-accounting.git
synced 2025-04-04 13:46:27 +00:00
tokens system improved
This commit is contained in:
parent
fc158b92e1
commit
876957aabe
@ -1,4 +1,4 @@
|
||||
# NoNameBlockchainAccounting backend
|
||||
# blockd backend
|
||||
## Build
|
||||
### Locally
|
||||
1. Install Go >= 1.22
|
||||
@ -84,7 +84,10 @@ curl --location 'http://localhost:8081/join' \
|
||||
Response:
|
||||
``` json
|
||||
{
|
||||
"token": "token-here"
|
||||
"token": "token",
|
||||
"token_expired_at": 1715975501581,
|
||||
"refresh_token": "refresh_token",
|
||||
"refresh_token_expired_at": 1716407501581
|
||||
}
|
||||
```
|
||||
|
||||
@ -105,7 +108,36 @@ curl --location 'http://localhost:8081/login' \
|
||||
Response:
|
||||
``` json
|
||||
{
|
||||
"token": "token-here"
|
||||
"token": "token",
|
||||
"token_expired_at": 1715975501581,
|
||||
"refresh_token": "refresh_token",
|
||||
"refresh_token_expired_at": 1716407501581
|
||||
}
|
||||
```
|
||||
|
||||
## POST **/refresh**
|
||||
### Request body:
|
||||
token (string, **required**)
|
||||
refresh_token (string, **required**)
|
||||
|
||||
### Example
|
||||
Request:
|
||||
``` bash
|
||||
curl --location --request GET 'http://localhost:8081/refresh' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"token": "token",
|
||||
"refresh_token": "refresh_token"
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
``` json
|
||||
{
|
||||
"token": "token",
|
||||
"token_expired_at": 1715975501581,
|
||||
"refresh_token": "refresh_token",
|
||||
"refresh_token_expired_at": 1716407501581
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/jwt"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/organizations"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
||||
orepo "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
||||
urepo "github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
||||
@ -19,8 +20,12 @@ func provideUsersInteractor(
|
||||
return users.NewUsersInteractor(log.WithGroup("users-interactor"), usersRepo)
|
||||
}
|
||||
|
||||
func provideJWTInteractor(c config.Config, usersInteractor users.UsersInteractor) jwt.JWTInteractor {
|
||||
return jwt.NewWardenJWT(c.Common.JWTSecret, usersInteractor)
|
||||
func provideJWTInteractor(
|
||||
c config.Config,
|
||||
usersInteractor users.UsersInteractor,
|
||||
authRepository auth.Repository,
|
||||
) jwt.JWTInteractor {
|
||||
return jwt.NewJWT(c.Common.JWTSecret, usersInteractor, authRepository)
|
||||
}
|
||||
|
||||
func provideOrganizationsInteractor(
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/emochka2007/block-accounting/internal/pkg/config"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
||||
@ -19,6 +20,10 @@ func provideOrganizationsRepository(db *sql.DB) organizations.Repository {
|
||||
return organizations.NewRepository(db)
|
||||
}
|
||||
|
||||
func provideAuthRepository(db *sql.DB) auth.Repository {
|
||||
return auth.NewRepository(db)
|
||||
}
|
||||
|
||||
func provideRedisConnection(c config.Config) (*redis.Client, func()) {
|
||||
r := redis.NewClient(&redis.Options{
|
||||
Addr: c.DB.CacheHost,
|
||||
|
@ -20,6 +20,7 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
|
||||
provideUsersInteractor,
|
||||
provideOrganizationsRepository,
|
||||
provideOrganizationsInteractor,
|
||||
provideAuthRepository,
|
||||
provideJWTInteractor,
|
||||
interfaceSet,
|
||||
provideRestServer,
|
||||
|
@ -22,7 +22,8 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
|
||||
}
|
||||
usersRepository := provideUsersRepository(db)
|
||||
usersInteractor := provideUsersInteractor(logger, usersRepository)
|
||||
jwtInteractor := provideJWTInteractor(c, usersInteractor)
|
||||
authRepository := provideAuthRepository(db)
|
||||
jwtInteractor := provideJWTInteractor(c, usersInteractor, authRepository)
|
||||
authPresenter := provideAuthPresenter(jwtInteractor)
|
||||
authController := provideAuthController(logger, usersInteractor, authPresenter, jwtInteractor)
|
||||
organizationsRepository := provideOrganizationsRepository(db)
|
||||
|
@ -26,6 +26,7 @@ type AuthController interface {
|
||||
JoinWithInvite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||
Login(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||
Invite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||
Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||
}
|
||||
|
||||
type authController struct {
|
||||
@ -114,6 +115,29 @@ func (c *authController) Login(w http.ResponseWriter, req *http.Request) ([]byte
|
||||
return c.presenter.ResponseLogin(users[0])
|
||||
}
|
||||
|
||||
func (c *authController) Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error) {
|
||||
request, err := presenters.CreateRequest[domain.RefreshRequest](req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error create refresh request. %w", err)
|
||||
}
|
||||
|
||||
c.log.Debug(
|
||||
"refresh request",
|
||||
slog.String("token", request.Token),
|
||||
slog.String("refresh_token", request.RefreshToken),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
newTokens, err := c.jwtInteractor.RefreshToken(ctx, request.Token, request.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error refresh access token. %w", err)
|
||||
}
|
||||
|
||||
return c.presenter.ResponseRefresh(newTokens)
|
||||
}
|
||||
|
||||
// const mnemonicEntropyBitSize int = 256
|
||||
|
||||
func (c *authController) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) {
|
||||
|
@ -38,8 +38,16 @@ type LoginRequest struct {
|
||||
Mnemonic string `json:"mnemonic"`
|
||||
}
|
||||
|
||||
type RefreshRequest struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Token string `json:"token"`
|
||||
ExpiredAt int64 `json:"token_expired_at"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RTExpiredAt int64 `json:"refresh_token_expired_at"`
|
||||
}
|
||||
|
||||
// Organizations
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
type AuthPresenter interface {
|
||||
ResponseJoin(user *models.User) ([]byte, error)
|
||||
ResponseLogin(user *models.User) ([]byte, error)
|
||||
ResponseRefresh(tokens jwt.AccessToken) ([]byte, error)
|
||||
}
|
||||
|
||||
type authPresenter struct {
|
||||
@ -28,16 +29,17 @@ func NewAuthPresenter(
|
||||
}
|
||||
|
||||
func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
|
||||
resp := new(domain.JoinResponse)
|
||||
|
||||
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error create access token. %w", err)
|
||||
}
|
||||
|
||||
resp.Token = token
|
||||
|
||||
out, err := json.Marshal(resp)
|
||||
out, err := json.Marshal(domain.LoginResponse{
|
||||
Token: tokens.Token,
|
||||
RefreshToken: tokens.RefreshToken,
|
||||
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshal join response. %w", err)
|
||||
}
|
||||
@ -46,19 +48,34 @@ func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (p *authPresenter) ResponseLogin(user *models.User) ([]byte, error) {
|
||||
resp := new(domain.LoginResponse)
|
||||
|
||||
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error create access token. %w", err)
|
||||
}
|
||||
|
||||
resp.Token = token
|
||||
|
||||
out, err := json.Marshal(resp)
|
||||
out, err := json.Marshal(domain.LoginResponse{
|
||||
Token: tokens.Token,
|
||||
RefreshToken: tokens.RefreshToken,
|
||||
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshal login response. %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *authPresenter) ResponseRefresh(tokens jwt.AccessToken) ([]byte, error) {
|
||||
out, err := json.Marshal(domain.LoginResponse{
|
||||
Token: tokens.Token,
|
||||
RefreshToken: tokens.RefreshToken,
|
||||
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshal refresh response. %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
@ -96,6 +96,7 @@ func (s *Server) buildRouter() {
|
||||
|
||||
router.Post("/join", s.handle(s.controllers.Auth.Join, "join"))
|
||||
router.Post("/login", s.handle(s.controllers.Auth.Login, "login"))
|
||||
router.Get("/refresh", s.handle(s.controllers.Auth.Refresh, "refresh"))
|
||||
|
||||
router.Route("/organizations", func(r chi.Router) {
|
||||
r = r.With(s.withAuthorization)
|
||||
|
@ -2,12 +2,15 @@ package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/emochka2007/block-accounting/internal/pkg/models"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -18,38 +21,59 @@ var (
|
||||
)
|
||||
|
||||
type JWTInteractor interface {
|
||||
NewToken(user models.UserIdentity, duration time.Duration) (string, error)
|
||||
NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error)
|
||||
User(token string) (*models.User, error)
|
||||
RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error)
|
||||
}
|
||||
|
||||
type jwtInteractor struct {
|
||||
secret []byte
|
||||
usersInteractor users.UsersInteractor
|
||||
authRepository auth.Repository
|
||||
}
|
||||
|
||||
func NewWardenJWT(secret []byte, usersInteractor users.UsersInteractor) JWTInteractor {
|
||||
func NewJWT(
|
||||
secret []byte,
|
||||
usersInteractor users.UsersInteractor,
|
||||
authRepository auth.Repository,
|
||||
) JWTInteractor {
|
||||
return &jwtInteractor{
|
||||
secret: secret,
|
||||
usersInteractor: usersInteractor,
|
||||
authRepository: authRepository,
|
||||
}
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
Token string
|
||||
ExpiredAt time.Time
|
||||
|
||||
RefreshToken string
|
||||
RTExpiredAt time.Time
|
||||
}
|
||||
|
||||
// NewToken creates new JWT token for given user
|
||||
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (string, error) {
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
claims["uid"] = user.Id().String()
|
||||
claims["exp"] = time.Now().Add(duration).UnixMilli()
|
||||
|
||||
secret := w.secret
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error) {
|
||||
tokens, err := w.newTokens(user.Id(), duration)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error sign token. %w", err)
|
||||
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := w.authRepository.AddToken(ctx, auth.AddTokenParams{
|
||||
UserId: user.Id(),
|
||||
Token: tokens.Token,
|
||||
TokenExpiredAt: tokens.ExpiredAt,
|
||||
RefreshToken: tokens.RefreshToken,
|
||||
RefreshTokenExpiredAt: tokens.RTExpiredAt,
|
||||
CreatedAt: time.Now(),
|
||||
}); err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error save tokens into repository. %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
||||
@ -85,8 +109,24 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
|
||||
UserId: userId,
|
||||
Token: tokenStr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetch token from repository. %w", err)
|
||||
}
|
||||
|
||||
if tokens.TokenExpiredAt.Before(time.Now()) {
|
||||
return nil, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
|
||||
}
|
||||
|
||||
if tokens.UserId != userId {
|
||||
return nil, errors.Join(fmt.Errorf("error invalid user id. %w", err), ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
users, err := w.usersInteractor.Get(ctx, users.GetParams{
|
||||
Ids: uuid.UUIDs{userId},
|
||||
Ids: uuid.UUIDs{tokens.UserId},
|
||||
})
|
||||
if err != nil || len(users) == 0 {
|
||||
return nil, fmt.Errorf("error fetch user from repository. %w", err)
|
||||
@ -94,3 +134,153 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
||||
|
||||
return users[0], nil
|
||||
}
|
||||
|
||||
func (w *jwtInteractor) RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error) {
|
||||
claims := make(jwt.MapClaims)
|
||||
|
||||
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) {
|
||||
return w.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return AccessToken{}, errors.Join(fmt.Errorf("error parse jwt token. %w", err), ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
var userIdString string
|
||||
var ok bool
|
||||
|
||||
if userIdString, ok = claims["uid"].(string); !ok {
|
||||
return AccessToken{}, ErrorInvalidTokenClaims
|
||||
}
|
||||
|
||||
userId, err := uuid.Parse(userIdString)
|
||||
if err != nil {
|
||||
return AccessToken{}, errors.Join(fmt.Errorf("error parse user id. %w", err), ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
_, err = jwt.ParseWithClaims(rToken, claims, func(t *jwt.Token) (interface{}, error) {
|
||||
return w.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return AccessToken{}, errors.Join(fmt.Errorf("error parse refresh jwt token. %w", err), ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
if expDate, ok := claims["exp"].(float64); ok {
|
||||
if time.UnixMilli(int64(expDate)).Before(time.Now()) {
|
||||
return AccessToken{}, fmt.Errorf("error refresh token expired. %w", ErrorTokenExpired)
|
||||
}
|
||||
} else {
|
||||
return AccessToken{}, errors.Join(fmt.Errorf("error parse exp date. %w", err), ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
if userIdString, ok = claims["uid"].(string); !ok {
|
||||
return AccessToken{}, ErrorInvalidTokenClaims
|
||||
}
|
||||
|
||||
rTokenUserId, err := uuid.Parse(userIdString)
|
||||
if err != nil {
|
||||
return AccessToken{}, errors.Join(
|
||||
fmt.Errorf("error parse user id from refresh token. %w", err),
|
||||
ErrorInvalidTokenClaims,
|
||||
)
|
||||
}
|
||||
|
||||
if userId != rTokenUserId {
|
||||
return AccessToken{}, fmt.Errorf("error user ids corrupted. %w", ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
|
||||
UserId: userId,
|
||||
Token: token,
|
||||
RefreshToken: rToken,
|
||||
})
|
||||
if err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error fetch token from repository. %w", err)
|
||||
}
|
||||
|
||||
if tokens.RefreshTokenExpiredAt.Before(time.Now()) {
|
||||
return AccessToken{}, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
|
||||
}
|
||||
|
||||
rtHash := sha512.New()
|
||||
rtHash.Write([]byte(tokens.Token))
|
||||
|
||||
rtHashStringValid := base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
|
||||
|
||||
rtHashRaw, ok := claims["rt_hash"]
|
||||
if !ok {
|
||||
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
rtHashString, ok := rtHashRaw.(string)
|
||||
if !ok {
|
||||
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
if rtHashString != rtHashStringValid {
|
||||
return AccessToken{}, fmt.Errorf("error refresh token hash corrupted. %w", ErrorInvalidTokenClaims)
|
||||
}
|
||||
|
||||
newTokens, err := w.newTokens(userId, 24*time.Hour)
|
||||
if err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
|
||||
}
|
||||
|
||||
if err = w.authRepository.RefreshToken(ctx, auth.RefreshTokenParams{
|
||||
UserId: userId,
|
||||
OldToken: token,
|
||||
Token: newTokens.Token,
|
||||
TokenExpiredAt: newTokens.ExpiredAt,
|
||||
OldRefreshToken: rToken,
|
||||
RefreshToken: newTokens.RefreshToken,
|
||||
RefreshTokenExpiredAt: newTokens.RTExpiredAt,
|
||||
}); err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error update tokens. %w", err)
|
||||
}
|
||||
|
||||
return newTokens, nil
|
||||
}
|
||||
|
||||
func (w *jwtInteractor) newTokens(userId uuid.UUID, duration time.Duration) (AccessToken, error) {
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
|
||||
expAt := time.Now().Add(duration)
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
claims["uid"] = userId.String()
|
||||
claims["exp"] = expAt.UnixMilli()
|
||||
|
||||
secret := w.secret
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error sign token. %w", err)
|
||||
}
|
||||
|
||||
refreshToken := jwt.New(jwt.SigningMethodHS256)
|
||||
|
||||
rtHash := sha512.New()
|
||||
|
||||
rtHash.Write([]byte(tokenString))
|
||||
|
||||
rtExpAt := expAt.Add(time.Hour * 24 * 5)
|
||||
|
||||
claims = refreshToken.Claims.(jwt.MapClaims)
|
||||
claims["uid"] = userId.String()
|
||||
claims["exp"] = rtExpAt.UnixMilli()
|
||||
claims["rt_hash"] = base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
|
||||
|
||||
rtokenString, err := refreshToken.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return AccessToken{}, fmt.Errorf("error sign refresh token. %w", err)
|
||||
}
|
||||
|
||||
return AccessToken{
|
||||
Token: tokenString,
|
||||
ExpiredAt: expAt,
|
||||
RefreshToken: rtokenString,
|
||||
RTExpiredAt: rtExpAt,
|
||||
}, nil
|
||||
}
|
||||
|
191
backend/internal/usecase/repository/auth/repository.go
Normal file
191
backend/internal/usecase/repository/auth/repository.go
Normal file
@ -0,0 +1,191 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
sqltools "github.com/emochka2007/block-accounting/internal/pkg/sqlutils"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AddTokenParams struct {
|
||||
UserId uuid.UUID
|
||||
|
||||
Token string
|
||||
TokenExpiredAt time.Time
|
||||
|
||||
RefreshToken string
|
||||
RefreshTokenExpiredAt time.Time
|
||||
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type GetTokenParams struct {
|
||||
UserId uuid.UUID
|
||||
Token string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
type RefreshTokenParams struct {
|
||||
UserId uuid.UUID
|
||||
|
||||
OldToken string
|
||||
Token string
|
||||
TokenExpiredAt time.Time
|
||||
|
||||
OldRefreshToken string
|
||||
RefreshToken string
|
||||
RefreshTokenExpiredAt time.Time
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
UserId uuid.UUID
|
||||
|
||||
Token string
|
||||
TokenExpiredAt time.Time
|
||||
|
||||
RefreshToken string
|
||||
RefreshTokenExpiredAt time.Time
|
||||
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
AddToken(ctx context.Context, params AddTokenParams) error
|
||||
GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error)
|
||||
RefreshToken(ctx context.Context, params RefreshTokenParams) error
|
||||
}
|
||||
|
||||
type repositorySQL struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (r *repositorySQL) AddToken(ctx context.Context, params AddTokenParams) error {
|
||||
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||
query := sq.Insert("access_tokens").
|
||||
Columns(
|
||||
"user_id",
|
||||
"token",
|
||||
"refresh_token",
|
||||
"token_expired_at",
|
||||
"refresh_token_expired_at",
|
||||
).
|
||||
Values(
|
||||
params.UserId,
|
||||
params.Token,
|
||||
params.RefreshToken,
|
||||
params.TokenExpiredAt,
|
||||
params.RefreshTokenExpiredAt,
|
||||
).PlaceholderFormat(sq.Dollar)
|
||||
|
||||
if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
|
||||
return fmt.Errorf("error add tokens. %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repositorySQL) RefreshToken(ctx context.Context, params RefreshTokenParams) error {
|
||||
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||
updateQuery := sq.Update("access_tokens").
|
||||
SetMap(sq.Eq{
|
||||
"token": params.Token,
|
||||
"refresh_token": params.RefreshToken,
|
||||
"token_expired_at": params.TokenExpiredAt,
|
||||
"refresh_token_expired_at": params.RefreshTokenExpiredAt,
|
||||
}).
|
||||
Where(sq.Eq{
|
||||
"user_id": params.UserId,
|
||||
"token": params.OldToken,
|
||||
"refresh_token": params.OldRefreshToken,
|
||||
}).PlaceholderFormat(sq.Dollar)
|
||||
|
||||
if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
|
||||
return fmt.Errorf("error update tokens. %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repositorySQL) GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) {
|
||||
var token *AccessToken = new(AccessToken)
|
||||
|
||||
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||
query := sq.Select(
|
||||
"user_id",
|
||||
"token",
|
||||
"token_expired_at",
|
||||
"refresh_token",
|
||||
"refresh_token_expired_at",
|
||||
"created_at",
|
||||
).From("access_tokens").
|
||||
Where(sq.Eq{
|
||||
"token": params.Token,
|
||||
"user_id": params.UserId,
|
||||
}).PlaceholderFormat(sq.Dollar)
|
||||
|
||||
if params.RefreshToken != "" {
|
||||
query = query.Where(sq.Eq{
|
||||
"refresh_token": params.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
rows, err := query.RunWith(r.Conn(ctx)).QueryContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetch token from database. %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if cErr := rows.Close(); cErr != nil {
|
||||
err = errors.Join(fmt.Errorf("error close database rows. %w", cErr), err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(
|
||||
&token.UserId,
|
||||
&token.Token,
|
||||
&token.TokenExpiredAt,
|
||||
&token.RefreshToken,
|
||||
&token.RefreshTokenExpiredAt,
|
||||
&token.CreatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("error scan row. %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func NewRepository(db *sql.DB) Repository {
|
||||
return &repositorySQL{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *repositorySQL) Conn(ctx context.Context) sqltools.DBTX {
|
||||
if tx, ok := ctx.Value(sqltools.TxCtxKey).(*sql.Tx); ok {
|
||||
return tx
|
||||
}
|
||||
|
||||
return s.db
|
||||
}
|
@ -24,6 +24,22 @@ create index if not exists index_users_phone
|
||||
create index if not exists index_users_seed
|
||||
on users using hash (seed);
|
||||
|
||||
create table if not exists access_tokens (
|
||||
user_id uuid not null references users(id),
|
||||
token varchar(350) not null,
|
||||
token_expired_at timestamp,
|
||||
refresh_token varchar(350) not null,
|
||||
refresh_token_expired_at timestamp,
|
||||
created_at timestamp default current_timestamp,
|
||||
remote_addr string
|
||||
);
|
||||
|
||||
create index if not exists index_access_tokens_token_refresh_token
|
||||
on access_tokens (token, refresh_token);
|
||||
|
||||
create index if not exists index_access_tokens_token_refresh_token_exp
|
||||
on access_tokens (token, refresh_token, token_expired_at, refresh_token_expired_at);
|
||||
|
||||
create table if not exists organizations (
|
||||
id uuid primary key unique,
|
||||
name varchar(300) default 'My Organization' not null,
|
||||
@ -120,9 +136,12 @@ create table contracts (
|
||||
title varchar(250) default 'New Contract',
|
||||
description text not null,
|
||||
|
||||
address bytea not null,
|
||||
|
||||
created_by uuid not null references users(id),
|
||||
organization_id uuid not null references organizations(id),
|
||||
|
||||
created_at timestamp default current_timestamp,
|
||||
updated_at timestamp default current_timestamp
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user