148 lines
4.5 KiB
Go
148 lines
4.5 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"git.optclblast.xyz/draincloud/draincloud-core/internal/closer"
|
|
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
|
|
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
)
|
|
|
|
type Database struct {
|
|
db *pgx.Conn
|
|
cluster *ShardCluster
|
|
}
|
|
|
|
func New(ctx context.Context, dsn string) *Database {
|
|
db, err := pgx.Connect(ctx, dsn)
|
|
if err != nil {
|
|
logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err))
|
|
}
|
|
|
|
closer.Add(func() error {
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
return db.Close(ctx)
|
|
})
|
|
|
|
return &Database{db: db}
|
|
}
|
|
|
|
type dbtx interface {
|
|
Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error)
|
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
|
}
|
|
|
|
func (d *Database) AddUser(ctx context.Context, id uuid.UUID, login string, username string, passwordHash []byte) error {
|
|
return addUser(ctx, d.db, id, login, username, passwordHash)
|
|
}
|
|
|
|
func (d *Database) GetUserByID(ctx context.Context, id uuid.UUID) (*models.User, error) {
|
|
return getUserByID(ctx, d.db, id)
|
|
}
|
|
|
|
func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
|
|
return getUserByLogin(ctx, d.db, login)
|
|
}
|
|
|
|
func (d *Database) AddSession(ctx context.Context, ses *models.Session) (uuid.UUID, error) {
|
|
return addSession(ctx, d.db, ses)
|
|
}
|
|
|
|
func (d *Database) GetSession(ctx context.Context, sessionToken string) (*models.Session, error) {
|
|
const stmt = `SELECT
|
|
s.id, s.session_token, s.csrf_token, s.user_id, s.created_at, s.expired_at
|
|
FROM sessions as s
|
|
WHERE s.session_token = $1;`
|
|
|
|
row := d.db.QueryRow(ctx, stmt, sessionToken)
|
|
|
|
var (
|
|
id uuid.UUID
|
|
sesToken, csrfToken string
|
|
userID uuid.UUID
|
|
createdAt sql.NullTime
|
|
expiredAt sql.NullTime
|
|
)
|
|
|
|
if err := row.Scan(&id, &sesToken, &csrfToken, &userID, &createdAt, &expiredAt); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &models.Session{
|
|
ID: id,
|
|
SessionToken: sesToken,
|
|
CsrfToken: csrfToken,
|
|
UserID: userID,
|
|
CreatedAt: createdAt.Time,
|
|
ExpiredAt: expiredAt.Time,
|
|
}, nil
|
|
}
|
|
|
|
func (d *Database) RemoveSession(ctx context.Context, id uuid.UUID) error {
|
|
const stmt = `DELETE FROM sessions WHERE id = $1;`
|
|
_, err := d.db.Exec(ctx, stmt, id)
|
|
return err
|
|
}
|
|
|
|
func (d *Database) RemoveExpiredSessions(ctx context.Context) error {
|
|
const stmt = `DELETE FROM sessions WHERE expired_at < $1;`
|
|
res, err := d.db.Exec(ctx, stmt, time.Now())
|
|
logger.Notice(ctx, "[Database][RemoveExpiredSessions] sessions cleanup", slog.Int64("removed", res.RowsAffected()))
|
|
return err
|
|
}
|
|
|
|
func addUser(ctx context.Context, conn dbtx, id uuid.UUID, login string, username string, passwordHash []byte) error {
|
|
const stmt = `INSERT INTO users (id,login,username,password)
|
|
VALUES ($1,$2,$3,$4);`
|
|
|
|
_, err := conn.Exec(ctx, stmt, id, login, username, passwordHash)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert user data into users table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getUserByID(ctx context.Context, conn dbtx, id uuid.UUID) (*models.User, error) {
|
|
const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1`
|
|
u := new(models.User)
|
|
row := conn.QueryRow(ctx, stmt, id)
|
|
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch user by id: %w", err)
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) {
|
|
const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1`
|
|
u := new(models.User)
|
|
row := conn.QueryRow(ctx, stmt, login)
|
|
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("failed to fetch user by login: %w", err)
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func addSession(ctx context.Context, conn dbtx, session *models.Session) (uuid.UUID, error) {
|
|
const stmt = `INSERT INTO sessions (id,session_token, csrf_token, user_id,
|
|
created_at, expired_at) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id;`
|
|
var id uuid.UUID
|
|
row := conn.QueryRow(ctx, stmt, session.ID, session.SessionToken, session.CsrfToken, session.UserID, session.CreatedAt, session.ExpiredAt)
|
|
if err := row.Scan(&id); err != nil {
|
|
return uuid.Nil, fmt.Errorf("failed to insert new session: %w", err)
|
|
}
|
|
|
|
return id, nil
|
|
}
|