draincloud-core/internal/storage/postgres/database.go

90 lines
2.6 KiB
Go
Raw Normal View History

2024-09-27 22:37:58 +00:00
package postgres
import (
"context"
"fmt"
"time"
"git.optclblast.xyz/draincloud/draincloud-light/internal/closer"
"git.optclblast.xyz/draincloud/draincloud-light/internal/logger"
"git.optclblast.xyz/draincloud/draincloud-light/internal/storage/models"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type Database struct {
db *pgx.Conn
}
func New(ctx context.Context, dsn string) *Database {
db, err := pgx.Connect(ctx, dsn)
if err != nil {
logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err))
}
closer.Add(func() error {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
return db.Close(ctx)
})
return &Database{db: db}
}
type dbtx interface {
Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
}
func (d *Database) AddUser(ctx context.Context, login string, username string, passwordHash []byte) (uint64, error) {
return addUser(ctx, d.db, login, username, passwordHash)
}
func (d *Database) GetUserByID(ctx context.Context, id uint64) (*models.User, error) {
return getUserByID(ctx, d.db, id)
}
func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
return getUserByLogin(ctx, d.db, login)
}
func addUser(ctx context.Context, conn dbtx, login string, username string, passwordHash []byte) (uint64, error) {
const stmt = `INSERT INTO users (login,username,password)
VALUES ($1,$2,$3,$4) RETURNING id`
row := conn.QueryRow(ctx, stmt, login, username, passwordHash)
var id uint64
if err := row.Scan(&id); err != nil {
return 0, fmt.Errorf("failed to insert user data into users table: %w", err)
}
return id, nil
}
func getUserByID(ctx context.Context, conn dbtx, id uint64) (*models.User, error) {
const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, id)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by id: %w", err)
}
return u, nil
}
func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) {
const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1`
u := new(models.User)
row := conn.QueryRow(ctx, stmt, login)
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, fmt.Errorf("failed to fetch user by login: %w", err)
}
return u, nil
}