package postgres import ( "context" "database/sql" "fmt" "log/slog" "time" "git.optclblast.xyz/draincloud/draincloud-core/internal/closer" "git.optclblast.xyz/draincloud/draincloud-core/internal/logger" "git.optclblast.xyz/draincloud/draincloud-core/internal/storage/models" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) type Database struct { db *pgx.Conn } func New(ctx context.Context, dsn string) *Database { db, err := pgx.Connect(ctx, dsn) if err != nil { logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err)) } closer.Add(func() error { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() return db.Close(ctx) }) return &Database{db: db} } type dbtx interface { Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) } func (d *Database) AddUser(ctx context.Context, login string, username string, passwordHash []byte) (int64, error) { return addUser(ctx, d.db, login, username, passwordHash) } func (d *Database) GetUserByID(ctx context.Context, id uint64) (*models.User, error) { return getUserByID(ctx, d.db, id) } func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) { return getUserByLogin(ctx, d.db, login) } func (d *Database) AddSession(ctx context.Context, ses *models.Session) (int64, error) { return addSession(ctx, d.db, ses) } func (d *Database) GetSession(ctx context.Context, sessionToken string) (*models.Session, error) { const stmt = `SELECT s.id, s.session_token, s.csrf_token, s.user_id, s.created_at, s.expired_at FROM sessions as s WHERE s.session_token = $1;` row := d.db.QueryRow(ctx, stmt, sessionToken) var ( id int64 sesToken, csrfToken string userID int64 createdAt sql.NullTime expiredAt sql.NullTime ) if err := row.Scan(&id, &sesToken, &csrfToken, &userID, &createdAt, &expiredAt); err != nil { return nil, err } return &models.Session{ ID: id, SessionToken: sesToken, CsrfToken: csrfToken, UserID: userID, CreatedAt: createdAt.Time, ExpiredAt: expiredAt.Time, }, nil } func (d *Database) RemoveSession(ctx context.Context, id int64) error { const stmt = `DELETE FROM sessions WHERE id = $1;` _, err := d.db.Exec(ctx, stmt, id) return err } func (d *Database) RemoveExpiredSessions(ctx context.Context) error { const stmt = `DELETE FROM sessions WHERE expired_at < $1;` res, err := d.db.Exec(ctx, stmt, time.Now()) logger.Notice(ctx, "[Database][RemoveExpiredSessions] sessions cleanup", slog.Int64("removed", res.RowsAffected())) return err } func addUser(ctx context.Context, conn dbtx, login string, username string, passwordHash []byte) (int64, error) { const stmt = `INSERT INTO users (login,username,password) VALUES ($1,$2,$3) RETURNING id` row := conn.QueryRow(ctx, stmt, login, username, passwordHash) var id int64 if err := row.Scan(&id); err != nil { return 0, fmt.Errorf("failed to insert user data into users table: %w", err) } return id, nil } func getUserByID(ctx context.Context, conn dbtx, id uint64) (*models.User, error) { const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1` u := new(models.User) row := conn.QueryRow(ctx, stmt, id) if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil { return nil, fmt.Errorf("failed to fetch user by id: %w", err) } return u, nil } func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) { const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1` u := new(models.User) row := conn.QueryRow(ctx, stmt, login) if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil { return nil, fmt.Errorf("failed to fetch user by login: %w", err) } return u, nil } func addSession(ctx context.Context, conn dbtx, session *models.Session) (int64, error) { const stmt = `INSERT INTO sessions (session_token, csrf_token, user_id, created_at, expired_at) VALUES ($1, $2, $3, $4, $5) RETURNING id;` var id int64 row := conn.QueryRow(ctx, stmt, session.SessionToken, session.CsrfToken, session.UserID, session.CreatedAt, session.ExpiredAt) if err := row.Scan(&id); err != nil { return 0, fmt.Errorf("failed to insert new session: %w", err) } return id, nil }