feat: add user role support to database and queries;
fix: add max length validation for refresh token in RefreshRequest; refactor: use named constants for cache durations in AuthService; refactor: select all user columns in GetValidUserByLoginCredentials query;
This commit is contained in:
@@ -80,5 +80,6 @@ type User struct {
|
|||||||
Username string
|
Username string
|
||||||
Verified *bool
|
Verified *bool
|
||||||
RegistrationDate pgtype.Timestamp
|
RegistrationDate pgtype.Timestamp
|
||||||
|
Role int32
|
||||||
Deleted *bool
|
Deleted *bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
|||||||
|
|
||||||
const createUser = `-- name: CreateUser :one
|
const createUser = `-- name: CreateUser :one
|
||||||
INSERT INTO users(username, verified)
|
INSERT INTO users(username, verified)
|
||||||
VALUES ($1, false) RETURNING id, username, verified, registration_date, deleted
|
VALUES ($1, false) RETURNING id, username, verified, registration_date, role, deleted
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) {
|
func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) {
|
||||||
@@ -247,6 +247,7 @@ func (q *Queries) CreateUser(ctx context.Context, username string) (User, error)
|
|||||||
&i.Username,
|
&i.Username,
|
||||||
&i.Verified,
|
&i.Verified,
|
||||||
&i.RegistrationDate,
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
&i.Deleted,
|
&i.Deleted,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
@@ -264,7 +265,7 @@ WITH deleted_rows AS (
|
|||||||
AND linfo.email = $2::text
|
AND linfo.email = $2::text
|
||||||
))
|
))
|
||||||
AND verified IS FALSE
|
AND verified IS FALSE
|
||||||
RETURNING id, username, verified, registration_date, deleted
|
RETURNING id, username, verified, registration_date, role, deleted
|
||||||
)
|
)
|
||||||
SELECT COUNT(*) AS deleted_count FROM deleted_rows
|
SELECT COUNT(*) AS deleted_count FROM deleted_rows
|
||||||
`
|
`
|
||||||
@@ -543,7 +544,7 @@ func (q *Queries) GetUnexpiredTerminatedSessionsGuidsPaginated(ctx context.Conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUser = `-- name: GetUser :one
|
const getUser = `-- name: GetUser :one
|
||||||
SELECT id, username, verified, registration_date, deleted FROM users
|
SELECT id, username, verified, registration_date, role, deleted FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -555,6 +556,7 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) {
|
|||||||
&i.Username,
|
&i.Username,
|
||||||
&i.Verified,
|
&i.Verified,
|
||||||
&i.RegistrationDate,
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
&i.Deleted,
|
&i.Deleted,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
@@ -630,7 +632,7 @@ func (q *Queries) GetUserBansByUsername(ctx context.Context, username string) ([
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||||
SELECT users.id, users.username, users.verified, users.registration_date, users.deleted FROM users
|
SELECT users.id, users.username, users.verified, users.registration_date, users.role, users.deleted FROM users
|
||||||
JOIN login_informations linfo ON linfo.user_id = users.id
|
JOIN login_informations linfo ON linfo.user_id = users.id
|
||||||
WHERE linfo.email = $1::text
|
WHERE linfo.email = $1::text
|
||||||
`
|
`
|
||||||
@@ -643,13 +645,14 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
|
|||||||
&i.Username,
|
&i.Username,
|
||||||
&i.Verified,
|
&i.Verified,
|
||||||
&i.RegistrationDate,
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
&i.Deleted,
|
&i.Deleted,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getUserByUsername = `-- name: GetUserByUsername :one
|
const getUserByUsername = `-- name: GetUserByUsername :one
|
||||||
SELECT id, username, verified, registration_date, deleted FROM users
|
SELECT id, username, verified, registration_date, role, deleted FROM users
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -661,6 +664,7 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
|
|||||||
&i.Username,
|
&i.Username,
|
||||||
&i.Verified,
|
&i.Verified,
|
||||||
&i.RegistrationDate,
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
&i.Deleted,
|
&i.Deleted,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
@@ -698,7 +702,7 @@ func (q *Queries) GetValidConfirmationCodeByCode(ctx context.Context, arg GetVal
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getValidConfirmationCodesByUsername = `-- name: GetValidConfirmationCodesByUsername :many
|
const getValidConfirmationCodesByUsername = `-- name: GetValidConfirmationCodesByUsername :many
|
||||||
SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, users.deleted FROM confirmation_codes
|
SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, role, users.deleted FROM confirmation_codes
|
||||||
JOIN users on users.id = confirmation_codes.user_id
|
JOIN users on users.id = confirmation_codes.user_id
|
||||||
WHERE
|
WHERE
|
||||||
users.username = $1::text AND
|
users.username = $1::text AND
|
||||||
@@ -724,6 +728,7 @@ type GetValidConfirmationCodesByUsernameRow struct {
|
|||||||
Username string
|
Username string
|
||||||
Verified *bool
|
Verified *bool
|
||||||
RegistrationDate pgtype.Timestamp
|
RegistrationDate pgtype.Timestamp
|
||||||
|
Role int32
|
||||||
Deleted_2 *bool
|
Deleted_2 *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,6 +753,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G
|
|||||||
&i.Username,
|
&i.Username,
|
||||||
&i.Verified,
|
&i.Verified,
|
||||||
&i.RegistrationDate,
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
&i.Deleted_2,
|
&i.Deleted_2,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -762,8 +768,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G
|
|||||||
|
|
||||||
const getValidUserByLoginCredentials = `-- name: GetValidUserByLoginCredentials :one
|
const getValidUserByLoginCredentials = `-- name: GetValidUserByLoginCredentials :one
|
||||||
SELECT
|
SELECT
|
||||||
users.id,
|
users.id, users.username, users.verified, users.registration_date, users.role, users.deleted,
|
||||||
users.username,
|
|
||||||
linfo.password_hash,
|
linfo.password_hash,
|
||||||
linfo.totp_encrypted
|
linfo.totp_encrypted
|
||||||
FROM users
|
FROM users
|
||||||
@@ -785,6 +790,10 @@ type GetValidUserByLoginCredentialsParams struct {
|
|||||||
type GetValidUserByLoginCredentialsRow struct {
|
type GetValidUserByLoginCredentialsRow struct {
|
||||||
ID int64
|
ID int64
|
||||||
Username string
|
Username string
|
||||||
|
Verified *bool
|
||||||
|
RegistrationDate pgtype.Timestamp
|
||||||
|
Role int32
|
||||||
|
Deleted *bool
|
||||||
PasswordHash string
|
PasswordHash string
|
||||||
TotpEncrypted *string
|
TotpEncrypted *string
|
||||||
}
|
}
|
||||||
@@ -795,6 +804,10 @@ func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetVal
|
|||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.ID,
|
&i.ID,
|
||||||
&i.Username,
|
&i.Username,
|
||||||
|
&i.Verified,
|
||||||
|
&i.RegistrationDate,
|
||||||
|
&i.Role,
|
||||||
|
&i.Deleted,
|
||||||
&i.PasswordHash,
|
&i.PasswordHash,
|
||||||
&i.TotpEncrypted,
|
&i.TotpEncrypted,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,9 +49,8 @@ type LoginResponse struct {
|
|||||||
Tokens
|
Tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: length check
|
|
||||||
type RefreshRequest struct {
|
type RefreshRequest struct {
|
||||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
RefreshToken string `json:"refresh_token" binding:"required,max=2000"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RefreshResponse struct {
|
type RefreshResponse struct {
|
||||||
|
|||||||
@@ -38,6 +38,11 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
AuthTerminatedSessionCacheDuration = time.Duration(8 * time.Hour)
|
||||||
|
AuthRegistrationCooldownCacheDuration = time.Duration(10 * time.Minute)
|
||||||
|
)
|
||||||
|
|
||||||
type AuthService interface {
|
type AuthService interface {
|
||||||
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
||||||
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||||
@@ -87,7 +92,7 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
|
|||||||
pipe := _redis.Pipeline()
|
pipe := _redis.Pipeline()
|
||||||
for _, guid := range guids {
|
for _, guid := range guids {
|
||||||
key := fmt.Sprintf("session::%s::is_terminated", guid)
|
key := fmt.Sprintf("session::%s::is_terminated", guid)
|
||||||
pipe.Set(ctx, key, true, time.Duration(8 * time.Hour)) // XXX: magic number
|
pipe.Set(ctx, key, true, AuthTerminatedSessionCacheDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
@@ -122,7 +127,12 @@ func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, usern
|
|||||||
|
|
||||||
pipe := a.redis.Pipeline()
|
pipe := a.redis.Pipeline()
|
||||||
for _, guid := range sessionGuids {
|
for _, guid := range sessionGuids {
|
||||||
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
|
pipe.Set(
|
||||||
|
ctx,
|
||||||
|
fmt.Sprintf("session::%s::is_terminated", guid),
|
||||||
|
true,
|
||||||
|
AuthTerminatedSessionCacheDuration,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pipe.Exec(ctx); err != nil {
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
@@ -219,7 +229,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
context.TODO(),
|
context.TODO(),
|
||||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||||
true,
|
true,
|
||||||
time.Duration(10*time.Minute), // XXX: magic number
|
AuthRegistrationCooldownCacheDuration,
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
||||||
@@ -324,7 +334,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
context.TODO(),
|
context.TODO(),
|
||||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||||
true,
|
true,
|
||||||
time.Duration(10*time.Minute), // XXX: magic number
|
AuthTerminatedSessionCacheDuration,
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to cache registration_in_progress state for email",
|
"Failed to cache registration_in_progress state for email",
|
||||||
@@ -453,8 +463,7 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod
|
|||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: get user role
|
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.Role(user.Role))
|
||||||
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
@@ -483,7 +492,6 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod
|
|||||||
return &response, nil
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: totp
|
|
||||||
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
||||||
var userRow database.GetValidUserByLoginCredentialsRow
|
var userRow database.GetValidUserByLoginCredentialsRow
|
||||||
var err error
|
var err error
|
||||||
@@ -500,12 +508,7 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
|
|||||||
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
||||||
Username: request.Username,
|
Username: request.Username,
|
||||||
Password: request.Password,
|
Password: request.Password,
|
||||||
})
|
}); if err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
a.log.Warn(
|
|
||||||
"Failed login attempt",
|
|
||||||
zap.Error(err))
|
|
||||||
|
|
||||||
var returnedError error
|
var returnedError error
|
||||||
|
|
||||||
@@ -516,6 +519,9 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
|
|||||||
returnedError = errs.ErrServerError
|
returnedError = errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a.log.Warn(
|
||||||
|
"Failed login attempt",
|
||||||
|
zap.Error(err))
|
||||||
return nil, returnedError
|
return nil, returnedError
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -532,9 +538,11 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques
|
|||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: get user role
|
accessToken, refreshToken, err := utils.GenerateTokens(
|
||||||
accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String(), enums.UserRole)
|
userRow.Username,
|
||||||
if err != nil {
|
session.Guid.String(),
|
||||||
|
enums.Role(userRow.Role),
|
||||||
|
); if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to generate tokens for a new login",
|
"Failed to generate tokens for a new login",
|
||||||
zap.String("username", userRow.Username),
|
zap.String("username", userRow.Username),
|
||||||
@@ -666,7 +674,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|||||||
ctx,
|
ctx,
|
||||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
||||||
*session.Terminated,
|
*session.Terminated,
|
||||||
time.Duration(8*time.Hour), // XXX: magic number
|
AuthTerminatedSessionCacheDuration,
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to cache session's is_terminated state",
|
"Failed to cache session's is_terminated state",
|
||||||
@@ -902,8 +910,11 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: get user role
|
if accessToken, refreshToken, err = utils.GenerateTokens(
|
||||||
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil {
|
user.Username,
|
||||||
|
session.Guid.String(),
|
||||||
|
enums.UserRole,
|
||||||
|
); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to generate tokens as part of user password reset",
|
"Failed to generate tokens as part of user password reset",
|
||||||
zap.String("email", request.Email),
|
zap.String("email", request.Email),
|
||||||
|
|||||||
@@ -81,8 +81,7 @@ WHERE linfo.email = @email::text;
|
|||||||
|
|
||||||
;-- name: GetValidUserByLoginCredentials :one
|
;-- name: GetValidUserByLoginCredentials :one
|
||||||
SELECT
|
SELECT
|
||||||
users.id,
|
users.*,
|
||||||
users.username,
|
|
||||||
linfo.password_hash,
|
linfo.password_hash,
|
||||||
linfo.totp_encrypted
|
linfo.totp_encrypted
|
||||||
FROM users
|
FROM users
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS "users" (
|
|||||||
username VARCHAR(20) UNIQUE NOT NULL,
|
username VARCHAR(20) UNIQUE NOT NULL,
|
||||||
verified BOOLEAN DEFAULT FALSE,
|
verified BOOLEAN DEFAULT FALSE,
|
||||||
registration_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
registration_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
role INTEGER NOT NULL DEFAULT 0,
|
||||||
deleted BOOLEAN DEFAULT FALSE
|
deleted BOOLEAN DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user