Compare commits

...

5 Commits

Author SHA1 Message Date
7ce2b41ce7 Merge pull request 'transport layer refactoring' (#1) from transport-layer-refactoring into master
Reviewed-on: #1
2025-01-02 23:35:16 +00:00
d0f10cec62 +1 2025-01-02 15:34:47 -08:00
f44813431d +1 2025-01-02 10:44:28 -08:00
0202bd5dbb refactoring and fixes. test with /auth/register handler 2024-12-30 15:35:31 -08:00
286a0fe826 +1 2024-12-29 15:15:44 -08:00
12 changed files with 406 additions and 102 deletions

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/spf13/viper v1.19.0 github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.28.0 golang.org/x/crypto v0.28.0
golang.org/x/sync v0.8.0
) )
require ( require (

View File

@ -5,8 +5,11 @@ import (
"errors" "errors"
"net/http" "net/http"
"git.optclblast.xyz/draincloud/draincloud-core/internal/app/handlers"
"git.optclblast.xyz/draincloud/draincloud-core/internal/domain" "git.optclblast.xyz/draincloud/draincloud-core/internal/domain"
filesengine "git.optclblast.xyz/draincloud/draincloud-core/internal/files_engine" filesengine "git.optclblast.xyz/draincloud/draincloud-core/internal/files_engine"
"git.optclblast.xyz/draincloud/draincloud-core/internal/processor"
resolvedispatcher "git.optclblast.xyz/draincloud/draincloud-core/internal/resolve_dispatcher"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage" "git.optclblast.xyz/draincloud/draincloud-core/internal/storage"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -15,6 +18,8 @@ type DrainCloud struct {
mux *gin.Engine mux *gin.Engine
database storage.Database database storage.Database
filesEngine *filesengine.FilesEngine filesEngine *filesengine.FilesEngine
ginProcessor processor.Processor[gin.HandlerFunc]
} }
func New( func New(
@ -23,15 +28,21 @@ func New(
) *DrainCloud { ) *DrainCloud {
mux := gin.Default() mux := gin.Default()
dispatcher := resolvedispatcher.New()
d := &DrainCloud{ d := &DrainCloud{
database: database, database: database,
filesEngine: filesEngine, filesEngine: filesEngine,
ginProcessor: processor.NewGinProcessor(database, dispatcher),
} }
// Built-in auth component of DrainCloud-Core // Built-in auth component of DrainCloud-Core
authGroup := mux.Group("/auth") authGroup := mux.Group("/auth")
{ {
authGroup.POST("/register", d.Register) // authGroup.POST("/register", d.Register)
authGroup.POST("/register", d.ginProcessor.Process(
handlers.NewRegisterHandler(database),
))
authGroup.POST("/logon", d.Login) authGroup.POST("/logon", d.Login)
} }

View File

@ -0,0 +1,37 @@
package handlers
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
)
const (
csrfTokenCookie = "__Csrf_token"
sessionTokenCookie = "__Session_token"
)
var (
ErrorUnauthorized = errors.New("unauthorized")
)
func validateLoginAndPassword(login, password string) error {
if len(login) < 4 {
return fmt.Errorf("login must be longer than 8 chars")
}
if len(password) < 6 {
return fmt.Errorf("password must be longer than 8 chars")
}
return nil
}
func generateSessionToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes), nil
}

View File

@ -1,46 +1,61 @@
package app package handlers
import ( import (
"context"
"encoding/json"
"fmt" "fmt"
"net/http"
"time" "time"
"git.optclblast.xyz/draincloud/draincloud-core/internal/common"
"git.optclblast.xyz/draincloud/draincloud-core/internal/domain" "git.optclblast.xyz/draincloud/draincloud-core/internal/domain"
"git.optclblast.xyz/draincloud/draincloud-core/internal/handler"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger" "git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models" "git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func (d *DrainCloud) Register(ctx *gin.Context) { type RegisterHandler struct {
logger.Debug(ctx, "[register] new request") *handler.BaseHandler
authStorage storage.AuthStorage
// TODO check if registration is enabled
req := new(domain.RegisterRequest)
err := ctx.BindJSON(req)
if err != nil {
logger.Error(ctx, "[register] failed to bind request", logger.Err(err))
ctx.JSON(http.StatusBadRequest, map[string]string{
"error": "bad request",
})
return
}
resp, err := d.register(ctx, req)
if err != nil {
logger.Error(ctx, "[register] failed to register user", logger.Err(err))
ctx.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
return
}
ctx.JSON(http.StatusOK, resp)
} }
func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*domain.RegisterResponse, error) { func NewRegisterHandler(
authStorage storage.AuthStorage,
) *RegisterHandler {
h := &RegisterHandler{
authStorage: authStorage,
BaseHandler: handler.New().
WithName("registerv1").
WithRequiredResolveParams(),
}
h.WithProcessFunc(h.process)
return h
}
func (h *RegisterHandler) process(ctx context.Context, req *common.Request, w handler.Writer) error {
regReq := new(domain.RegisterRequest)
if err := json.Unmarshal(req.Body, regReq); err != nil {
return err
}
resp, err := h.register(ctx, regReq, w)
if err != nil {
return fmt.Errorf("failed to register user: %w", err)
}
w.Write(ctx, resp)
return nil
}
func (d *RegisterHandler) register(
ctx context.Context,
req *domain.RegisterRequest,
w handler.Writer,
) (*domain.RegisterResponse, error) {
if err := validateLoginAndPassword(req.Login, req.Password); err != nil { if err := validateLoginAndPassword(req.Login, req.Password); err != nil {
return nil, fmt.Errorf("invalid creds: %w", err) return nil, fmt.Errorf("invalid creds: %w", err)
} }
@ -63,7 +78,7 @@ func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*d
PasswordHash: passwordHash, PasswordHash: passwordHash,
} }
err = d.database.AddUser(ctx, userID, user.Login, user.Username, user.PasswordHash) err = d.authStorage.AddUser(ctx, userID, user.Login, user.Username, user.PasswordHash)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add new user: %w", err) return nil, fmt.Errorf("failed to add new user: %w", err)
} }
@ -75,20 +90,20 @@ func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*d
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate a session token: %w", err) return nil, fmt.Errorf("failed to generate a session token: %w", err)
} }
ctx.SetCookie(sessionTokenCookie, sessionToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, true) w.SetCookie(sessionTokenCookie, sessionToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, true)
csrfToken, err := generateSessionToken(100) csrfToken, err := generateSessionToken(100)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate a csrf token: %w", err) return nil, fmt.Errorf("failed to generate a csrf token: %w", err)
} }
ctx.SetCookie(csrfTokenCookie, csrfToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, false) w.SetCookie(csrfTokenCookie, csrfToken, int(sessionExpiredAt.Sub(sessionCreatedAt).Seconds()), "_path", "_domain", true, false)
sessionID, err := uuid.NewV7() sessionID, err := uuid.NewV7()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate session id: %w", err) return nil, fmt.Errorf("failed to generate session id: %w", err)
} }
if _, err = d.database.AddSession(ctx, &models.Session{ if _, err = d.authStorage.AddSession(ctx, &models.Session{
ID: sessionID, ID: sessionID,
SessionToken: sessionToken, SessionToken: sessionToken,
CsrfToken: csrfToken, CsrfToken: csrfToken,
@ -103,3 +118,4 @@ func (d *DrainCloud) register(ctx *gin.Context, req *domain.RegisterRequest) (*d
Ok: true, Ok: true,
}, nil }, nil
} }

View File

@ -1,10 +1,13 @@
package common package common
import ( import (
"context"
"fmt" "fmt"
"io"
"net/http" "net/http"
"sync" "sync"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models" "git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -13,43 +16,82 @@ type RequestPool struct {
sp sync.Pool sp sync.Pool
} }
type Request struct { func (p *RequestPool) Get() *Request {
ID string r, _ := p.sp.Get().(*Request)
Session *models.Session return r
User *models.User
ResolveValues sync.Map
Metadata map[string]any
Body []byte
} }
// NewRequest builds a new *Request struct from raw http Request. No auth data validated. func (p *RequestPool) Put(r *Request) {
func NewRequest(pool *RequestPool, req *http.Request) *Request { r.ID = ""
r.Metadata = &sync.Map{}
r.ResolveValues = &sync.Map{}
r.Session = nil
r.User = nil
r.Body = nil
p.sp.Put(r)
}
func NewRequestPool() *RequestPool {
return &RequestPool{
sp: sync.Pool{
New: func() any {
return &Request{
ResolveValues: &sync.Map{},
Metadata: &sync.Map{},
}
},
},
}
}
type Request struct {
ID string
Session *models.Session
User *models.User
// ResolveValues - data required to process request.
ResolveValues *sync.Map
// Metadata - an additional data, usually added with preprocessing.
Metadata *sync.Map
// Request body
Body []byte
RawReq *http.Request
}
// NewRequestFromHttp builds a new *Request struct from raw http Request. No auth data validated.
func NewRequestFromHttp(pool *RequestPool, req *http.Request) *Request {
out := pool.sp.Get().(*Request) out := pool.sp.Get().(*Request)
cookies := req.Cookies() cookies := req.Cookies()
headers := req.Header headers := req.Header
out.Metadata = make(map[string]any, len(cookies)) out.Metadata = &sync.Map{}
out.RawReq = req
for _, cookie := range cookies { for _, cookie := range cookies {
out.Metadata[cookie.Name] = cookie.Value out.Metadata.Store(cookie.Name, cookie.Value)
} }
for hname, hval := range headers { for hname, hval := range headers {
out.Metadata[hname] = hval out.Metadata.Store(hname, hval)
} }
body, err := io.ReadAll(req.Body)
if err != nil {
logger.Error(context.TODO(), "failed to read request body", logger.Err(err))
}
out.Body = body
reqID := uuid.NewString() reqID := uuid.NewString()
out.ID = reqID out.ID = reqID
return out return out
} }
func GetValue[T any](vals map[string]any, key string) (T, error) { func GetValue[T any](vals *sync.Map, key string) (T, error) {
var out T var out T
if vals == nil { if vals == nil {
return out, fmt.Errorf("nil vals map") return out, fmt.Errorf("nil vals map")
} }
rawVal, ok := vals[key] rawVal, ok := vals.Load(key)
if !ok { if !ok {
return out, fmt.Errorf("value not found in resolve values set") return out, fmt.Errorf("value not found in resolve values set")
} }

View File

@ -2,6 +2,7 @@ package common
import ( import (
"reflect" "reflect"
"sync"
"testing" "testing"
) )
@ -65,7 +66,7 @@ func TestGetValue_string(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := GetValue[string](tt.args.vals, tt.args.key) got, err := GetValue[string](_mapToSyncMap(tt.args.vals), tt.args.key)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -151,7 +152,7 @@ func TestGetValue_struct(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := GetValue[val](tt.args.vals, tt.args.key) got, err := GetValue[val](_mapToSyncMap(tt.args.vals), tt.args.key)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -237,7 +238,7 @@ func TestGetValue_structptr(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := GetValue[*val](tt.args.vals, tt.args.key) got, err := GetValue[*val](_mapToSyncMap(tt.args.vals), tt.args.key)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GetValue() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -248,3 +249,13 @@ func TestGetValue_structptr(t *testing.T) {
}) })
} }
} }
func _mapToSyncMap(m map[string]any) *sync.Map {
out := &sync.Map{}
for k, v := range m {
out.Store(k, v)
}
return out
}

View File

@ -8,10 +8,59 @@ import (
type Writer interface { type Writer interface {
Write(ctx context.Context, resp any) Write(ctx context.Context, resp any)
SetCookie(name string, value string, maxAge int, path string, domain string, secure bool, httpOnly bool)
} }
type Handler struct { type Handler interface {
GetName() string
GetRequiredResolveParams() []string
GetProcessFn() func(ctx context.Context, req *common.Request, w Writer) error
GetPreprocessFn() func(ctx context.Context, req *common.Request, w Writer) error
}
type BaseHandler struct {
Name string Name string
RequiredResolveParams []string RequiredResolveParams []string
ProcessFn func(ctx context.Context, req *common.Request, w Writer) error ProcessFn func(ctx context.Context, req *common.Request, w Writer) error
PreprocessFn func(ctx context.Context, req *common.Request, w Writer) error
}
func New() *BaseHandler {
return new(BaseHandler)
}
func (h *BaseHandler) WithName(name string) *BaseHandler {
h.Name = name
return h
}
func (h *BaseHandler) WithRequiredResolveParams(params ...string) *BaseHandler {
h.RequiredResolveParams = params
return h
}
func (h *BaseHandler) WithProcessFunc(fn func(ctx context.Context, req *common.Request, w Writer) error) *BaseHandler {
h.ProcessFn = fn
return h
}
func (h *BaseHandler) WithPreprocessFunc(fn func(ctx context.Context, req *common.Request, w Writer) error) *BaseHandler {
h.PreprocessFn = fn
return h
}
func (h *BaseHandler) GetName() string {
return h.Name
}
func (h *BaseHandler) GetRequiredResolveParams() []string {
return h.RequiredResolveParams
}
func (h *BaseHandler) GetProcessFn() func(ctx context.Context, req *common.Request, w Writer) error {
return h.ProcessFn
}
func (h *BaseHandler) GetPreprocessFn() func(ctx context.Context, req *common.Request, w Writer) error {
return h.PreprocessFn
} }

View File

@ -0,0 +1,107 @@
package processor
import (
"context"
"errors"
"fmt"
"net/http"
"git.optclblast.xyz/draincloud/draincloud-core/internal/common"
"git.optclblast.xyz/draincloud/draincloud-core/internal/domain"
"git.optclblast.xyz/draincloud/draincloud-core/internal/errs"
"git.optclblast.xyz/draincloud/draincloud-core/internal/handler"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
resolvedispatcher "git.optclblast.xyz/draincloud/draincloud-core/internal/resolve_dispatcher"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
type GinProcessor struct {
rp *common.RequestPool
authStorage storage.AuthStorage
resolveDispatcher *resolvedispatcher.ResolveDispatcher
}
func NewGinProcessor(
authStorage storage.AuthStorage,
resolveDispatcher *resolvedispatcher.ResolveDispatcher,
) *GinProcessor {
return &GinProcessor{
rp: common.NewRequestPool(),
authStorage: authStorage,
resolveDispatcher: resolveDispatcher,
}
}
func (p *GinProcessor) Process(handler handler.Handler) gin.HandlerFunc {
return func(ctx *gin.Context) {
req := common.NewRequestFromHttp(p.rp, ctx.Request)
ctx.Request = ctx.Request.WithContext(context.WithValue(ctx.Request.Context(), "__request_id", req.ID))
// 1. Resolve the resolvers, collect all data required
// 2. Try process oprional resolvers
err := p.resolve(ctx, handler, req)
if err != nil {
p.writeError(ctx, err)
return
}
// 3. Call preprocessing fn's, middlewares etc.
if err = handler.GetPreprocessFn()(ctx, req, wrapGin(ctx)); err != nil {
p.writeError(ctx, err)
return
}
// 4. Call handler.ProcessFn
if err = handler.GetProcessFn()(ctx, req, wrapGin(ctx)); err != nil {
p.writeError(ctx, err)
return
}
}
}
func (p *GinProcessor) resolve(ctx context.Context, h handler.Handler, req *common.Request) error {
eg, ctx := errgroup.WithContext(ctx)
for _, r := range h.GetRequiredResolveParams() {
resolver, err := p.resolveDispatcher.GetResolver(r)
if err != nil {
return fmt.Errorf("failed to resolve '%s' param: no resolver provided: %w", r, err)
}
resolveValueName := r
eg.Go(func() error {
if resolveErr := resolver.Resolve(ctx, req); resolveErr != nil {
return fmt.Errorf("failed to resolve '%s' value: %w", resolveValueName, resolveErr)
}
return nil
})
}
if err := eg.Wait(); err != nil {
return err
}
return nil
}
func (p *GinProcessor) writeError(ctx *gin.Context, err error) {
logger.Error(ctx, "error process request", logger.Err(err))
switch {
case errors.Is(err, errs.ErrorAccessDenied):
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusForbidden,
Message: err.Error(),
})
case errors.Is(err, errs.ErrorSessionExpired):
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusForbidden,
Message: err.Error(),
})
default:
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusInternalServerError,
Message: "Internal Error",
})
}
}

View File

@ -0,0 +1,26 @@
package processor
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
)
type ginWriter struct {
ctx *gin.Context
}
func wrapGin(ctx *gin.Context) ginWriter {
return ginWriter{
ctx: ctx,
}
}
func (w ginWriter) Write(ctx context.Context, resp any) {
w.ctx.JSON(http.StatusOK, resp)
}
func (w ginWriter) SetCookie(name string, value string, maxAge int, path string, domain string, secure bool, httpOnly bool) {
w.ctx.SetCookie(name, value, maxAge, path, domain, secure, httpOnly)
}

View File

@ -1,51 +1,7 @@
package processor package processor
import ( import "git.optclblast.xyz/draincloud/draincloud-core/internal/handler"
"errors"
"net/http"
"git.optclblast.xyz/draincloud/draincloud-core/internal/common" type Processor[H any] interface {
"git.optclblast.xyz/draincloud/draincloud-core/internal/domain" Process(handler.Handler) H
"git.optclblast.xyz/draincloud/draincloud-core/internal/errs"
"git.optclblast.xyz/draincloud/draincloud-core/internal/handler"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage"
"github.com/gin-gonic/gin"
)
type Processor struct {
rp *common.RequestPool
authStorage storage.AuthStorage
}
func (p *Processor) Process(handler *handler.Handler) gin.HandlerFunc {
return func(ctx *gin.Context) {
//req := common.NewRequest(p.rp, ctx.Request)
// if handler.WithAuth {
// if err := p.authorize(ctx, req); err != nil {
// p.writeError(ctx, err)
// return
// }
// }
}
}
func (p *Processor) writeError(ctx *gin.Context, err error) {
switch {
case errors.Is(err, errs.ErrorAccessDenied):
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusForbidden,
Message: err.Error(),
})
case errors.Is(err, errs.ErrorSessionExpired):
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusForbidden,
Message: err.Error(),
})
default:
ctx.JSON(http.StatusInternalServerError, domain.ErrorJson{
Code: http.StatusInternalServerError,
Message: "Internal Error",
})
}
} }

View File

@ -0,0 +1,48 @@
package resolvedispatcher
import (
"context"
"fmt"
"sync"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-core/internal/resolvers"
)
type ResolveDispatcher struct {
m sync.RWMutex
router map[string]resolvers.Resolver
}
func New() *ResolveDispatcher {
return &ResolveDispatcher{
router: map[string]resolvers.Resolver{},
}
}
func (d *ResolveDispatcher) RegisterResolver(
ctx context.Context,
resolverName string,
resolver resolvers.Resolver,
) {
d.m.Lock()
defer d.m.Unlock()
if _, ok := d.router[resolverName]; ok {
logger.Fatal(ctx, fmt.Sprintf("resolver '%s' is already registered in router", resolverName))
}
d.router[resolverName] = resolver
}
func (d *ResolveDispatcher) GetResolver(name string) (resolvers.Resolver, error) {
d.m.RLock()
defer d.m.RUnlock()
res, ok := d.router[name]
if !ok {
return nil, fmt.Errorf("resolver '%s' not found", name)
}
return res, nil
}

View File

@ -73,7 +73,7 @@ func (d *AuthResolver) getSession(ctx context.Context, req *common.Request) (*mo
return session, nil return session, nil
} }
func validateSession(ctx context.Context, req *common.Request, session *models.Session) error { func validateSession(_ context.Context, req *common.Request, session *models.Session) error {
if session == nil { if session == nil {
return errs.ErrorAccessDenied return errs.ErrorAccessDenied
} }