90 lines
2.6 KiB
Go
90 lines
2.6 KiB
Go
|
package postgres
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"git.optclblast.xyz/draincloud/draincloud-light/internal/closer"
|
||
|
"git.optclblast.xyz/draincloud/draincloud-light/internal/logger"
|
||
|
"git.optclblast.xyz/draincloud/draincloud-light/internal/storage/models"
|
||
|
"github.com/jackc/pgx/v5"
|
||
|
"github.com/jackc/pgx/v5/pgconn"
|
||
|
)
|
||
|
|
||
|
type Database struct {
|
||
|
db *pgx.Conn
|
||
|
}
|
||
|
|
||
|
func New(ctx context.Context, dsn string) *Database {
|
||
|
db, err := pgx.Connect(ctx, dsn)
|
||
|
if err != nil {
|
||
|
logger.Fatal(ctx, "failed to connect to postgres", logger.Err(err))
|
||
|
}
|
||
|
|
||
|
closer.Add(func() error {
|
||
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||
|
defer cancel()
|
||
|
return db.Close(ctx)
|
||
|
})
|
||
|
|
||
|
return &Database{db: db}
|
||
|
}
|
||
|
|
||
|
type dbtx interface {
|
||
|
Exec(ctx context.Context, stmt string, args ...any) (pgconn.CommandTag, error)
|
||
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||
|
}
|
||
|
|
||
|
func (d *Database) AddUser(ctx context.Context, login string, username string, passwordHash []byte) (uint64, error) {
|
||
|
return addUser(ctx, d.db, login, username, passwordHash)
|
||
|
}
|
||
|
|
||
|
func (d *Database) GetUserByID(ctx context.Context, id uint64) (*models.User, error) {
|
||
|
return getUserByID(ctx, d.db, id)
|
||
|
}
|
||
|
|
||
|
func (d *Database) GetUserByLogin(ctx context.Context, login string) (*models.User, error) {
|
||
|
return getUserByLogin(ctx, d.db, login)
|
||
|
}
|
||
|
|
||
|
func addUser(ctx context.Context, conn dbtx, login string, username string, passwordHash []byte) (uint64, error) {
|
||
|
const stmt = `INSERT INTO users (login,username,password)
|
||
|
VALUES ($1,$2,$3,$4) RETURNING id`
|
||
|
|
||
|
row := conn.QueryRow(ctx, stmt, login, username, passwordHash)
|
||
|
|
||
|
var id uint64
|
||
|
|
||
|
if err := row.Scan(&id); err != nil {
|
||
|
return 0, fmt.Errorf("failed to insert user data into users table: %w", err)
|
||
|
}
|
||
|
|
||
|
return id, nil
|
||
|
}
|
||
|
|
||
|
func getUserByID(ctx context.Context, conn dbtx, id uint64) (*models.User, error) {
|
||
|
const stmt = `SELECT * FROM users WHERE id = $1 LIMIT 1`
|
||
|
u := new(models.User)
|
||
|
|
||
|
row := conn.QueryRow(ctx, stmt, id)
|
||
|
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||
|
return nil, fmt.Errorf("failed to fetch user by id: %w", err)
|
||
|
}
|
||
|
|
||
|
return u, nil
|
||
|
}
|
||
|
|
||
|
func getUserByLogin(ctx context.Context, conn dbtx, login string) (*models.User, error) {
|
||
|
const stmt = `SELECT * FROM users WHERE login = $1 LIMIT 1`
|
||
|
u := new(models.User)
|
||
|
|
||
|
row := conn.QueryRow(ctx, stmt, login)
|
||
|
if err := row.Scan(&u.ID, &u.Login, &u.Username, &u.PasswordHash, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||
|
return nil, fmt.Errorf("failed to fetch user by login: %w", err)
|
||
|
}
|
||
|
|
||
|
return u, nil
|
||
|
}
|