feat: add user role support to database and queries;

fix: add max length validation for refresh token in RefreshRequest;
refactor: use named constants for cache durations in AuthService;
refactor: select all user columns in GetValidUserByLoginCredentials query;
This commit is contained in:
2025-07-17 04:31:25 +03:00
parent b986d45d82
commit 249bbe4a98
6 changed files with 59 additions and 35 deletions

View File

@@ -80,5 +80,6 @@ type User struct {
Username string Username string
Verified *bool Verified *bool
RegistrationDate pgtype.Timestamp RegistrationDate pgtype.Timestamp
Role int32
Deleted *bool Deleted *bool
} }

View File

@@ -236,7 +236,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
const createUser = `-- name: CreateUser :one const createUser = `-- name: CreateUser :one
INSERT INTO users(username, verified) INSERT INTO users(username, verified)
VALUES ($1, false) RETURNING id, username, verified, registration_date, deleted VALUES ($1, false) RETURNING id, username, verified, registration_date, role, deleted
` `
func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) { func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) {
@@ -247,6 +247,7 @@ func (q *Queries) CreateUser(ctx context.Context, username string) (User, error)
&i.Username, &i.Username,
&i.Verified, &i.Verified,
&i.RegistrationDate, &i.RegistrationDate,
&i.Role,
&i.Deleted, &i.Deleted,
) )
return i, err return i, err
@@ -264,7 +265,7 @@ WITH deleted_rows AS (
AND linfo.email = $2::text AND linfo.email = $2::text
)) ))
AND verified IS FALSE AND verified IS FALSE
RETURNING id, username, verified, registration_date, deleted RETURNING id, username, verified, registration_date, role, deleted
) )
SELECT COUNT(*) AS deleted_count FROM deleted_rows SELECT COUNT(*) AS deleted_count FROM deleted_rows
` `
@@ -543,7 +544,7 @@ func (q *Queries) GetUnexpiredTerminatedSessionsGuidsPaginated(ctx context.Conte
} }
const getUser = `-- name: GetUser :one const getUser = `-- name: GetUser :one
SELECT id, username, verified, registration_date, deleted FROM users SELECT id, username, verified, registration_date, role, deleted FROM users
WHERE id = $1 WHERE id = $1
` `
@@ -555,6 +556,7 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) {
&i.Username, &i.Username,
&i.Verified, &i.Verified,
&i.RegistrationDate, &i.RegistrationDate,
&i.Role,
&i.Deleted, &i.Deleted,
) )
return i, err return i, err
@@ -630,7 +632,7 @@ func (q *Queries) GetUserBansByUsername(ctx context.Context, username string) ([
} }
const getUserByEmail = `-- name: GetUserByEmail :one const getUserByEmail = `-- name: GetUserByEmail :one
SELECT users.id, users.username, users.verified, users.registration_date, users.deleted FROM users SELECT users.id, users.username, users.verified, users.registration_date, users.role, users.deleted FROM users
JOIN login_informations linfo ON linfo.user_id = users.id JOIN login_informations linfo ON linfo.user_id = users.id
WHERE linfo.email = $1::text WHERE linfo.email = $1::text
` `
@@ -643,13 +645,14 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
&i.Username, &i.Username,
&i.Verified, &i.Verified,
&i.RegistrationDate, &i.RegistrationDate,
&i.Role,
&i.Deleted, &i.Deleted,
) )
return i, err return i, err
} }
const getUserByUsername = `-- name: GetUserByUsername :one const getUserByUsername = `-- name: GetUserByUsername :one
SELECT id, username, verified, registration_date, deleted FROM users SELECT id, username, verified, registration_date, role, deleted FROM users
WHERE username = $1 WHERE username = $1
` `
@@ -661,6 +664,7 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
&i.Username, &i.Username,
&i.Verified, &i.Verified,
&i.RegistrationDate, &i.RegistrationDate,
&i.Role,
&i.Deleted, &i.Deleted,
) )
return i, err return i, err
@@ -698,7 +702,7 @@ func (q *Queries) GetValidConfirmationCodeByCode(ctx context.Context, arg GetVal
} }
const getValidConfirmationCodesByUsername = `-- name: GetValidConfirmationCodesByUsername :many const getValidConfirmationCodesByUsername = `-- name: GetValidConfirmationCodesByUsername :many
SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, users.deleted FROM confirmation_codes SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, role, users.deleted FROM confirmation_codes
JOIN users on users.id = confirmation_codes.user_id JOIN users on users.id = confirmation_codes.user_id
WHERE WHERE
users.username = $1::text AND users.username = $1::text AND
@@ -724,6 +728,7 @@ type GetValidConfirmationCodesByUsernameRow struct {
Username string Username string
Verified *bool Verified *bool
RegistrationDate pgtype.Timestamp RegistrationDate pgtype.Timestamp
Role int32
Deleted_2 *bool Deleted_2 *bool
} }
@@ -748,6 +753,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G
&i.Username, &i.Username,
&i.Verified, &i.Verified,
&i.RegistrationDate, &i.RegistrationDate,
&i.Role,
&i.Deleted_2, &i.Deleted_2,
); err != nil { ); err != nil {
return nil, err return nil, err
@@ -762,8 +768,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G
const getValidUserByLoginCredentials = `-- name: GetValidUserByLoginCredentials :one const getValidUserByLoginCredentials = `-- name: GetValidUserByLoginCredentials :one
SELECT SELECT
users.id, users.id, users.username, users.verified, users.registration_date, users.role, users.deleted,
users.username,
linfo.password_hash, linfo.password_hash,
linfo.totp_encrypted linfo.totp_encrypted
FROM users FROM users
@@ -783,10 +788,14 @@ type GetValidUserByLoginCredentialsParams struct {
} }
type GetValidUserByLoginCredentialsRow struct { type GetValidUserByLoginCredentialsRow struct {
ID int64 ID int64
Username string Username string
PasswordHash string Verified *bool
TotpEncrypted *string RegistrationDate pgtype.Timestamp
Role int32
Deleted *bool
PasswordHash string
TotpEncrypted *string
} }
func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetValidUserByLoginCredentialsParams) (GetValidUserByLoginCredentialsRow, error) { func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetValidUserByLoginCredentialsParams) (GetValidUserByLoginCredentialsRow, error) {
@@ -795,6 +804,10 @@ func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetVal
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
&i.Username, &i.Username,
&i.Verified,
&i.RegistrationDate,
&i.Role,
&i.Deleted,
&i.PasswordHash, &i.PasswordHash,
&i.TotpEncrypted, &i.TotpEncrypted,
) )

View File

@@ -49,9 +49,8 @@ type LoginResponse struct {
Tokens Tokens
} }
// TODO: length check
type RefreshRequest struct { type RefreshRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token" binding:"required,max=2000"`
} }
type RefreshResponse struct { type RefreshResponse struct {

View File

@@ -38,6 +38,11 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var (
AuthTerminatedSessionCacheDuration = time.Duration(8 * time.Hour)
AuthRegistrationCooldownCacheDuration = time.Duration(10 * time.Minute)
)
type AuthService interface { type AuthService interface {
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
@@ -87,7 +92,7 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
pipe := _redis.Pipeline() pipe := _redis.Pipeline()
for _, guid := range guids { for _, guid := range guids {
key := fmt.Sprintf("session::%s::is_terminated", guid) key := fmt.Sprintf("session::%s::is_terminated", guid)
pipe.Set(ctx, key, true, time.Duration(8 * time.Hour)) // XXX: magic number pipe.Set(ctx, key, true, AuthTerminatedSessionCacheDuration)
} }
if _, err := pipe.Exec(ctx); err != nil { if _, err := pipe.Exec(ctx); err != nil {
@@ -122,7 +127,12 @@ func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, usern
pipe := a.redis.Pipeline() pipe := a.redis.Pipeline()
for _, guid := range sessionGuids { for _, guid := range sessionGuids {
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number pipe.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", guid),
true,
AuthTerminatedSessionCacheDuration,
)
} }
if _, err := pipe.Exec(ctx); err != nil { if _, err := pipe.Exec(ctx); err != nil {
@@ -219,7 +229,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email), fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true, true,
time.Duration(10*time.Minute), // XXX: magic number AuthRegistrationCooldownCacheDuration,
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration", "Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
@@ -324,7 +334,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email), fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true, true,
time.Duration(10*time.Minute), // XXX: magic number AuthTerminatedSessionCacheDuration,
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to cache registration_in_progress state for email", "Failed to cache registration_in_progress state for email",
@@ -453,8 +463,7 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
// TODO: get user role accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.Role(user.Role))
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
@@ -483,7 +492,6 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod
return &response, nil return &response, nil
} }
// TODO: totp
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) { func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
var userRow database.GetValidUserByLoginCredentialsRow var userRow database.GetValidUserByLoginCredentialsRow
var err error var err error
@@ -500,12 +508,7 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{ userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
Username: request.Username, Username: request.Username,
Password: request.Password, Password: request.Password,
}) }); if err != nil {
if err != nil {
a.log.Warn(
"Failed login attempt",
zap.Error(err))
var returnedError error var returnedError error
@@ -516,6 +519,9 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
returnedError = errs.ErrServerError returnedError = errs.ErrServerError
} }
a.log.Warn(
"Failed login attempt",
zap.Error(err))
return nil, returnedError return nil, returnedError
} }
@@ -532,9 +538,11 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
// TODO: get user role accessToken, refreshToken, err := utils.GenerateTokens(
accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String(), enums.UserRole) userRow.Username,
if err != nil { session.Guid.String(),
enums.Role(userRow.Role),
); if err != nil {
a.log.Error( a.log.Error(
"Failed to generate tokens for a new login", "Failed to generate tokens for a new login",
zap.String("username", userRow.Username), zap.String("username", userRow.Username),
@@ -666,7 +674,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
ctx, ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session), fmt.Sprintf("session::%s::is_terminated", claims.Session),
*session.Terminated, *session.Terminated,
time.Duration(8*time.Hour), // XXX: magic number AuthTerminatedSessionCacheDuration,
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to cache session's is_terminated state", "Failed to cache session's is_terminated state",
@@ -902,8 +910,11 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
zap.Error(err)) zap.Error(err))
} }
// TODO: get user role if accessToken, refreshToken, err = utils.GenerateTokens(
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil { user.Username,
session.Guid.String(),
enums.UserRole,
); err != nil {
a.log.Error( a.log.Error(
"Failed to generate tokens as part of user password reset", "Failed to generate tokens as part of user password reset",
zap.String("email", request.Email), zap.String("email", request.Email),

View File

@@ -81,8 +81,7 @@ WHERE linfo.email = @email::text;
;-- name: GetValidUserByLoginCredentials :one ;-- name: GetValidUserByLoginCredentials :one
SELECT SELECT
users.id, users.*,
users.username,
linfo.password_hash, linfo.password_hash,
linfo.totp_encrypted linfo.totp_encrypted
FROM users FROM users

View File

@@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS "users" (
username VARCHAR(20) UNIQUE NOT NULL, username VARCHAR(20) UNIQUE NOT NULL,
verified BOOLEAN DEFAULT FALSE, verified BOOLEAN DEFAULT FALSE,
registration_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, registration_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
role INTEGER NOT NULL DEFAULT 0,
deleted BOOLEAN DEFAULT FALSE deleted BOOLEAN DEFAULT FALSE
); );