+1
This commit is contained in:
parent
0202bd5dbb
commit
f44813431d
@ -1,105 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.optclblast.xyz/draincloud/draincloud-core/internal/domain"
|
|
||||||
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
|
|
||||||
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *DrainCloud) Register(ctx *gin.Context) {
|
|
||||||
logger.Debug(ctx, "[register] new request")
|
|
||||||
|
|
||||||
// TODO check if registration is enabled
|
|
||||||
|
|
||||||
req := new(domain.RegisterRequest)
|
|
||||||
err := ctx.BindJSON(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "[register] failed to bind request", logger.Err(err))
|
|
||||||
ctx.JSON(http.StatusBadRequest, map[string]string{
|
|
||||||
"error": "bad request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := d.register(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "[register] failed to register user", logger.Err(err))
|
|
||||||
ctx.JSON(http.StatusInternalServerError, map[string]string{
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.JSON(http.StatusOK, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*domain.RegisterResponse, error) {
|
|
||||||
if err := validateLoginAndPassword(req.Login, req.Password); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid creds: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "[register] failed to generate password hash", logger.Err(err))
|
|
||||||
return nil, fmt.Errorf("failed to generate password hash: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
userID, err := uuid.NewV7()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate user id: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := &models.User{
|
|
||||||
ID: userID,
|
|
||||||
Username: req.Login,
|
|
||||||
Login: req.Login,
|
|
||||||
PasswordHash: passwordHash,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = d.database.AddUser(ctx, userID, user.Login, user.Username, user.PasswordHash)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add new user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionCreatedAt := time.Now()
|
|
||||||
sessionExpiredAt := sessionCreatedAt.Add(time.Hour * 24 * 7)
|
|
||||||
|
|
||||||
sessionToken, err := generateSessionToken(100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate a session token: %w", err)
|
|
||||||
}
|
|
||||||
ctx.SetCookie(sessionTokenCookie, sessionToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, true)
|
|
||||||
|
|
||||||
csrfToken, err := generateSessionToken(100)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate a csrf token: %w", err)
|
|
||||||
}
|
|
||||||
ctx.SetCookie(csrfTokenCookie, csrfToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, false)
|
|
||||||
|
|
||||||
sessionID, err := uuid.NewV7()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate session id: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = d.database.AddSession(ctx, &models.Session{
|
|
||||||
ID: sessionID,
|
|
||||||
SessionToken: sessionToken,
|
|
||||||
CsrfToken: csrfToken,
|
|
||||||
UserID: user.ID,
|
|
||||||
CreatedAt: sessionCreatedAt,
|
|
||||||
ExpiredAt: sessionExpiredAt,
|
|
||||||
}); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to save session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domain.RegisterResponse{
|
|
||||||
Ok: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
@ -23,8 +23,8 @@ func (p *RequestPool) Get() *Request {
|
|||||||
|
|
||||||
func (p *RequestPool) Put(r *Request) {
|
func (p *RequestPool) Put(r *Request) {
|
||||||
r.ID = ""
|
r.ID = ""
|
||||||
r.Metadata = make(map[string]any)
|
r.Metadata = &sync.Map{}
|
||||||
r.ResolveValues = sync.Map{}
|
r.ResolveValues = &sync.Map{}
|
||||||
r.Session = nil
|
r.Session = nil
|
||||||
r.User = nil
|
r.User = nil
|
||||||
r.Body = nil
|
r.Body = nil
|
||||||
@ -36,8 +36,8 @@ func NewRequestPool() *RequestPool {
|
|||||||
sp: sync.Pool{
|
sp: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &Request{
|
return &Request{
|
||||||
ResolveValues: sync.Map{},
|
ResolveValues: &sync.Map{},
|
||||||
Metadata: make(map[string]any),
|
Metadata: &sync.Map{},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -48,9 +48,10 @@ type Request struct {
|
|||||||
ID string
|
ID string
|
||||||
Session *models.Session
|
Session *models.Session
|
||||||
User *models.User
|
User *models.User
|
||||||
ResolveValues sync.Map
|
ResolveValues *sync.Map
|
||||||
Metadata map[string]any
|
Metadata *sync.Map
|
||||||
Body []byte
|
Body []byte
|
||||||
|
RawReq *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestFromHttp builds a new *Request struct from raw http Request. No auth data validated.
|
// NewRequestFromHttp builds a new *Request struct from raw http Request. No auth data validated.
|
||||||
@ -60,14 +61,15 @@ func NewRequestFromHttp(pool *RequestPool, req *http.Request) *Request {
|
|||||||
cookies := req.Cookies()
|
cookies := req.Cookies()
|
||||||
headers := req.Header
|
headers := req.Header
|
||||||
|
|
||||||
out.Metadata = make(map[string]any, len(cookies))
|
out.Metadata = &sync.Map{}
|
||||||
|
out.RawReq = req
|
||||||
|
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
out.Metadata[cookie.Name] = cookie.Value
|
out.Metadata.Store(cookie.Name, cookie.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
for hname, hval := range headers {
|
for hname, hval := range headers {
|
||||||
out.Metadata[hname] = hval
|
out.Metadata.Store(hname, hval)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(req.Body)
|
body, err := io.ReadAll(req.Body)
|
||||||
@ -81,12 +83,12 @@ func NewRequestFromHttp(pool *RequestPool, req *http.Request) *Request {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetValue[T any](vals map[string]any, key string) (T, error) {
|
func GetValue[T any](vals *sync.Map, key string) (T, error) {
|
||||||
var out T
|
var out T
|
||||||
if vals == nil {
|
if vals == nil {
|
||||||
return out, fmt.Errorf("nil vals map")
|
return out, fmt.Errorf("nil vals map")
|
||||||
}
|
}
|
||||||
rawVal, ok := vals[key]
|
rawVal, ok := vals.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return out, fmt.Errorf("value not found in resolve values set")
|
return out, fmt.Errorf("value not found in resolve values set")
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ func TestGetValue_string(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := GetValue[string](tt.args.vals, tt.args.key)
|
got, err := GetValue[string](_mapToSyncMap(tt.args.vals), tt.args.key)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
@ -151,7 +152,7 @@ func TestGetValue_struct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := GetValue[val](tt.args.vals, tt.args.key)
|
got, err := GetValue[val](_mapToSyncMap(tt.args.vals), tt.args.key)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
@ -237,7 +238,7 @@ func TestGetValue_structptr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := GetValue[*val](tt.args.vals, tt.args.key)
|
got, err := GetValue[*val](_mapToSyncMap(tt.args.vals), tt.args.key)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
@ -248,3 +249,13 @@ func TestGetValue_structptr(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _mapToSyncMap(m map[string]any) *sync.Map {
|
||||||
|
out := &sync.Map{}
|
||||||
|
|
||||||
|
for k, v := range m {
|
||||||
|
out.Store(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
package ip
|
|
||||||
|
|
||||||
type IpResolver struct{}
|
|
||||||
|
|
||||||
func New() *IpResolver {
|
|
||||||
return new(IpResolver)
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user