draincloud-core/internal/storage/postgres/database.go

148 lines
4.5 KiB
Go
Raw Normal View History

2024-09-27 22:37:58 +00:00
package postgres
import (
"context"
2024-11-23 08:52:06 +00:00
"database/sql"
2024-09-27 22:37:58 +00:00
"fmt"
2024-11-23 08:52:06 +00:00
"log/slog"
2024-09-27 22:37:58 +00:00
"time"
2024-10-10 21:36:51 +00:00
"git.optclblast.xyz/draincloud/draincloud-core/internal/closer"
"git.optclblast.xyz/draincloud/draincloud-core/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models"
2024-12-15 16:56:03 +00:00
"github.com/google/uuid"
2024-09-27 22:37:58 +00:00
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type Database struct {
db *pgx.Conn
}
func New(ctx context.Context, dsn string) *Database {
db, err := pgx.Connect(ctx, dsn)
if err != nil {
logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err))
}
closer.Add(func() error {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
return db.Close(ctx)
})
return &Database{db: db}
}
type dbtx interface {
Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
}
2024-12-15 16:56:03 +00:00
func (d *Database) AddUser(ctx context.Context, login string, username string, passwordHash []byte) (uuid.UUID, error) {
2024-09-27 22:37:58 +00:00
return addUser(ctx, d.db, login, username, passwordHash)
}
2024-12-15 16:56:03 +00:00
func (d *Database) GetUserByID(ctx context.Context, id uuid.UUID) (*models.User, error) {
2024-09-27 22:37:58 +00:00
return getUserByID(ctx, d.db, id)
}
func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
return getUserByLogin(ctx, d.db, login)
}
2024-12-15 16:56:03 +00:00
func (d *Database) AddSession(ctx context.Context, ses *models.Session) (uuid.UUID, error) {
2024-10-27 16:44:05 +00:00
return addSession(ctx, d.db, ses)
}
2024-11-23 08:52:06 +00:00
func (d *Database) GetSession(ctx context.Context, sessionToken string) (*models.Session, error) {
const stmt = `SELECT
s.id, s.session_token, s.csrf_token, s.user_id, s.created_at, s.expired_at
FROM sessions as s
WHERE s.session_token = $1;`
row := d.db.QueryRow(ctx, stmt, sessionToken)
var (
2024-12-15 16:56:03 +00:00
id uuid.UUID
2024-11-23 08:52:06 +00:00
sesToken, csrfToken string
2024-12-15 16:56:03 +00:00
userID uuid.UUID
2024-11-23 08:52:06 +00:00
createdAt sql.NullTime
expiredAt sql.NullTime
)
if err := row.Scan(&id, &sesToken, &csrfToken, &userID, &createdAt, &expiredAt); err != nil {
return nil, err
}
return &models.Session{
ID: id,
SessionToken: sesToken,
CsrfToken: csrfToken,
UserID: userID,
CreatedAt: createdAt.Time,
ExpiredAt: expiredAt.Time,
}, nil
}
2024-12-15 16:56:03 +00:00
func (d *Database) RemoveSession(ctx context.Context, id uuid.UUID) error {
2024-11-23 08:52:06 +00:00
const stmt = `DELETE FROM sessions WHERE id = $1;`
_, err := d.db.Exec(ctx, stmt, id)
return err
}
func (d *Database) RemoveExpiredSessions(ctx context.Context) error {
const stmt = `DELETE FROM sessions WHERE expired_at < $1;`
res, err := d.db.Exec(ctx, stmt, time.Now())
logger.Notice(ctx, "[Database][RemoveExpiredSessions] sessions cleanup", slog.Int64("removed", res.RowsAffected()))
return err
}
2024-12-15 16:56:03 +00:00
func addUser(ctx context.Context, conn dbtx, login string, username string, passwordHash []byte) (uuid.UUID, error) {
2024-09-27 22:37:58 +00:00
const stmt = `INSERT INTO users (login,username,password)
2024-10-27 05:07:27 +00:00
VALUES ($1,$2,$3) RETURNING id`
2024-09-27 22:37:58 +00:00
row := conn.QueryRow(ctx, stmt, login, username, passwordHash)
2024-12-15 16:56:03 +00:00
var id uuid.UUID
2024-09-27 22:37:58 +00:00
if err := row.Scan(&id); err != nil {
2024-12-15 16:56:03 +00:00
return uuid.Nil, fmt.Errorf("failed to insert user data into users table: %w", err)
2024-09-27 22:37:58 +00:00
}
return id, nil
}
2024-12-15 16:56:03 +00:00
func getUserByID(ctx context.Context, conn dbtx, id uuid.UUID) (*models.User, error) {
2024-09-27 22:37:58 +00:00
const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, id)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by id: %w", err)
}
return u, nil
}
func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) {
const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, login)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by login: %w", err)
}
return u, nil
}
2024-10-27 16:44:05 +00:00
2024-12-15 16:56:03 +00:00
func addSession(ctx context.Context, conn dbtx, session *models.Session) (uuid.UUID, error) {
2024-10-27 16:44:05 +00:00
const stmt = `INSERT INTO sessions (session_token, csrf_token, user_id,
created_at, expired_at) VALUES ($1, $2, $3, $4, $5) RETURNING id;`
2024-12-15 16:56:03 +00:00
var id uuid.UUID
2024-11-23 08:52:06 +00:00
row := conn.QueryRow(ctx, stmt, session.SessionToken, session.CsrfToken, session.UserID, session.CreatedAt, session.ExpiredAt)
2024-10-27 16:44:05 +00:00
if err := row.Scan(&id); err != nil {
2024-12-15 16:56:03 +00:00
return uuid.Nil, fmt.Errorf("failed to insert new session: %w", err)
2024-10-27 16:44:05 +00:00
}
return id, nil
}