From 876957aabe8e4895d6c27ae8809679c16983311f Mon Sep 17 00:00:00 2001 From: optclblast Date: Fri, 17 May 2024 00:19:33 +0300 Subject: [PATCH] tokens system improved --- backend/README.md | 38 ++- backend/internal/factory/interactors.go | 9 +- backend/internal/factory/repositories.go | 5 + backend/internal/factory/wire.go | 1 + backend/internal/factory/wire_gen.go | 3 +- .../interface/rest/controllers/auth.go | 24 ++ backend/internal/interface/rest/domain/dto.go | 10 +- .../interface/rest/presenters/auth.go | 41 +++- backend/internal/interface/rest/server.go | 1 + .../internal/usecase/interactors/jwt/jwt.go | 220 ++++++++++++++++-- .../usecase/repository/auth/repository.go | 191 +++++++++++++++ backend/migrations/blockd.sql | 19 ++ 12 files changed, 528 insertions(+), 34 deletions(-) create mode 100644 backend/internal/usecase/repository/auth/repository.go diff --git a/backend/README.md b/backend/README.md index c9311d8..a1fce16 100644 --- a/backend/README.md +++ b/backend/README.md @@ -1,4 +1,4 @@ -# NoNameBlockchainAccounting backend +# blockd backend ## Build ### Locally 1. Install Go >= 1.22 @@ -84,7 +84,10 @@ curl --location 'http://localhost:8081/join' \ Response: ``` json { - "token": "token-here" + "token": "token", + "token_expired_at": 1715975501581, + "refresh_token": "refresh_token", + "refresh_token_expired_at": 1716407501581 } ``` @@ -105,7 +108,36 @@ curl --location 'http://localhost:8081/login' \ Response: ``` json { - "token": "token-here" + "token": "token", + "token_expired_at": 1715975501581, + "refresh_token": "refresh_token", + "refresh_token_expired_at": 1716407501581 +} +``` + +## POST **/refresh** +### Request body: +token (string, **required**) +refresh_token (string, **required**) + +### Example +Request: +``` bash +curl --location --request GET 'http://localhost:8081/refresh' \ +--header 'Content-Type: application/json' \ +--data '{ + "token": "token", + "refresh_token": "refresh_token" +}' +``` + +Response: +``` json +{ + "token": "token", + "token_expired_at": 1715975501581, + "refresh_token": "refresh_token", + "refresh_token_expired_at": 1716407501581 } ``` diff --git a/backend/internal/factory/interactors.go b/backend/internal/factory/interactors.go index 0a5ee45..f86a786 100644 --- a/backend/internal/factory/interactors.go +++ b/backend/internal/factory/interactors.go @@ -7,6 +7,7 @@ import ( "github.com/emochka2007/block-accounting/internal/usecase/interactors/jwt" "github.com/emochka2007/block-accounting/internal/usecase/interactors/organizations" "github.com/emochka2007/block-accounting/internal/usecase/interactors/users" + "github.com/emochka2007/block-accounting/internal/usecase/repository/auth" "github.com/emochka2007/block-accounting/internal/usecase/repository/cache" orepo "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations" urepo "github.com/emochka2007/block-accounting/internal/usecase/repository/users" @@ -19,8 +20,12 @@ func provideUsersInteractor( return users.NewUsersInteractor(log.WithGroup("users-interactor"), usersRepo) } -func provideJWTInteractor(c config.Config, usersInteractor users.UsersInteractor) jwt.JWTInteractor { - return jwt.NewWardenJWT(c.Common.JWTSecret, usersInteractor) +func provideJWTInteractor( + c config.Config, + usersInteractor users.UsersInteractor, + authRepository auth.Repository, +) jwt.JWTInteractor { + return jwt.NewJWT(c.Common.JWTSecret, usersInteractor, authRepository) } func provideOrganizationsInteractor( diff --git a/backend/internal/factory/repositories.go b/backend/internal/factory/repositories.go index 4a8e14e..1c10c45 100644 --- a/backend/internal/factory/repositories.go +++ b/backend/internal/factory/repositories.go @@ -5,6 +5,7 @@ import ( "log/slog" "github.com/emochka2007/block-accounting/internal/pkg/config" + "github.com/emochka2007/block-accounting/internal/usecase/repository/auth" "github.com/emochka2007/block-accounting/internal/usecase/repository/cache" "github.com/emochka2007/block-accounting/internal/usecase/repository/organizations" "github.com/emochka2007/block-accounting/internal/usecase/repository/users" @@ -19,6 +20,10 @@ func provideOrganizationsRepository(db *sql.DB) organizations.Repository { return organizations.NewRepository(db) } +func provideAuthRepository(db *sql.DB) auth.Repository { + return auth.NewRepository(db) +} + func provideRedisConnection(c config.Config) (*redis.Client, func()) { r := redis.NewClient(&redis.Options{ Addr: c.DB.CacheHost, diff --git a/backend/internal/factory/wire.go b/backend/internal/factory/wire.go index 4834ccd..3558957 100644 --- a/backend/internal/factory/wire.go +++ b/backend/internal/factory/wire.go @@ -20,6 +20,7 @@ func ProvideService(c config.Config) (service.Service, func(), error) { provideUsersInteractor, provideOrganizationsRepository, provideOrganizationsInteractor, + provideAuthRepository, provideJWTInteractor, interfaceSet, provideRestServer, diff --git a/backend/internal/factory/wire_gen.go b/backend/internal/factory/wire_gen.go index b4dc17b..8402230 100644 --- a/backend/internal/factory/wire_gen.go +++ b/backend/internal/factory/wire_gen.go @@ -22,7 +22,8 @@ func ProvideService(c config.Config) (service.Service, func(), error) { } usersRepository := provideUsersRepository(db) usersInteractor := provideUsersInteractor(logger, usersRepository) - jwtInteractor := provideJWTInteractor(c, usersInteractor) + authRepository := provideAuthRepository(db) + jwtInteractor := provideJWTInteractor(c, usersInteractor, authRepository) authPresenter := provideAuthPresenter(jwtInteractor) authController := provideAuthController(logger, usersInteractor, authPresenter, jwtInteractor) organizationsRepository := provideOrganizationsRepository(db) diff --git a/backend/internal/interface/rest/controllers/auth.go b/backend/internal/interface/rest/controllers/auth.go index 53a5067..eeb626c 100644 --- a/backend/internal/interface/rest/controllers/auth.go +++ b/backend/internal/interface/rest/controllers/auth.go @@ -26,6 +26,7 @@ type AuthController interface { JoinWithInvite(w http.ResponseWriter, req *http.Request) ([]byte, error) Login(w http.ResponseWriter, req *http.Request) ([]byte, error) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) + Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error) } type authController struct { @@ -114,6 +115,29 @@ func (c *authController) Login(w http.ResponseWriter, req *http.Request) ([]byte return c.presenter.ResponseLogin(users[0]) } +func (c *authController) Refresh(w http.ResponseWriter, req *http.Request) ([]byte, error) { + request, err := presenters.CreateRequest[domain.RefreshRequest](req) + if err != nil { + return nil, fmt.Errorf("error create refresh request. %w", err) + } + + c.log.Debug( + "refresh request", + slog.String("token", request.Token), + slog.String("refresh_token", request.RefreshToken), + ) + + ctx, cancel := context.WithTimeout(req.Context(), 3*time.Second) + defer cancel() + + newTokens, err := c.jwtInteractor.RefreshToken(ctx, request.Token, request.RefreshToken) + if err != nil { + return nil, fmt.Errorf("error refresh access token. %w", err) + } + + return c.presenter.ResponseRefresh(newTokens) +} + // const mnemonicEntropyBitSize int = 256 func (c *authController) Invite(w http.ResponseWriter, req *http.Request) ([]byte, error) { diff --git a/backend/internal/interface/rest/domain/dto.go b/backend/internal/interface/rest/domain/dto.go index 0b43150..444c613 100644 --- a/backend/internal/interface/rest/domain/dto.go +++ b/backend/internal/interface/rest/domain/dto.go @@ -38,8 +38,16 @@ type LoginRequest struct { Mnemonic string `json:"mnemonic"` } +type RefreshRequest struct { + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` +} + type LoginResponse struct { - Token string `json:"token"` + Token string `json:"token"` + ExpiredAt int64 `json:"token_expired_at"` + RefreshToken string `json:"refresh_token"` + RTExpiredAt int64 `json:"refresh_token_expired_at"` } // Organizations diff --git a/backend/internal/interface/rest/presenters/auth.go b/backend/internal/interface/rest/presenters/auth.go index 94a64d1..2a64f91 100644 --- a/backend/internal/interface/rest/presenters/auth.go +++ b/backend/internal/interface/rest/presenters/auth.go @@ -13,6 +13,7 @@ import ( type AuthPresenter interface { ResponseJoin(user *models.User) ([]byte, error) ResponseLogin(user *models.User) ([]byte, error) + ResponseRefresh(tokens jwt.AccessToken) ([]byte, error) } type authPresenter struct { @@ -28,16 +29,17 @@ func NewAuthPresenter( } func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) { - resp := new(domain.JoinResponse) - - token, err := p.jwtInteractor.NewToken(user, 24*time.Hour) + tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour) if err != nil { return nil, fmt.Errorf("error create access token. %w", err) } - resp.Token = token - - out, err := json.Marshal(resp) + out, err := json.Marshal(domain.LoginResponse{ + Token: tokens.Token, + RefreshToken: tokens.RefreshToken, + ExpiredAt: tokens.ExpiredAt.UnixMilli(), + RTExpiredAt: tokens.RTExpiredAt.UnixMilli(), + }) if err != nil { return nil, fmt.Errorf("error marshal join response. %w", err) } @@ -46,19 +48,34 @@ func (p *authPresenter) ResponseJoin(user *models.User) ([]byte, error) { } func (p *authPresenter) ResponseLogin(user *models.User) ([]byte, error) { - resp := new(domain.LoginResponse) - - token, err := p.jwtInteractor.NewToken(user, 24*time.Hour) + tokens, err := p.jwtInteractor.NewToken(user, 24*time.Hour) if err != nil { return nil, fmt.Errorf("error create access token. %w", err) } - resp.Token = token - - out, err := json.Marshal(resp) + out, err := json.Marshal(domain.LoginResponse{ + Token: tokens.Token, + RefreshToken: tokens.RefreshToken, + ExpiredAt: tokens.ExpiredAt.UnixMilli(), + RTExpiredAt: tokens.RTExpiredAt.UnixMilli(), + }) if err != nil { return nil, fmt.Errorf("error marshal login response. %w", err) } return out, nil } + +func (p *authPresenter) ResponseRefresh(tokens jwt.AccessToken) ([]byte, error) { + out, err := json.Marshal(domain.LoginResponse{ + Token: tokens.Token, + RefreshToken: tokens.RefreshToken, + ExpiredAt: tokens.ExpiredAt.UnixMilli(), + RTExpiredAt: tokens.RTExpiredAt.UnixMilli(), + }) + if err != nil { + return nil, fmt.Errorf("error marshal refresh response. %w", err) + } + + return out, nil +} diff --git a/backend/internal/interface/rest/server.go b/backend/internal/interface/rest/server.go index 9e1a290..386ee0a 100644 --- a/backend/internal/interface/rest/server.go +++ b/backend/internal/interface/rest/server.go @@ -96,6 +96,7 @@ func (s *Server) buildRouter() { router.Post("/join", s.handle(s.controllers.Auth.Join, "join")) router.Post("/login", s.handle(s.controllers.Auth.Login, "login")) + router.Get("/refresh", s.handle(s.controllers.Auth.Refresh, "refresh")) router.Route("/organizations", func(r chi.Router) { r = r.With(s.withAuthorization) diff --git a/backend/internal/usecase/interactors/jwt/jwt.go b/backend/internal/usecase/interactors/jwt/jwt.go index f8a7b7c..569dca3 100644 --- a/backend/internal/usecase/interactors/jwt/jwt.go +++ b/backend/internal/usecase/interactors/jwt/jwt.go @@ -2,12 +2,15 @@ package jwt import ( "context" + "crypto/sha512" + "encoding/base64" "errors" "fmt" "time" "github.com/emochka2007/block-accounting/internal/pkg/models" "github.com/emochka2007/block-accounting/internal/usecase/interactors/users" + "github.com/emochka2007/block-accounting/internal/usecase/repository/auth" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) @@ -18,38 +21,59 @@ var ( ) type JWTInteractor interface { - NewToken(user models.UserIdentity, duration time.Duration) (string, error) + NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error) User(token string) (*models.User, error) + RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error) } type jwtInteractor struct { secret []byte usersInteractor users.UsersInteractor + authRepository auth.Repository } -func NewWardenJWT(secret []byte, usersInteractor users.UsersInteractor) JWTInteractor { +func NewJWT( + secret []byte, + usersInteractor users.UsersInteractor, + authRepository auth.Repository, +) JWTInteractor { return &jwtInteractor{ secret: secret, usersInteractor: usersInteractor, + authRepository: authRepository, } } +type AccessToken struct { + Token string + ExpiredAt time.Time + + RefreshToken string + RTExpiredAt time.Time +} + // NewToken creates new JWT token for given user -func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (string, error) { - token := jwt.New(jwt.SigningMethodHS256) - - claims := token.Claims.(jwt.MapClaims) - claims["uid"] = user.Id().String() - claims["exp"] = time.Now().Add(duration).UnixMilli() - - secret := w.secret - - tokenString, err := token.SignedString([]byte(secret)) +func (w *jwtInteractor) NewToken(user models.UserIdentity, duration time.Duration) (AccessToken, error) { + tokens, err := w.newTokens(user.Id(), duration) if err != nil { - return "", fmt.Errorf("error sign token. %w", err) + return AccessToken{}, fmt.Errorf("error create new tokens. %w", err) } - return tokenString, nil + ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) + defer cancel() + + if err := w.authRepository.AddToken(ctx, auth.AddTokenParams{ + UserId: user.Id(), + Token: tokens.Token, + TokenExpiredAt: tokens.ExpiredAt, + RefreshToken: tokens.RefreshToken, + RefreshTokenExpiredAt: tokens.RTExpiredAt, + CreatedAt: time.Now(), + }); err != nil { + return AccessToken{}, fmt.Errorf("error save tokens into repository. %w", err) + } + + return tokens, nil } func (w *jwtInteractor) User(tokenStr string) (*models.User, error) { @@ -85,8 +109,24 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) { ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) defer cancel() + tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{ + UserId: userId, + Token: tokenStr, + }) + if err != nil { + return nil, fmt.Errorf("error fetch token from repository. %w", err) + } + + if tokens.TokenExpiredAt.Before(time.Now()) { + return nil, fmt.Errorf("error token expired. %w", ErrorTokenExpired) + } + + if tokens.UserId != userId { + return nil, errors.Join(fmt.Errorf("error invalid user id. %w", err), ErrorInvalidTokenClaims) + } + users, err := w.usersInteractor.Get(ctx, users.GetParams{ - Ids: uuid.UUIDs{userId}, + Ids: uuid.UUIDs{tokens.UserId}, }) if err != nil || len(users) == 0 { return nil, fmt.Errorf("error fetch user from repository. %w", err) @@ -94,3 +134,153 @@ func (w *jwtInteractor) User(tokenStr string) (*models.User, error) { return users[0], nil } + +func (w *jwtInteractor) RefreshToken(ctx context.Context, token string, rToken string) (AccessToken, error) { + claims := make(jwt.MapClaims) + + _, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (interface{}, error) { + return w.secret, nil + }) + if err != nil { + return AccessToken{}, errors.Join(fmt.Errorf("error parse jwt token. %w", err), ErrorInvalidTokenClaims) + } + + var userIdString string + var ok bool + + if userIdString, ok = claims["uid"].(string); !ok { + return AccessToken{}, ErrorInvalidTokenClaims + } + + userId, err := uuid.Parse(userIdString) + if err != nil { + return AccessToken{}, errors.Join(fmt.Errorf("error parse user id. %w", err), ErrorInvalidTokenClaims) + } + + _, err = jwt.ParseWithClaims(rToken, claims, func(t *jwt.Token) (interface{}, error) { + return w.secret, nil + }) + if err != nil { + return AccessToken{}, errors.Join(fmt.Errorf("error parse refresh jwt token. %w", err), ErrorInvalidTokenClaims) + } + + if expDate, ok := claims["exp"].(float64); ok { + if time.UnixMilli(int64(expDate)).Before(time.Now()) { + return AccessToken{}, fmt.Errorf("error refresh token expired. %w", ErrorTokenExpired) + } + } else { + return AccessToken{}, errors.Join(fmt.Errorf("error parse exp date. %w", err), ErrorInvalidTokenClaims) + } + + if userIdString, ok = claims["uid"].(string); !ok { + return AccessToken{}, ErrorInvalidTokenClaims + } + + rTokenUserId, err := uuid.Parse(userIdString) + if err != nil { + return AccessToken{}, errors.Join( + fmt.Errorf("error parse user id from refresh token. %w", err), + ErrorInvalidTokenClaims, + ) + } + + if userId != rTokenUserId { + return AccessToken{}, fmt.Errorf("error user ids corrupted. %w", ErrorInvalidTokenClaims) + } + + ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) + defer cancel() + + tokens, err := w.authRepository.GetTokens(ctx, auth.GetTokenParams{ + UserId: userId, + Token: token, + RefreshToken: rToken, + }) + if err != nil { + return AccessToken{}, fmt.Errorf("error fetch token from repository. %w", err) + } + + if tokens.RefreshTokenExpiredAt.Before(time.Now()) { + return AccessToken{}, fmt.Errorf("error token expired. %w", ErrorTokenExpired) + } + + rtHash := sha512.New() + rtHash.Write([]byte(tokens.Token)) + + rtHashStringValid := base64.StdEncoding.EncodeToString(rtHash.Sum(nil)) + + rtHashRaw, ok := claims["rt_hash"] + if !ok { + return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims) + } + + rtHashString, ok := rtHashRaw.(string) + if !ok { + return AccessToken{}, fmt.Errorf("error refresh token claims corrupted. %w", ErrorInvalidTokenClaims) + } + + if rtHashString != rtHashStringValid { + return AccessToken{}, fmt.Errorf("error refresh token hash corrupted. %w", ErrorInvalidTokenClaims) + } + + newTokens, err := w.newTokens(userId, 24*time.Hour) + if err != nil { + return AccessToken{}, fmt.Errorf("error create new tokens. %w", err) + } + + if err = w.authRepository.RefreshToken(ctx, auth.RefreshTokenParams{ + UserId: userId, + OldToken: token, + Token: newTokens.Token, + TokenExpiredAt: newTokens.ExpiredAt, + OldRefreshToken: rToken, + RefreshToken: newTokens.RefreshToken, + RefreshTokenExpiredAt: newTokens.RTExpiredAt, + }); err != nil { + return AccessToken{}, fmt.Errorf("error update tokens. %w", err) + } + + return newTokens, nil +} + +func (w *jwtInteractor) newTokens(userId uuid.UUID, duration time.Duration) (AccessToken, error) { + token := jwt.New(jwt.SigningMethodHS256) + + expAt := time.Now().Add(duration) + + claims := token.Claims.(jwt.MapClaims) + claims["uid"] = userId.String() + claims["exp"] = expAt.UnixMilli() + + secret := w.secret + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return AccessToken{}, fmt.Errorf("error sign token. %w", err) + } + + refreshToken := jwt.New(jwt.SigningMethodHS256) + + rtHash := sha512.New() + + rtHash.Write([]byte(tokenString)) + + rtExpAt := expAt.Add(time.Hour * 24 * 5) + + claims = refreshToken.Claims.(jwt.MapClaims) + claims["uid"] = userId.String() + claims["exp"] = rtExpAt.UnixMilli() + claims["rt_hash"] = base64.StdEncoding.EncodeToString(rtHash.Sum(nil)) + + rtokenString, err := refreshToken.SignedString([]byte(secret)) + if err != nil { + return AccessToken{}, fmt.Errorf("error sign refresh token. %w", err) + } + + return AccessToken{ + Token: tokenString, + ExpiredAt: expAt, + RefreshToken: rtokenString, + RTExpiredAt: rtExpAt, + }, nil +} diff --git a/backend/internal/usecase/repository/auth/repository.go b/backend/internal/usecase/repository/auth/repository.go new file mode 100644 index 0000000..c5ebe6d --- /dev/null +++ b/backend/internal/usecase/repository/auth/repository.go @@ -0,0 +1,191 @@ +package auth + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + sqltools "github.com/emochka2007/block-accounting/internal/pkg/sqlutils" + "github.com/google/uuid" +) + +type AddTokenParams struct { + UserId uuid.UUID + + Token string + TokenExpiredAt time.Time + + RefreshToken string + RefreshTokenExpiredAt time.Time + + CreatedAt time.Time +} + +type GetTokenParams struct { + UserId uuid.UUID + Token string + RefreshToken string +} + +type RefreshTokenParams struct { + UserId uuid.UUID + + OldToken string + Token string + TokenExpiredAt time.Time + + OldRefreshToken string + RefreshToken string + RefreshTokenExpiredAt time.Time +} + +type AccessToken struct { + UserId uuid.UUID + + Token string + TokenExpiredAt time.Time + + RefreshToken string + RefreshTokenExpiredAt time.Time + + CreatedAt time.Time +} + +type Repository interface { + AddToken(ctx context.Context, params AddTokenParams) error + GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) + RefreshToken(ctx context.Context, params RefreshTokenParams) error +} + +type repositorySQL struct { + db *sql.DB +} + +func (r *repositorySQL) AddToken(ctx context.Context, params AddTokenParams) error { + if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error { + query := sq.Insert("access_tokens"). + Columns( + "user_id", + "token", + "refresh_token", + "token_expired_at", + "refresh_token_expired_at", + ). + Values( + params.UserId, + params.Token, + params.RefreshToken, + params.TokenExpiredAt, + params.RefreshTokenExpiredAt, + ).PlaceholderFormat(sq.Dollar) + + if _, err := query.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil { + return fmt.Errorf("error add tokens. %w", err) + } + + return nil + }); err != nil { + return err + } + + return nil +} + +func (r *repositorySQL) RefreshToken(ctx context.Context, params RefreshTokenParams) error { + if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error { + updateQuery := sq.Update("access_tokens"). + SetMap(sq.Eq{ + "token": params.Token, + "refresh_token": params.RefreshToken, + "token_expired_at": params.TokenExpiredAt, + "refresh_token_expired_at": params.RefreshTokenExpiredAt, + }). + Where(sq.Eq{ + "user_id": params.UserId, + "token": params.OldToken, + "refresh_token": params.OldRefreshToken, + }).PlaceholderFormat(sq.Dollar) + + if _, err := updateQuery.RunWith(r.Conn(ctx)).ExecContext(ctx); err != nil { + return fmt.Errorf("error update tokens. %w", err) + } + + return nil + }); err != nil { + return err + } + + return nil +} + +func (r *repositorySQL) GetTokens(ctx context.Context, params GetTokenParams) (*AccessToken, error) { + var token *AccessToken = new(AccessToken) + + if err := sqltools.Transaction(ctx, r.db, func(ctx context.Context) error { + query := sq.Select( + "user_id", + "token", + "token_expired_at", + "refresh_token", + "refresh_token_expired_at", + "created_at", + ).From("access_tokens"). + Where(sq.Eq{ + "token": params.Token, + "user_id": params.UserId, + }).PlaceholderFormat(sq.Dollar) + + if params.RefreshToken != "" { + query = query.Where(sq.Eq{ + "refresh_token": params.RefreshToken, + }) + } + + rows, err := query.RunWith(r.Conn(ctx)).QueryContext(ctx) + if err != nil { + return fmt.Errorf("error fetch token from database. %w", err) + } + + defer func() { + if cErr := rows.Close(); cErr != nil { + err = errors.Join(fmt.Errorf("error close database rows. %w", cErr), err) + } + }() + + for rows.Next() { + if err := rows.Scan( + &token.UserId, + &token.Token, + &token.TokenExpiredAt, + &token.RefreshToken, + &token.RefreshTokenExpiredAt, + &token.CreatedAt, + ); err != nil { + return fmt.Errorf("error scan row. %w", err) + } + } + + return nil + }); err != nil { + return nil, err + } + + return token, nil +} + +func NewRepository(db *sql.DB) Repository { + return &repositorySQL{ + db: db, + } +} + +func (s *repositorySQL) Conn(ctx context.Context) sqltools.DBTX { + if tx, ok := ctx.Value(sqltools.TxCtxKey).(*sql.Tx); ok { + return tx + } + + return s.db +} diff --git a/backend/migrations/blockd.sql b/backend/migrations/blockd.sql index 52c195f..5d2be93 100644 --- a/backend/migrations/blockd.sql +++ b/backend/migrations/blockd.sql @@ -24,6 +24,22 @@ create index if not exists index_users_phone create index if not exists index_users_seed on users using hash (seed); +create table if not exists access_tokens ( + user_id uuid not null references users(id), + token varchar(350) not null, + token_expired_at timestamp, + refresh_token varchar(350) not null, + refresh_token_expired_at timestamp, + created_at timestamp default current_timestamp, + remote_addr string +); + +create index if not exists index_access_tokens_token_refresh_token + on access_tokens (token, refresh_token); + +create index if not exists index_access_tokens_token_refresh_token_exp + on access_tokens (token, refresh_token, token_expired_at, refresh_token_expired_at); + create table if not exists organizations ( id uuid primary key unique, name varchar(300) default 'My Organization' not null, @@ -120,9 +136,12 @@ create table contracts ( title varchar(250) default 'New Contract', description text not null, + address bytea not null, + created_by uuid not null references users(id), organization_id uuid not null references organizations(id), created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp ); +