mirror of
https://github.com/emo2007/block-accounting.git
synced 2025-04-12 08:56:28 +00:00
tokens system improved
This commit is contained in:
parent
fc158b92e1
commit
876957aabe
@ -1,4 +1,4 @@
|
|||||||
# NoNameBlockchainAccounting backend
|
# blockd backend
|
||||||
## Build
|
## Build
|
||||||
### Locally
|
### Locally
|
||||||
1. Install Go >= 1.22
|
1. Install Go >= 1.22
|
||||||
@ -84,7 +84,10 @@ curl --location 'http://localhost:8081/join' \
|
|||||||
Response:
|
Response:
|
||||||
``` json
|
``` json
|
||||||
{
|
{
|
||||||
"token": "token-here"
|
"token": "token",
|
||||||
|
"token_expired_at": 1715975501581,
|
||||||
|
"refresh_token": "refresh_token",
|
||||||
|
"refresh_token_expired_at": 1716407501581
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -105,7 +108,36 @@ curl --location 'http://localhost:8081/login' \
|
|||||||
Response:
|
Response:
|
||||||
``` json
|
``` json
|
||||||
{
|
{
|
||||||
"token": "token-here"
|
"token": "token",
|
||||||
|
"token_expired_at": 1715975501581,
|
||||||
|
"refresh_token": "refresh_token",
|
||||||
|
"refresh_token_expired_at": 1716407501581
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## POST **/refresh**
|
||||||
|
### Request body:
|
||||||
|
token (string, **required**)
|
||||||
|
refresh_token (string, **required**)
|
||||||
|
|
||||||
|
### Example
|
||||||
|
Request:
|
||||||
|
``` bash
|
||||||
|
curl --location --request GET 'http://localhost:8081/refresh' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"token": "token",
|
||||||
|
"refresh_token": "refresh_token"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
``` json
|
||||||
|
{
|
||||||
|
"token": "token",
|
||||||
|
"token_expired_at": 1715975501581,
|
||||||
|
"refresh_token": "refresh_token",
|
||||||
|
"refresh_token_expired_at": 1716407501581
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/jwt"
|
"github.com/emochka2007/block-accounting/internal/usecase/interactors/jwt"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/organizations"
|
"github.com/emochka2007/block-accounting/internal/usecase/interactors/organizations"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
||||||
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
||||||
orepo "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
orepo "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
||||||
urepo "github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
urepo "github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
||||||
@ -19,8 +20,12 @@ func provideUsersInteractor(
|
|||||||
return users.NewUsersInteractor(log.WithGroup("users-interactor"), usersRepo)
|
return users.NewUsersInteractor(log.WithGroup("users-interactor"), usersRepo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func provideJWTInteractor(c config.Config, usersInteractor users.UsersInteractor) jwt.JWTInteractor {
|
func provideJWTInteractor(
|
||||||
return jwt.NewWardenJWT(c.Common.JWTSecret, usersInteractor)
|
c config.Config,
|
||||||
|
usersInteractor users.UsersInteractor,
|
||||||
|
authRepository auth.Repository,
|
||||||
|
) jwt.JWTInteractor {
|
||||||
|
return jwt.NewJWT(c.Common.JWTSecret, usersInteractor, authRepository)
|
||||||
}
|
}
|
||||||
|
|
||||||
func provideOrganizationsInteractor(
|
func provideOrganizationsInteractor(
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"github.com/emochka2007/block-accounting/internal/pkg/config"
|
"github.com/emochka2007/block-accounting/internal/pkg/config"
|
||||||
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/users"
|
||||||
@ -19,6 +20,10 @@ func provideOrganizationsRepository(db *sql.DB) organizations.Repository {
|
|||||||
return organizations.NewRepository(db)
|
return organizations.NewRepository(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func provideAuthRepository(db *sql.DB) auth.Repository {
|
||||||
|
return auth.NewRepository(db)
|
||||||
|
}
|
||||||
|
|
||||||
func provideRedisConnection(c config.Config) (*redis.Client, func()) {
|
func provideRedisConnection(c config.Config) (*redis.Client, func()) {
|
||||||
r := redis.NewClient(&redis.Options{
|
r := redis.NewClient(&redis.Options{
|
||||||
Addr: c.DB.CacheHost,
|
Addr: c.DB.CacheHost,
|
||||||
|
@ -20,6 +20,7 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
|
|||||||
provideUsersInteractor,
|
provideUsersInteractor,
|
||||||
provideOrganizationsRepository,
|
provideOrganizationsRepository,
|
||||||
provideOrganizationsInteractor,
|
provideOrganizationsInteractor,
|
||||||
|
provideAuthRepository,
|
||||||
provideJWTInteractor,
|
provideJWTInteractor,
|
||||||
interfaceSet,
|
interfaceSet,
|
||||||
provideRestServer,
|
provideRestServer,
|
||||||
|
@ -22,7 +22,8 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
|
|||||||
}
|
}
|
||||||
usersRepository := provideUsersRepository(db)
|
usersRepository := provideUsersRepository(db)
|
||||||
usersInteractor := provideUsersInteractor(logger, usersRepository)
|
usersInteractor := provideUsersInteractor(logger, usersRepository)
|
||||||
jwtInteractor := provideJWTInteractor(c, usersInteractor)
|
authRepository := provideAuthRepository(db)
|
||||||
|
jwtInteractor := provideJWTInteractor(c, usersInteractor, authRepository)
|
||||||
authPresenter := provideAuthPresenter(jwtInteractor)
|
authPresenter := provideAuthPresenter(jwtInteractor)
|
||||||
authController := provideAuthController(logger, usersInteractor, authPresenter, jwtInteractor)
|
authController := provideAuthController(logger, usersInteractor, authPresenter, jwtInteractor)
|
||||||
organizationsRepository := provideOrganizationsRepository(db)
|
organizationsRepository := provideOrganizationsRepository(db)
|
||||||
|
@ -26,6 +26,7 @@ type AuthController interface {
|
|||||||
JoinWithInvite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
JoinWithInvite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||||
Login(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
Login(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||||
Invite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
Invite(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||||
|
Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authController struct {
|
type authController struct {
|
||||||
@ -114,6 +115,29 @@ func (c *authController) Login(w http.ResponseWriter, req *http.Request) ([]byte
|
|||||||
return c.presenter.ResponseLogin(users[0])
|
return c.presenter.ResponseLogin(users[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *authController) Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error) {
|
||||||
|
request, err := presenters.CreateRequest[domain.RefreshRequest](req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error create refresh request. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Debug(
|
||||||
|
"refresh request",
|
||||||
|
slog.String("token", request.Token),
|
||||||
|
slog.String("refresh_token", request.RefreshToken),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
newTokens, err := c.jwtInteractor.RefreshToken(ctx, request.Token, request.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error refresh access token. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.presenter.ResponseRefresh(newTokens)
|
||||||
|
}
|
||||||
|
|
||||||
// const mnemonicEntropyBitSize int = 256
|
// const mnemonicEntropyBitSize int = 256
|
||||||
|
|
||||||
func (c *authController) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) {
|
func (c *authController) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) {
|
||||||
|
@ -38,8 +38,16 @@ type LoginRequest struct {
|
|||||||
Mnemonic string `json:"mnemonic"`
|
Mnemonic string `json:"mnemonic"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RefreshRequest struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
ExpiredAt int64 `json:"token_expired_at"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
RTExpiredAt int64 `json:"refresh_token_expired_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Organizations
|
// Organizations
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
type AuthPresenter interface {
|
type AuthPresenter interface {
|
||||||
ResponseJoin(user *models.User) ([]byte, error)
|
ResponseJoin(user *models.User) ([]byte, error)
|
||||||
ResponseLogin(user *models.User) ([]byte, error)
|
ResponseLogin(user *models.User) ([]byte, error)
|
||||||
|
ResponseRefresh(tokens jwt.AccessToken) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authPresenter struct {
|
type authPresenter struct {
|
||||||
@ -28,16 +29,17 @@ func NewAuthPresenter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
|
func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
|
||||||
resp := new(domain.JoinResponse)
|
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||||
|
|
||||||
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error create access token. %w", err)
|
return nil, fmt.Errorf("error create access token. %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Token = token
|
out, err := json.Marshal(domain.LoginResponse{
|
||||||
|
Token: tokens.Token,
|
||||||
out, err := json.Marshal(resp)
|
RefreshToken: tokens.RefreshToken,
|
||||||
|
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||||
|
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshal join response. %w", err)
|
return nil, fmt.Errorf("error marshal join response. %w", err)
|
||||||
}
|
}
|
||||||
@ -46,19 +48,34 @@ func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *authPresenter) ResponseLogin(user *models.User) ([]byte, error) {
|
func (p *authPresenter) ResponseLogin(user *models.User) ([]byte, error) {
|
||||||
resp := new(domain.LoginResponse)
|
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
||||||
|
|
||||||
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error create access token. %w", err)
|
return nil, fmt.Errorf("error create access token. %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Token = token
|
out, err := json.Marshal(domain.LoginResponse{
|
||||||
|
Token: tokens.Token,
|
||||||
out, err := json.Marshal(resp)
|
RefreshToken: tokens.RefreshToken,
|
||||||
|
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||||
|
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshal login response. %w", err)
|
return nil, fmt.Errorf("error marshal login response. %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *authPresenter) ResponseRefresh(tokens jwt.AccessToken) ([]byte, error) {
|
||||||
|
out, err := json.Marshal(domain.LoginResponse{
|
||||||
|
Token: tokens.Token,
|
||||||
|
RefreshToken: tokens.RefreshToken,
|
||||||
|
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
|
||||||
|
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshal refresh response. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
@ -96,6 +96,7 @@ func (s *Server) buildRouter() {
|
|||||||
|
|
||||||
router.Post("/join", s.handle(s.controllers.Auth.Join, "join"))
|
router.Post("/join", s.handle(s.controllers.Auth.Join, "join"))
|
||||||
router.Post("/login", s.handle(s.controllers.Auth.Login, "login"))
|
router.Post("/login", s.handle(s.controllers.Auth.Login, "login"))
|
||||||
|
router.Get("/refresh", s.handle(s.controllers.Auth.Refresh, "refresh"))
|
||||||
|
|
||||||
router.Route("/organizations", func(r chi.Router) {
|
router.Route("/organizations", func(r chi.Router) {
|
||||||
r = r.With(s.withAuthorization)
|
r = r.With(s.withAuthorization)
|
||||||
|
@ -2,12 +2,15 @@ package jwt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/emochka2007/block-accounting/internal/pkg/models"
|
"github.com/emochka2007/block-accounting/internal/pkg/models"
|
||||||
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
|
||||||
|
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -18,38 +21,59 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type JWTInteractor interface {
|
type JWTInteractor interface {
|
||||||
NewToken(user models.UserIdentity, duration time.Duration) (string, error)
|
NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error)
|
||||||
User(token string) (*models.User, error)
|
User(token string) (*models.User, error)
|
||||||
|
RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type jwtInteractor struct {
|
type jwtInteractor struct {
|
||||||
secret []byte
|
secret []byte
|
||||||
usersInteractor users.UsersInteractor
|
usersInteractor users.UsersInteractor
|
||||||
|
authRepository auth.Repository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWardenJWT(secret []byte, usersInteractor users.UsersInteractor) JWTInteractor {
|
func NewJWT(
|
||||||
|
secret []byte,
|
||||||
|
usersInteractor users.UsersInteractor,
|
||||||
|
authRepository auth.Repository,
|
||||||
|
) JWTInteractor {
|
||||||
return &jwtInteractor{
|
return &jwtInteractor{
|
||||||
secret: secret,
|
secret: secret,
|
||||||
usersInteractor: usersInteractor,
|
usersInteractor: usersInteractor,
|
||||||
|
authRepository: authRepository,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccessToken struct {
|
||||||
|
Token string
|
||||||
|
ExpiredAt time.Time
|
||||||
|
|
||||||
|
RefreshToken string
|
||||||
|
RTExpiredAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// NewToken creates new JWT token for given user
|
// NewToken creates new JWT token for given user
|
||||||
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (string, error) {
|
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error) {
|
||||||
token := jwt.New(jwt.SigningMethodHS256)
|
tokens, err := w.newTokens(user.Id(), duration)
|
||||||
|
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
|
||||||
claims["uid"] = user.Id().String()
|
|
||||||
claims["exp"] = time.Now().Add(duration).UnixMilli()
|
|
||||||
|
|
||||||
secret := w.secret
|
|
||||||
|
|
||||||
tokenString, err := token.SignedString([]byte(secret))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error sign token. %w", err)
|
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenString, nil
|
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := w.authRepository.AddToken(ctx, auth.AddTokenParams{
|
||||||
|
UserId: user.Id(),
|
||||||
|
Token: tokens.Token,
|
||||||
|
TokenExpiredAt: tokens.ExpiredAt,
|
||||||
|
RefreshToken: tokens.RefreshToken,
|
||||||
|
RefreshTokenExpiredAt: tokens.RTExpiredAt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}); err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error save tokens into repository. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
||||||
@ -85,8 +109,24 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
|||||||
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
|
||||||
|
UserId: userId,
|
||||||
|
Token: tokenStr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error fetch token from repository. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokens.TokenExpiredAt.Before(time.Now()) {
|
||||||
|
return nil, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokens.UserId != userId {
|
||||||
|
return nil, errors.Join(fmt.Errorf("error invalid user id. %w", err), ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
users, err := w.usersInteractor.Get(ctx, users.GetParams{
|
users, err := w.usersInteractor.Get(ctx, users.GetParams{
|
||||||
Ids: uuid.UUIDs{userId},
|
Ids: uuid.UUIDs{tokens.UserId},
|
||||||
})
|
})
|
||||||
if err != nil || len(users) == 0 {
|
if err != nil || len(users) == 0 {
|
||||||
return nil, fmt.Errorf("error fetch user from repository. %w", err)
|
return nil, fmt.Errorf("error fetch user from repository. %w", err)
|
||||||
@ -94,3 +134,153 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
|
|||||||
|
|
||||||
return users[0], nil
|
return users[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *jwtInteractor) RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error) {
|
||||||
|
claims := make(jwt.MapClaims)
|
||||||
|
|
||||||
|
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
return w.secret, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Join(fmt.Errorf("error parse jwt token. %w", err), ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
var userIdString string
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
if userIdString, ok = claims["uid"].(string); !ok {
|
||||||
|
return AccessToken{}, ErrorInvalidTokenClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, err := uuid.Parse(userIdString)
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Join(fmt.Errorf("error parse user id. %w", err), ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = jwt.ParseWithClaims(rToken, claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
return w.secret, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Join(fmt.Errorf("error parse refresh jwt token. %w", err), ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expDate, ok := claims["exp"].(float64); ok {
|
||||||
|
if time.UnixMilli(int64(expDate)).Before(time.Now()) {
|
||||||
|
return AccessToken{}, fmt.Errorf("error refresh token expired. %w", ErrorTokenExpired)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return AccessToken{}, errors.Join(fmt.Errorf("error parse exp date. %w", err), ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userIdString, ok = claims["uid"].(string); !ok {
|
||||||
|
return AccessToken{}, ErrorInvalidTokenClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
rTokenUserId, err := uuid.Parse(userIdString)
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, errors.Join(
|
||||||
|
fmt.Errorf("error parse user id from refresh token. %w", err),
|
||||||
|
ErrorInvalidTokenClaims,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userId != rTokenUserId {
|
||||||
|
return AccessToken{}, fmt.Errorf("error user ids corrupted. %w", ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
|
||||||
|
UserId: userId,
|
||||||
|
Token: token,
|
||||||
|
RefreshToken: rToken,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error fetch token from repository. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokens.RefreshTokenExpiredAt.Before(time.Now()) {
|
||||||
|
return AccessToken{}, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
|
||||||
|
}
|
||||||
|
|
||||||
|
rtHash := sha512.New()
|
||||||
|
rtHash.Write([]byte(tokens.Token))
|
||||||
|
|
||||||
|
rtHashStringValid := base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
|
||||||
|
|
||||||
|
rtHashRaw, ok := claims["rt_hash"]
|
||||||
|
if !ok {
|
||||||
|
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
rtHashString, ok := rtHashRaw.(string)
|
||||||
|
if !ok {
|
||||||
|
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtHashString != rtHashStringValid {
|
||||||
|
return AccessToken{}, fmt.Errorf("error refresh token hash corrupted. %w", ErrorInvalidTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
newTokens, err := w.newTokens(userId, 24*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = w.authRepository.RefreshToken(ctx, auth.RefreshTokenParams{
|
||||||
|
UserId: userId,
|
||||||
|
OldToken: token,
|
||||||
|
Token: newTokens.Token,
|
||||||
|
TokenExpiredAt: newTokens.ExpiredAt,
|
||||||
|
OldRefreshToken: rToken,
|
||||||
|
RefreshToken: newTokens.RefreshToken,
|
||||||
|
RefreshTokenExpiredAt: newTokens.RTExpiredAt,
|
||||||
|
}); err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error update tokens. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *jwtInteractor) newTokens(userId uuid.UUID, duration time.Duration) (AccessToken, error) {
|
||||||
|
token := jwt.New(jwt.SigningMethodHS256)
|
||||||
|
|
||||||
|
expAt := time.Now().Add(duration)
|
||||||
|
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
claims["uid"] = userId.String()
|
||||||
|
claims["exp"] = expAt.UnixMilli()
|
||||||
|
|
||||||
|
secret := w.secret
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error sign token. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken := jwt.New(jwt.SigningMethodHS256)
|
||||||
|
|
||||||
|
rtHash := sha512.New()
|
||||||
|
|
||||||
|
rtHash.Write([]byte(tokenString))
|
||||||
|
|
||||||
|
rtExpAt := expAt.Add(time.Hour * 24 * 5)
|
||||||
|
|
||||||
|
claims = refreshToken.Claims.(jwt.MapClaims)
|
||||||
|
claims["uid"] = userId.String()
|
||||||
|
claims["exp"] = rtExpAt.UnixMilli()
|
||||||
|
claims["rt_hash"] = base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
|
||||||
|
|
||||||
|
rtokenString, err := refreshToken.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
return AccessToken{}, fmt.Errorf("error sign refresh token. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return AccessToken{
|
||||||
|
Token: tokenString,
|
||||||
|
ExpiredAt: expAt,
|
||||||
|
RefreshToken: rtokenString,
|
||||||
|
RTExpiredAt: rtExpAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
191
backend/internal/usecase/repository/auth/repository.go
Normal file
191
backend/internal/usecase/repository/auth/repository.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
sq "github.com/Masterminds/squirrel"
|
||||||
|
sqltools "github.com/emochka2007/block-accounting/internal/pkg/sqlutils"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddTokenParams struct {
|
||||||
|
UserId uuid.UUID
|
||||||
|
|
||||||
|
Token string
|
||||||
|
TokenExpiredAt time.Time
|
||||||
|
|
||||||
|
RefreshToken string
|
||||||
|
RefreshTokenExpiredAt time.Time
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetTokenParams struct {
|
||||||
|
UserId uuid.UUID
|
||||||
|
Token string
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RefreshTokenParams struct {
|
||||||
|
UserId uuid.UUID
|
||||||
|
|
||||||
|
OldToken string
|
||||||
|
Token string
|
||||||
|
TokenExpiredAt time.Time
|
||||||
|
|
||||||
|
OldRefreshToken string
|
||||||
|
RefreshToken string
|
||||||
|
RefreshTokenExpiredAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessToken struct {
|
||||||
|
UserId uuid.UUID
|
||||||
|
|
||||||
|
Token string
|
||||||
|
TokenExpiredAt time.Time
|
||||||
|
|
||||||
|
RefreshToken string
|
||||||
|
RefreshTokenExpiredAt time.Time
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Repository interface {
|
||||||
|
AddToken(ctx context.Context, params AddTokenParams) error
|
||||||
|
GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error)
|
||||||
|
RefreshToken(ctx context.Context, params RefreshTokenParams) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type repositorySQL struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repositorySQL) AddToken(ctx context.Context, params AddTokenParams) error {
|
||||||
|
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||||
|
query := sq.Insert("access_tokens").
|
||||||
|
Columns(
|
||||||
|
"user_id",
|
||||||
|
"token",
|
||||||
|
"refresh_token",
|
||||||
|
"token_expired_at",
|
||||||
|
"refresh_token_expired_at",
|
||||||
|
).
|
||||||
|
Values(
|
||||||
|
params.UserId,
|
||||||
|
params.Token,
|
||||||
|
params.RefreshToken,
|
||||||
|
params.TokenExpiredAt,
|
||||||
|
params.RefreshTokenExpiredAt,
|
||||||
|
).PlaceholderFormat(sq.Dollar)
|
||||||
|
|
||||||
|
if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
|
||||||
|
return fmt.Errorf("error add tokens. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repositorySQL) RefreshToken(ctx context.Context, params RefreshTokenParams) error {
|
||||||
|
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||||
|
updateQuery := sq.Update("access_tokens").
|
||||||
|
SetMap(sq.Eq{
|
||||||
|
"token": params.Token,
|
||||||
|
"refresh_token": params.RefreshToken,
|
||||||
|
"token_expired_at": params.TokenExpiredAt,
|
||||||
|
"refresh_token_expired_at": params.RefreshTokenExpiredAt,
|
||||||
|
}).
|
||||||
|
Where(sq.Eq{
|
||||||
|
"user_id": params.UserId,
|
||||||
|
"token": params.OldToken,
|
||||||
|
"refresh_token": params.OldRefreshToken,
|
||||||
|
}).PlaceholderFormat(sq.Dollar)
|
||||||
|
|
||||||
|
if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
|
||||||
|
return fmt.Errorf("error update tokens. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *repositorySQL) GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) {
|
||||||
|
var token *AccessToken = new(AccessToken)
|
||||||
|
|
||||||
|
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
|
||||||
|
query := sq.Select(
|
||||||
|
"user_id",
|
||||||
|
"token",
|
||||||
|
"token_expired_at",
|
||||||
|
"refresh_token",
|
||||||
|
"refresh_token_expired_at",
|
||||||
|
"created_at",
|
||||||
|
).From("access_tokens").
|
||||||
|
Where(sq.Eq{
|
||||||
|
"token": params.Token,
|
||||||
|
"user_id": params.UserId,
|
||||||
|
}).PlaceholderFormat(sq.Dollar)
|
||||||
|
|
||||||
|
if params.RefreshToken != "" {
|
||||||
|
query = query.Where(sq.Eq{
|
||||||
|
"refresh_token": params.RefreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := query.RunWith(r.Conn(ctx)).QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error fetch token from database. %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if cErr := rows.Close(); cErr != nil {
|
||||||
|
err = errors.Join(fmt.Errorf("error close database rows. %w", cErr), err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(
|
||||||
|
&token.UserId,
|
||||||
|
&token.Token,
|
||||||
|
&token.TokenExpiredAt,
|
||||||
|
&token.RefreshToken,
|
||||||
|
&token.RefreshTokenExpiredAt,
|
||||||
|
&token.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("error scan row. %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRepository(db *sql.DB) Repository {
|
||||||
|
return &repositorySQL{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repositorySQL) Conn(ctx context.Context) sqltools.DBTX {
|
||||||
|
if tx, ok := ctx.Value(sqltools.TxCtxKey).(*sql.Tx); ok {
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db
|
||||||
|
}
|
@ -24,6 +24,22 @@ create index if not exists index_users_phone
|
|||||||
create index if not exists index_users_seed
|
create index if not exists index_users_seed
|
||||||
on users using hash (seed);
|
on users using hash (seed);
|
||||||
|
|
||||||
|
create table if not exists access_tokens (
|
||||||
|
user_id uuid not null references users(id),
|
||||||
|
token varchar(350) not null,
|
||||||
|
token_expired_at timestamp,
|
||||||
|
refresh_token varchar(350) not null,
|
||||||
|
refresh_token_expired_at timestamp,
|
||||||
|
created_at timestamp default current_timestamp,
|
||||||
|
remote_addr string
|
||||||
|
);
|
||||||
|
|
||||||
|
create index if not exists index_access_tokens_token_refresh_token
|
||||||
|
on access_tokens (token, refresh_token);
|
||||||
|
|
||||||
|
create index if not exists index_access_tokens_token_refresh_token_exp
|
||||||
|
on access_tokens (token, refresh_token, token_expired_at, refresh_token_expired_at);
|
||||||
|
|
||||||
create table if not exists organizations (
|
create table if not exists organizations (
|
||||||
id uuid primary key unique,
|
id uuid primary key unique,
|
||||||
name varchar(300) default 'My Organization' not null,
|
name varchar(300) default 'My Organization' not null,
|
||||||
@ -120,9 +136,12 @@ create table contracts (
|
|||||||
title varchar(250) default 'New Contract',
|
title varchar(250) default 'New Contract',
|
||||||
description text not null,
|
description text not null,
|
||||||
|
|
||||||
|
address bytea not null,
|
||||||
|
|
||||||
created_by uuid not null references users(id),
|
created_by uuid not null references users(id),
|
||||||
organization_id uuid not null references organizations(id),
|
organization_id uuid not null references organizations(id),
|
||||||
|
|
||||||
created_at timestamp default current_timestamp,
|
created_at timestamp default current_timestamp,
|
||||||
updated_at timestamp default current_timestamp
|
updated_at timestamp default current_timestamp
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user