tokens system improved

This commit is contained in:
r8zavetr8v 2024-05-17 00:19:33 +03:00
parent fc158b92e1
commit 876957aabe
12 changed files with 528 additions and 34 deletions

View File

@ -1,4 +1,4 @@
# NoNameBlockchainAccounting backend
# blockd backend
## Build
### Locally
1. Install Go >= 1.22
@ -84,7 +84,10 @@ curl --location 'http://localhost:8081/join' \
Response:
``` json
{
"token": "token-here"
"token": "token",
"token_expired_at": 1715975501581,
"refresh_token": "refresh_token",
"refresh_token_expired_at": 1716407501581
}
```
@ -105,7 +108,36 @@ curl --location 'http://localhost:8081/login' \
Response:
``` json
{
"token": "token-here"
"token": "token",
"token_expired_at": 1715975501581,
"refresh_token": "refresh_token",
"refresh_token_expired_at": 1716407501581
}
```
## POST **/refresh**
### Request body:
token (string, **required**)
refresh_token (string, **required**)
### Example
Request:
``` bash
curl --location --request GET 'http://localhost:8081/refresh' \
--header 'Content-Type: application/json' \
--data '{
"token": "token",
"refresh_token": "refresh_token"
}'
```
Response:
``` json
{
"token": "token",
"token_expired_at": 1715975501581,
"refresh_token": "refresh_token",
"refresh_token_expired_at": 1716407501581
}
```

View File

@ -7,6 +7,7 @@ import (
"github.com/emochka2007/block-accounting/internal/usecase/interactors/jwt"
"github.com/emochka2007/block-accounting/internal/usecase/interactors/organizations"
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
orepo "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
urepo "github.com/emochka2007/block-accounting/internal/usecase/repository/users"
@ -19,8 +20,12 @@ func provideUsersInteractor(
return users.NewUsersInteractor(log.WithGroup("users-interactor"), usersRepo)
}
func provideJWTInteractor(c config.Config, usersInteractor users.UsersInteractor) jwt.JWTInteractor {
return jwt.NewWardenJWT(c.Common.JWTSecret, usersInteractor)
func provideJWTInteractor(
c config.Config,
usersInteractor users.UsersInteractor,
authRepository auth.Repository,
) jwt.JWTInteractor {
return jwt.NewJWT(c.Common.JWTSecret, usersInteractor, authRepository)
}
func provideOrganizationsInteractor(

View File

@ -5,6 +5,7 @@ import (
"log/slog"
"github.com/emochka2007/block-accounting/internal/pkg/config"
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
"github.com/emochka2007/block-accounting/internal/usecase/repository/cache"
"github.com/emochka2007/block-accounting/internal/usecase/repository/organizations"
"github.com/emochka2007/block-accounting/internal/usecase/repository/users"
@ -19,6 +20,10 @@ func provideOrganizationsRepository(db *sql.DB) organizations.Repository {
return organizations.NewRepository(db)
}
func provideAuthRepository(db *sql.DB) auth.Repository {
return auth.NewRepository(db)
}
func provideRedisConnection(c config.Config) (*redis.Client, func()) {
r := redis.NewClient(&redis.Options{
Addr: c.DB.CacheHost,

View File

@ -20,6 +20,7 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
provideUsersInteractor,
provideOrganizationsRepository,
provideOrganizationsInteractor,
provideAuthRepository,
provideJWTInteractor,
interfaceSet,
provideRestServer,

View File

@ -22,7 +22,8 @@ func ProvideService(c config.Config) (service.Service, func(), error) {
}
usersRepository := provideUsersRepository(db)
usersInteractor := provideUsersInteractor(logger, usersRepository)
jwtInteractor := provideJWTInteractor(c, usersInteractor)
authRepository := provideAuthRepository(db)
jwtInteractor := provideJWTInteractor(c, usersInteractor, authRepository)
authPresenter := provideAuthPresenter(jwtInteractor)
authController := provideAuthController(logger, usersInteractor, authPresenter, jwtInteractor)
organizationsRepository := provideOrganizationsRepository(db)

View File

@ -26,6 +26,7 @@ type AuthController interface {
JoinWithInvite(w http.ResponseWriter, req *http.Request) ([]byte, error)
Login(w http.ResponseWriter, req *http.Request) ([]byte, error)
Invite(w http.ResponseWriter, req *http.Request) ([]byte, error)
Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error)
}
type authController struct {
@ -114,6 +115,29 @@ func (c *authController) Login(w http.ResponseWriter, req *http.Request) ([]byte
return c.presenter.ResponseLogin(users[0])
}
func (c *authController) Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error) {
request, err := presenters.CreateRequest[domain.RefreshRequest](req)
if err != nil {
return nil, fmt.Errorf("error create refresh request. %w", err)
}
c.log.Debug(
"refresh request",
slog.String("token", request.Token),
slog.String("refresh_token", request.RefreshToken),
)
ctx, cancel := context.WithTimeout(req.Context(), 3*time.Second)
defer cancel()
newTokens, err := c.jwtInteractor.RefreshToken(ctx, request.Token, request.RefreshToken)
if err != nil {
return nil, fmt.Errorf("error refresh access token. %w", err)
}
return c.presenter.ResponseRefresh(newTokens)
}
// const mnemonicEntropyBitSize int = 256
func (c *authController) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) {

View File

@ -38,8 +38,16 @@ type LoginRequest struct {
Mnemonic string `json:"mnemonic"`
}
type RefreshRequest struct {
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type LoginResponse struct {
Token string `json:"token"`
Token string `json:"token"`
ExpiredAt int64 `json:"token_expired_at"`
RefreshToken string `json:"refresh_token"`
RTExpiredAt int64 `json:"refresh_token_expired_at"`
}
// Organizations

View File

@ -13,6 +13,7 @@ import (
type AuthPresenter interface {
ResponseJoin(user *models.User) ([]byte, error)
ResponseLogin(user *models.User) ([]byte, error)
ResponseRefresh(tokens jwt.AccessToken) ([]byte, error)
}
type authPresenter struct {
@ -28,16 +29,17 @@ func NewAuthPresenter(
}
func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
resp := new(domain.JoinResponse)
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
if err != nil {
return nil, fmt.Errorf("error create access token. %w", err)
}
resp.Token = token
out, err := json.Marshal(resp)
out, err := json.Marshal(domain.LoginResponse{
Token: tokens.Token,
RefreshToken: tokens.RefreshToken,
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
})
if err != nil {
return nil, fmt.Errorf("error marshal join response. %w", err)
}
@ -46,19 +48,34 @@ func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) {
}
func (p *authPresenter) ResponseLogin(user *models.User) ([]byte, error) {
resp := new(domain.LoginResponse)
token, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour)
if err != nil {
return nil, fmt.Errorf("error create access token. %w", err)
}
resp.Token = token
out, err := json.Marshal(resp)
out, err := json.Marshal(domain.LoginResponse{
Token: tokens.Token,
RefreshToken: tokens.RefreshToken,
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
})
if err != nil {
return nil, fmt.Errorf("error marshal login response. %w", err)
}
return out, nil
}
func (p *authPresenter) ResponseRefresh(tokens jwt.AccessToken) ([]byte, error) {
out, err := json.Marshal(domain.LoginResponse{
Token: tokens.Token,
RefreshToken: tokens.RefreshToken,
ExpiredAt: tokens.ExpiredAt.UnixMilli(),
RTExpiredAt: tokens.RTExpiredAt.UnixMilli(),
})
if err != nil {
return nil, fmt.Errorf("error marshal refresh response. %w", err)
}
return out, nil
}

View File

@ -96,6 +96,7 @@ func (s *Server) buildRouter() {
router.Post("/join", s.handle(s.controllers.Auth.Join, "join"))
router.Post("/login", s.handle(s.controllers.Auth.Login, "login"))
router.Get("/refresh", s.handle(s.controllers.Auth.Refresh, "refresh"))
router.Route("/organizations", func(r chi.Router) {
r = r.With(s.withAuthorization)

View File

@ -2,12 +2,15 @@ package jwt
import (
"context"
"crypto/sha512"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/emochka2007/block-accounting/internal/pkg/models"
"github.com/emochka2007/block-accounting/internal/usecase/interactors/users"
"github.com/emochka2007/block-accounting/internal/usecase/repository/auth"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
@ -18,38 +21,59 @@ var (
)
type JWTInteractor interface {
NewToken(user models.UserIdentity, duration time.Duration) (string, error)
NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error)
User(token string) (*models.User, error)
RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error)
}
type jwtInteractor struct {
secret []byte
usersInteractor users.UsersInteractor
authRepository auth.Repository
}
func NewWardenJWT(secret []byte, usersInteractor users.UsersInteractor) JWTInteractor {
func NewJWT(
secret []byte,
usersInteractor users.UsersInteractor,
authRepository auth.Repository,
) JWTInteractor {
return &jwtInteractor{
secret: secret,
usersInteractor: usersInteractor,
authRepository: authRepository,
}
}
type AccessToken struct {
Token string
ExpiredAt time.Time
RefreshToken string
RTExpiredAt time.Time
}
// NewToken creates new JWT token for given user
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (string, error) {
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["uid"] = user.Id().String()
claims["exp"] = time.Now().Add(duration).UnixMilli()
secret := w.secret
tokenString, err := token.SignedString([]byte(secret))
func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error) {
tokens, err := w.newTokens(user.Id(), duration)
if err != nil {
return "", fmt.Errorf("error sign token. %w", err)
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
}
return tokenString, nil
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
defer cancel()
if err := w.authRepository.AddToken(ctx, auth.AddTokenParams{
UserId: user.Id(),
Token: tokens.Token,
TokenExpiredAt: tokens.ExpiredAt,
RefreshToken: tokens.RefreshToken,
RefreshTokenExpiredAt: tokens.RTExpiredAt,
CreatedAt: time.Now(),
}); err != nil {
return AccessToken{}, fmt.Errorf("error save tokens into repository. %w", err)
}
return tokens, nil
}
func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
@ -85,8 +109,24 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
defer cancel()
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
UserId: userId,
Token: tokenStr,
})
if err != nil {
return nil, fmt.Errorf("error fetch token from repository. %w", err)
}
if tokens.TokenExpiredAt.Before(time.Now()) {
return nil, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
}
if tokens.UserId != userId {
return nil, errors.Join(fmt.Errorf("error invalid user id. %w", err), ErrorInvalidTokenClaims)
}
users, err := w.usersInteractor.Get(ctx, users.GetParams{
Ids: uuid.UUIDs{userId},
Ids: uuid.UUIDs{tokens.UserId},
})
if err != nil || len(users) == 0 {
return nil, fmt.Errorf("error fetch user from repository. %w", err)
@ -94,3 +134,153 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) {
return users[0], nil
}
func (w *jwtInteractor) RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error) {
claims := make(jwt.MapClaims)
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) {
return w.secret, nil
})
if err != nil {
return AccessToken{}, errors.Join(fmt.Errorf("error parse jwt token. %w", err), ErrorInvalidTokenClaims)
}
var userIdString string
var ok bool
if userIdString, ok = claims["uid"].(string); !ok {
return AccessToken{}, ErrorInvalidTokenClaims
}
userId, err := uuid.Parse(userIdString)
if err != nil {
return AccessToken{}, errors.Join(fmt.Errorf("error parse user id. %w", err), ErrorInvalidTokenClaims)
}
_, err = jwt.ParseWithClaims(rToken, claims, func(t *jwt.Token) (interface{}, error) {
return w.secret, nil
})
if err != nil {
return AccessToken{}, errors.Join(fmt.Errorf("error parse refresh jwt token. %w", err), ErrorInvalidTokenClaims)
}
if expDate, ok := claims["exp"].(float64); ok {
if time.UnixMilli(int64(expDate)).Before(time.Now()) {
return AccessToken{}, fmt.Errorf("error refresh token expired. %w", ErrorTokenExpired)
}
} else {
return AccessToken{}, errors.Join(fmt.Errorf("error parse exp date. %w", err), ErrorInvalidTokenClaims)
}
if userIdString, ok = claims["uid"].(string); !ok {
return AccessToken{}, ErrorInvalidTokenClaims
}
rTokenUserId, err := uuid.Parse(userIdString)
if err != nil {
return AccessToken{}, errors.Join(
fmt.Errorf("error parse user id from refresh token. %w", err),
ErrorInvalidTokenClaims,
)
}
if userId != rTokenUserId {
return AccessToken{}, fmt.Errorf("error user ids corrupted. %w", ErrorInvalidTokenClaims)
}
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
defer cancel()
tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{
UserId: userId,
Token: token,
RefreshToken: rToken,
})
if err != nil {
return AccessToken{}, fmt.Errorf("error fetch token from repository. %w", err)
}
if tokens.RefreshTokenExpiredAt.Before(time.Now()) {
return AccessToken{}, fmt.Errorf("error token expired. %w", ErrorTokenExpired)
}
rtHash := sha512.New()
rtHash.Write([]byte(tokens.Token))
rtHashStringValid := base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
rtHashRaw, ok := claims["rt_hash"]
if !ok {
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
}
rtHashString, ok := rtHashRaw.(string)
if !ok {
return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims)
}
if rtHashString != rtHashStringValid {
return AccessToken{}, fmt.Errorf("error refresh token hash corrupted. %w", ErrorInvalidTokenClaims)
}
newTokens, err := w.newTokens(userId, 24*time.Hour)
if err != nil {
return AccessToken{}, fmt.Errorf("error create new tokens. %w", err)
}
if err = w.authRepository.RefreshToken(ctx, auth.RefreshTokenParams{
UserId: userId,
OldToken: token,
Token: newTokens.Token,
TokenExpiredAt: newTokens.ExpiredAt,
OldRefreshToken: rToken,
RefreshToken: newTokens.RefreshToken,
RefreshTokenExpiredAt: newTokens.RTExpiredAt,
}); err != nil {
return AccessToken{}, fmt.Errorf("error update tokens. %w", err)
}
return newTokens, nil
}
func (w *jwtInteractor) newTokens(userId uuid.UUID, duration time.Duration) (AccessToken, error) {
token := jwt.New(jwt.SigningMethodHS256)
expAt := time.Now().Add(duration)
claims := token.Claims.(jwt.MapClaims)
claims["uid"] = userId.String()
claims["exp"] = expAt.UnixMilli()
secret := w.secret
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return AccessToken{}, fmt.Errorf("error sign token. %w", err)
}
refreshToken := jwt.New(jwt.SigningMethodHS256)
rtHash := sha512.New()
rtHash.Write([]byte(tokenString))
rtExpAt := expAt.Add(time.Hour * 24 * 5)
claims = refreshToken.Claims.(jwt.MapClaims)
claims["uid"] = userId.String()
claims["exp"] = rtExpAt.UnixMilli()
claims["rt_hash"] = base64.StdEncoding.EncodeToString(rtHash.Sum(nil))
rtokenString, err := refreshToken.SignedString([]byte(secret))
if err != nil {
return AccessToken{}, fmt.Errorf("error sign refresh token. %w", err)
}
return AccessToken{
Token: tokenString,
ExpiredAt: expAt,
RefreshToken: rtokenString,
RTExpiredAt: rtExpAt,
}, nil
}

View File

@ -0,0 +1,191 @@
package auth
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
sq "github.com/Masterminds/squirrel"
sqltools "github.com/emochka2007/block-accounting/internal/pkg/sqlutils"
"github.com/google/uuid"
)
type AddTokenParams struct {
UserId uuid.UUID
Token string
TokenExpiredAt time.Time
RefreshToken string
RefreshTokenExpiredAt time.Time
CreatedAt time.Time
}
type GetTokenParams struct {
UserId uuid.UUID
Token string
RefreshToken string
}
type RefreshTokenParams struct {
UserId uuid.UUID
OldToken string
Token string
TokenExpiredAt time.Time
OldRefreshToken string
RefreshToken string
RefreshTokenExpiredAt time.Time
}
type AccessToken struct {
UserId uuid.UUID
Token string
TokenExpiredAt time.Time
RefreshToken string
RefreshTokenExpiredAt time.Time
CreatedAt time.Time
}
type Repository interface {
AddToken(ctx context.Context, params AddTokenParams) error
GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error)
RefreshToken(ctx context.Context, params RefreshTokenParams) error
}
type repositorySQL struct {
db *sql.DB
}
func (r *repositorySQL) AddToken(ctx context.Context, params AddTokenParams) error {
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Insert("access_tokens").
Columns(
"user_id",
"token",
"refresh_token",
"token_expired_at",
"refresh_token_expired_at",
).
Values(
params.UserId,
params.Token,
params.RefreshToken,
params.TokenExpiredAt,
params.RefreshTokenExpiredAt,
).PlaceholderFormat(sq.Dollar)
if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error add tokens. %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (r *repositorySQL) RefreshToken(ctx context.Context, params RefreshTokenParams) error {
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
updateQuery := sq.Update("access_tokens").
SetMap(sq.Eq{
"token": params.Token,
"refresh_token": params.RefreshToken,
"token_expired_at": params.TokenExpiredAt,
"refresh_token_expired_at": params.RefreshTokenExpiredAt,
}).
Where(sq.Eq{
"user_id": params.UserId,
"token": params.OldToken,
"refresh_token": params.OldRefreshToken,
}).PlaceholderFormat(sq.Dollar)
if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil {
return fmt.Errorf("error update tokens. %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (r *repositorySQL) GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) {
var token *AccessToken = new(AccessToken)
if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error {
query := sq.Select(
"user_id",
"token",
"token_expired_at",
"refresh_token",
"refresh_token_expired_at",
"created_at",
).From("access_tokens").
Where(sq.Eq{
"token": params.Token,
"user_id": params.UserId,
}).PlaceholderFormat(sq.Dollar)
if params.RefreshToken != "" {
query = query.Where(sq.Eq{
"refresh_token": params.RefreshToken,
})
}
rows, err := query.RunWith(r.Conn(ctx)).QueryContext(ctx)
if err != nil {
return fmt.Errorf("error fetch token from database. %w", err)
}
defer func() {
if cErr := rows.Close(); cErr != nil {
err = errors.Join(fmt.Errorf("error close database rows. %w", cErr), err)
}
}()
for rows.Next() {
if err := rows.Scan(
&token.UserId,
&token.Token,
&token.TokenExpiredAt,
&token.RefreshToken,
&token.RefreshTokenExpiredAt,
&token.CreatedAt,
); err != nil {
return fmt.Errorf("error scan row. %w", err)
}
}
return nil
}); err != nil {
return nil, err
}
return token, nil
}
func NewRepository(db *sql.DB) Repository {
return &repositorySQL{
db: db,
}
}
func (s *repositorySQL) Conn(ctx context.Context) sqltools.DBTX {
if tx, ok := ctx.Value(sqltools.TxCtxKey).(*sql.Tx); ok {
return tx
}
return s.db
}

View File

@ -24,6 +24,22 @@ create index if not exists index_users_phone
create index if not exists index_users_seed
on users using hash (seed);
create table if not exists access_tokens (
user_id uuid not null references users(id),
token varchar(350) not null,
token_expired_at timestamp,
refresh_token varchar(350) not null,
refresh_token_expired_at timestamp,
created_at timestamp default current_timestamp,
remote_addr string
);
create index if not exists index_access_tokens_token_refresh_token
on access_tokens (token, refresh_token);
create index if not exists index_access_tokens_token_refresh_token_exp
on access_tokens (token, refresh_token, token_expired_at, refresh_token_expired_at);
create table if not exists organizations (
id uuid primary key unique,
name varchar(300) default 'My Organization' not null,
@ -120,9 +136,12 @@ create table contracts (
title varchar(250) default 'New Contract',
description text not null,
address bytea not null,
created_by uuid not null references users(id),
organization_id uuid not null references organizations(id),
created_at timestamp default current_timestamp,
updated_at timestamp default current_timestamp
);