From f44813431d433147b89785d8d2e6484129af5fc4 Mon Sep 17 00:00:00 2001 From: optclblast Date: Thu, 2 Jan 2025 10:44:28 -0800 Subject: [PATCH] +1 --- .../app/handlers/{auth.go => auth_common.go} | 0 internal/app/register.go | 105 ------------------ internal/common/request.go | 24 ++-- internal/common/request_test.go | 17 ++- internal/resolvers/ip/ip.go | 7 -- 5 files changed, 27 insertions(+), 126 deletions(-) rename internal/app/handlers/{auth.go => auth_common.go} (100%) delete mode 100644 internal/app/register.go delete mode 100644 internal/resolvers/ip/ip.go diff --git a/internal/app/handlers/auth.go b/internal/app/handlers/auth_common.go similarity index 100% rename from internal/app/handlers/auth.go rename to internal/app/handlers/auth_common.go diff --git a/internal/app/register.go b/internal/app/register.go deleted file mode 100644 index f5eb28d..0000000 --- a/internal/app/register.go +++ /dev/null @@ -1,105 +0,0 @@ -package app - -import ( - "fmt" - "net/http" - "time" - - "git.optclblast.xyz/draincloud/draincloud-core/internal/domain" - "git.optclblast.xyz/draincloud/draincloud-core/internal/logger" - "git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models" - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" -) - -func (d *DrainCloud) Register(ctx *gin.Context) { - logger.Debug(ctx, "[register] new request") - - // TODO check if registration is enabled - - req := new(domain.RegisterRequest) - err := ctx.BindJSON(req) - if err != nil { - logger.Error(ctx, "[register] failed to bind request", logger.Err(err)) - ctx.JSON(http.StatusBadRequest, map[string]string{ - "error": "bad request", - }) - return - } - - resp, err := d.register(ctx, req) - if err != nil { - logger.Error(ctx, "[register] failed to register user", logger.Err(err)) - ctx.JSON(http.StatusInternalServerError, map[string]string{ - "error": err.Error(), - }) - return - } - - ctx.JSON(http.StatusOK, resp) -} - -func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*domain.RegisterResponse, error) { - if err := validateLoginAndPassword(req.Login, req.Password); err != nil { - return nil, fmt.Errorf("invalid creds: %w", err) - } - - passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10) - if err != nil { - logger.Error(ctx, "[register] failed to generate password hash", logger.Err(err)) - return nil, fmt.Errorf("failed to generate password hash: %w", err) - } - - userID, err := uuid.NewV7() - if err != nil { - return nil, fmt.Errorf("failed to generate user id: %w", err) - } - - user := &models.User{ - ID: userID, - Username: req.Login, - Login: req.Login, - PasswordHash: passwordHash, - } - - err = d.database.AddUser(ctx, userID, user.Login, user.Username, user.PasswordHash) - if err != nil { - return nil, fmt.Errorf("failed to add new user: %w", err) - } - - sessionCreatedAt := time.Now() - sessionExpiredAt := sessionCreatedAt.Add(time.Hour * 24 * 7) - - sessionToken, err := generateSessionToken(100) - if err != nil { - return nil, fmt.Errorf("failed to generate a session token: %w", err) - } - ctx.SetCookie(sessionTokenCookie, sessionToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, true) - - csrfToken, err := generateSessionToken(100) - if err != nil { - return nil, fmt.Errorf("failed to generate a csrf token: %w", err) - } - ctx.SetCookie(csrfTokenCookie, csrfToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, false) - - sessionID, err := uuid.NewV7() - if err != nil { - return nil, fmt.Errorf("failed to generate session id: %w", err) - } - - if _, err = d.database.AddSession(ctx, &models.Session{ - ID: sessionID, - SessionToken: sessionToken, - CsrfToken: csrfToken, - UserID: user.ID, - CreatedAt: sessionCreatedAt, - ExpiredAt: sessionExpiredAt, - }); err != nil { - return nil, fmt.Errorf("failed to save session: %w", err) - } - - return &domain.RegisterResponse{ - Ok: true, - }, nil -} diff --git a/internal/common/request.go b/internal/common/request.go index 602769d..aab27ed 100644 --- a/internal/common/request.go +++ b/internal/common/request.go @@ -23,8 +23,8 @@ func (p *RequestPool) Get() *Request { func (p *RequestPool) Put(r *Request) { r.ID = "" - r.Metadata = make(map[string]any) - r.ResolveValues = sync.Map{} + r.Metadata = &sync.Map{} + r.ResolveValues = &sync.Map{} r.Session = nil r.User = nil r.Body = nil @@ -36,8 +36,8 @@ func NewRequestPool() *RequestPool { sp: sync.Pool{ New: func() any { return &Request{ - ResolveValues: sync.Map{}, - Metadata: make(map[string]any), + ResolveValues: &sync.Map{}, + Metadata: &sync.Map{}, } }, }, @@ -48,9 +48,10 @@ type Request struct { ID string Session *models.Session User *models.User - ResolveValues sync.Map - Metadata map[string]any + ResolveValues *sync.Map + Metadata *sync.Map Body []byte + RawReq *http.Request } // NewRequestFromHttp builds a new *Request struct from raw http Request. No auth data validated. @@ -60,14 +61,15 @@ func NewRequestFromHttp(pool *RequestPool, req *http.Request) *Request { cookies := req.Cookies() headers := req.Header - out.Metadata = make(map[string]any, len(cookies)) + out.Metadata = &sync.Map{} + out.RawReq = req for _, cookie := range cookies { - out.Metadata[cookie.Name] = cookie.Value + out.Metadata.Store(cookie.Name, cookie.Value) } for hname, hval := range headers { - out.Metadata[hname] = hval + out.Metadata.Store(hname, hval) } body, err := io.ReadAll(req.Body) @@ -81,12 +83,12 @@ func NewRequestFromHttp(pool *RequestPool, req *http.Request) *Request { return out } -func GetValue[T any](vals map[string]any, key string) (T, error) { +func GetValue[T any](vals *sync.Map, key string) (T, error) { var out T if vals == nil { return out, fmt.Errorf("nil vals map") } - rawVal, ok := vals[key] + rawVal, ok := vals.Load(key) if !ok { return out, fmt.Errorf("value not found in resolve values set") } diff --git a/internal/common/request_test.go b/internal/common/request_test.go index 9898dd3..ae7f93c 100644 --- a/internal/common/request_test.go +++ b/internal/common/request_test.go @@ -2,6 +2,7 @@ package common import ( "reflect" + "sync" "testing" ) @@ -65,7 +66,7 @@ func TestGetValue_string(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetValue[string](tt.args.vals, tt.args.key) + got, err := GetValue[string](_mapToSyncMap(tt.args.vals), tt.args.key) if (err != nil) != tt.wantErr { t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -151,7 +152,7 @@ func TestGetValue_struct(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetValue[val](tt.args.vals, tt.args.key) + got, err := GetValue[val](_mapToSyncMap(tt.args.vals), tt.args.key) if (err != nil) != tt.wantErr { t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -237,7 +238,7 @@ func TestGetValue_structptr(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetValue[*val](tt.args.vals, tt.args.key) + got, err := GetValue[*val](_mapToSyncMap(tt.args.vals), tt.args.key) if (err != nil) != tt.wantErr { t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) return @@ -248,3 +249,13 @@ func TestGetValue_structptr(t *testing.T) { }) } } + +func _mapToSyncMap(m map[string]any) *sync.Map { + out := &sync.Map{} + + for k, v := range m { + out.Store(k, v) + } + + return out +} diff --git a/internal/resolvers/ip/ip.go b/internal/resolvers/ip/ip.go deleted file mode 100644 index c8a4531..0000000 --- a/internal/resolvers/ip/ip.go +++ /dev/null @@ -1,7 +0,0 @@ -package ip - -type IpResolver struct{} - -func New() *IpResolver { - return new(IpResolver) -}