draincloud-core/internal/storage/postgres/database.go

148 lines
4.5 KiB
Go

package postgres
import (
"context"
"database/sql"
"fmt"
"log/slog"
"time"
"git.optclblast.xyz/draincloud/draincloud-core/internal/closer"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type Database struct {
db *pgx.Conn
cluster *ShardCluster
}
func New(ctx context.Context, dsn string) *Database {
db, err := pgx.Connect(ctx, dsn)
if err != nil {
logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err))
}
closer.Add(func() error {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
return db.Close(ctx)
})
return &Database{db: db}
}
type dbtx interface {
Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
}
func (d *Database) AddUser(ctx context.Context, id uuid.UUID, login string, username string, passwordHash []byte) error {
return addUser(ctx, d.db, id, login, username, passwordHash)
}
func (d *Database) GetUserByID(ctx context.Context, id uuid.UUID) (*models.User, error) {
return getUserByID(ctx, d.db, id)
}
func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
return getUserByLogin(ctx, d.db, login)
}
func (d *Database) AddSession(ctx context.Context, ses *models.Session) (uuid.UUID, error) {
return addSession(ctx, d.db, ses)
}
func (d *Database) GetSession(ctx context.Context, sessionToken string) (*models.Session, error) {
const stmt = `SELECT
s.id, s.session_token, s.csrf_token, s.user_id, s.created_at, s.expired_at
FROM sessions as s
WHERE s.session_token = $1;`
row := d.db.QueryRow(ctx, stmt, sessionToken)
var (
id uuid.UUID
sesToken, csrfToken string
userID uuid.UUID
createdAt sql.NullTime
expiredAt sql.NullTime
)
if err := row.Scan(&id, &sesToken, &csrfToken, &userID, &createdAt, &expiredAt); err != nil {
return nil, err
}
return &models.Session{
ID: id,
SessionToken: sesToken,
CsrfToken: csrfToken,
UserID: userID,
CreatedAt: createdAt.Time,
ExpiredAt: expiredAt.Time,
}, nil
}
func (d *Database) RemoveSession(ctx context.Context, id uuid.UUID) error {
const stmt = `DELETE FROM sessions WHERE id = $1;`
_, err := d.db.Exec(ctx, stmt, id)
return err
}
func (d *Database) RemoveExpiredSessions(ctx context.Context) error {
const stmt = `DELETE FROM sessions WHERE expired_at < $1;`
res, err := d.db.Exec(ctx, stmt, time.Now())
logger.Notice(ctx, "[Database][RemoveExpiredSessions] sessions cleanup", slog.Int64("removed", res.RowsAffected()))
return err
}
func addUser(ctx context.Context, conn dbtx, id uuid.UUID, login string, username string, passwordHash []byte) error {
const stmt = `INSERT INTO users (id,login,username,password)
VALUES ($1,$2,$3,$4);`
_, err := conn.Exec(ctx, stmt, id, login, username, passwordHash)
if err != nil {
return fmt.Errorf("failed to insert user data into users table: %w", err)
}
return nil
}
func getUserByID(ctx context.Context, conn dbtx, id uuid.UUID) (*models.User, error) {
const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, id)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by id: %w", err)
}
return u, nil
}
func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) {
const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, login)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by login: %w", err)
}
return u, nil
}
func addSession(ctx context.Context, conn dbtx, session *models.Session) (uuid.UUID, error) {
const stmt = `INSERT INTO sessions (id,session_token, csrf_token, user_id,
created_at, expired_at) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id;`
var id uuid.UUID
row := conn.QueryRow(ctx, stmt, session.ID, session.SessionToken, session.CsrfToken, session.UserID, session.CreatedAt, session.ExpiredAt)
if err := row.Scan(&id); err != nil {
return uuid.Nil, fmt.Errorf("failed to insert new session: %w", err)
}
return id, nil
}