feat: add change password endpoint using old password;

feat: implement change password service method with validation;
fix: correct ErrorIsOneOf function logic to return true on match;
refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity;
refactor: update session termination to return GUIDs and cache in Redis;
fix: ensure RollbackOnError only rolls back uncommitted transactions;
fix: handle transaction commit errors properly in dbHelper;
refactor: add helper methods for session termination and registration;
refactor: pass client info to login and registration complete methods;
fix: improve token validation error handling in refresh endpoint;
refactor: update auth middleware to set session info correctly;
chore: remove unused ClientInfo DTO;
fix: correct password reset complete to use session termination helper;
refactor: adjust database queries for session management;
chore: update SQL schema and queries for sessions;
docs: update swagger docs with new endpoint and model changes
This commit is contained in:
2025-07-17 03:44:22 +03:00
parent 8b558eaf5f
commit 827928178e
14 changed files with 454 additions and 173 deletions

View File

@@ -38,6 +38,44 @@ const docTemplate = `{
"responses": {} "responses": {}
} }
}, },
"/auth/changePassword": {
"post": {
"security": [
{
"JWT": []
}
],
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Auth"
],
"summary": "Set new password using the old password",
"parameters": [
{
"description": " ",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/models.ChangePasswordRequest"
}
}
],
"responses": {
"200": {
"description": "Password successfully changed"
},
"403": {
"description": "Invalid old password"
}
}
}
},
"/auth/login": { "/auth/login": {
"post": { "post": {
"consumes": [ "consumes": [
@@ -391,6 +429,24 @@ const docTemplate = `{
} }
} }
}, },
"models.ChangePasswordRequest": {
"type": "object",
"required": [
"old_password",
"password"
],
"properties": {
"old_password": {
"type": "string"
},
"password": {
"type": "string"
},
"totp": {
"type": "string"
}
}
},
"models.LoginRequest": { "models.LoginRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -445,7 +501,7 @@ const docTemplate = `{
"email": { "email": {
"type": "string" "type": "string"
}, },
"log_out_accounts": { "log_out_sessions": {
"type": "boolean" "type": "boolean"
}, },
"password": { "password": {

View File

@@ -34,6 +34,44 @@
"responses": {} "responses": {}
} }
}, },
"/auth/changePassword": {
"post": {
"security": [
{
"JWT": []
}
],
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Auth"
],
"summary": "Set new password using the old password",
"parameters": [
{
"description": " ",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/models.ChangePasswordRequest"
}
}
],
"responses": {
"200": {
"description": "Password successfully changed"
},
"403": {
"description": "Invalid old password"
}
}
}
},
"/auth/login": { "/auth/login": {
"post": { "post": {
"consumes": [ "consumes": [
@@ -387,6 +425,24 @@
} }
} }
}, },
"models.ChangePasswordRequest": {
"type": "object",
"required": [
"old_password",
"password"
],
"properties": {
"old_password": {
"type": "string"
},
"password": {
"type": "string"
},
"totp": {
"type": "string"
}
}
},
"models.LoginRequest": { "models.LoginRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -441,7 +497,7 @@
"email": { "email": {
"type": "string" "type": "string"
}, },
"log_out_accounts": { "log_out_sessions": {
"type": "boolean" "type": "boolean"
}, },
"password": { "password": {

View File

@@ -5,6 +5,18 @@ definitions:
healthy: healthy:
type: boolean type: boolean
type: object type: object
models.ChangePasswordRequest:
properties:
old_password:
type: string
password:
type: string
totp:
type: string
required:
- old_password
- password
type: object
models.LoginRequest: models.LoginRequest:
properties: properties:
password: password:
@@ -38,7 +50,7 @@ definitions:
properties: properties:
email: email:
type: string type: string
log_out_accounts: log_out_sessions:
type: boolean type: boolean
password: password:
type: string type: string
@@ -125,6 +137,29 @@ paths:
summary: Change account password summary: Change account password
tags: tags:
- Account - Account
/auth/changePassword:
post:
consumes:
- application/json
parameters:
- description: ' '
in: body
name: request
required: true
schema:
$ref: '#/definitions/models.ChangePasswordRequest'
produces:
- application/json
responses:
"200":
description: Password successfully changed
"403":
description: Invalid old password
security:
- JWT: []
summary: Set new password using the old password
tags:
- Auth
/auth/login: /auth/login:
post: post:
consumes: consumes:

View File

@@ -38,6 +38,7 @@ type AuthController interface {
Refresh(c *gin.Context) Refresh(c *gin.Context)
PasswordResetBegin(c *gin.Context) PasswordResetBegin(c *gin.Context)
PasswordResetComplete(c *gin.Context) PasswordResetComplete(c *gin.Context)
ChangePassword(c *gin.Context)
Router Router
} }
@@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) {
return return
} }
response, err := a.auth.Login(request.Body) response, err := a.auth.Login(request.User, request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrForbidden) { if errors.Is(err, errs.ErrForbidden) {
@@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
response, err := a.auth.Refresh(request.Body) response, err := a.auth.Refresh(request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrTokenExpired) { if utils.ErrorIsOneOf(
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"}) err,
} else if errors.Is(err, errs.ErrTokenInvalid) { errs.ErrTokenExpired,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"}) errs.ErrTokenInvalid,
} else if errors.Is(err, errs.ErrWrongTokenType) { errs.ErrInvalidToken,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"}) errs.ErrWrongTokenType,
} else if errors.Is(err, errs.ErrSessionNotFound) { errs.ErrSessionNotFound,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"}) errs.ErrSessionTerminated,
} else if errors.Is(err, errs.ErrSessionTerminated) { ) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
} else { } else {
c.Status(http.StatusInternalServerError) c.JSON(http.StatusInternalServerError, err.Error())
} }
return return
} }
@@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
return return
} }
response, err := a.auth.RegistrationComplete(request.Body) response, err := a.auth.RegistrationComplete(request.User, request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrForbidden) { if errors.Is(err, errs.ErrForbidden) {
@@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
} }
// @Summary Set new password using the old password
// @Tags Auth
// @Accept json
// @Produce json
// @Security JWT
// @Param request body models.ChangePasswordRequest true " "
// @Success 200 "Password successfully changed"
// @Failure 403 "Invalid old password"
// @Router /auth/changePassword [post]
func (a *authControllerImpl) ChangePassword(c *gin.Context) {
request, ok := utils.GetRequest[models.ChangePasswordRequest](c)
if !ok {
c.Status(http.StatusBadRequest)
return
}
response, err := a.auth.ChangePassword(request.Body, request.User)
if err != nil {
if errors.Is(err, errs.ErrForbidden) {
c.Status(http.StatusForbidden)
} else {
c.Status(http.StatusInternalServerError)
}
return
}
c.JSON(http.StatusOK, response)
}
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) { func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin) group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete) group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
@@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh) group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin) group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete) group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword)
} }

View File

@@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct {
TXlessQueries Queries TXlessQueries Queries
TX pgx.Tx TX pgx.Tx
TXQueries Queries TXQueries Queries
isCommited bool
} }
func NewDbHelper(dbContext DbContext) DbHelper { func NewDbHelper(dbContext DbContext) DbHelper {
@@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error {
errCommit := d.TX.Commit(d.CTX) errCommit := d.TX.Commit(d.CTX)
if errCommit != nil { if errCommit != nil {
errRollback := d.TX.Rollback(d.CTX) d.isCommited = true
if errRollback != nil {
return errRollback
}
return errCommit
} }
return nil
return errCommit
} }
// Rollback implements DbHelperTransaction. // Rollback implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) Rollback() error { func (d *dbHelperTransactionImpl) Rollback() error {
err := d.TX.Rollback(d.CTX) if d.isCommited {
return nil
if err != nil {
return err
} }
return nil
return d.TX.Rollback(d.CTX);
} }
// RollbackOnError implements DbHelperTransaction. // RollbackOnError implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error { func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
if err != nil { if d.isCommited || err == nil {
return d.Rollback() return d.Rollback()
} }
return nil return nil

View File

@@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error {
return err return err
} }
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many
UPDATE sessions UPDATE sessions
SET terminated = TRUE SET terminated = TRUE
FROM users FROM users
WHERE sessions.user_id = users.id AND users.username = $1::text WHERE sessions.user_id = users.id AND users.username = $1::text
RETURNING sessions.guid
` `
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error { func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) {
_, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username) rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username)
return err if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var guid pgtype.UUID
if err := rows.Scan(&guid); err != nil {
return nil, err
}
items = append(items, guid)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
} }
const updateBannedUser = `-- name: UpdateBannedUser :exec const updateBannedUser = `-- name: UpdateBannedUser :exec

View File

@@ -31,7 +31,7 @@ var (
ErrServerError = errors.New("Internal server error") ErrServerError = errors.New("Internal server error")
ErrTokenExpired = errors.New("Token is expired") ErrTokenExpired = errors.New("Token is expired")
ErrTokenInvalid = errors.New("Token is invalid") ErrTokenInvalid = ErrInvalidToken
ErrWrongTokenType = errors.New("Invalid token type") ErrWrongTokenType = errors.New("Invalid token type")
ErrSessionNotFound = errors.New("Could not find session in database") ErrSessionNotFound = errors.New("Could not find session in database")
ErrSessionTerminated = errors.New("Session is terminated") ErrSessionTerminated = errors.New("Session is terminated")

View File

@@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc
} }
return return
} else { } else {
c.Set("session_info", sessionInfo) c.Set("session_info", *sessionInfo)
c.Next() c.Next()
} }

View File

@@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct {
type PasswordResetCompleteResponse struct { type PasswordResetCompleteResponse struct {
Tokens Tokens
} }
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"password" binding:"required" validate:"password"`
TOTP string `json:"totp"`
}
type ChangePasswordResponse struct {
Tokens
}

View File

@@ -40,19 +40,20 @@ import (
type AuthService interface { type AuthService interface {
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
Login(request models.LoginRequest) (*models.LoginResponse, error) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error) PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
} }
type authServiceImpl struct { type authServiceImpl struct {
log *zap.Logger log *zap.Logger
dbctx database.DbContext dbctx database.DbContext
redis *redis.Client redis *redis.Client
smtp SmtpService smtp SmtpService
} }
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService { func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
panic("Failed to cache terminated session: " + err.Error()) panic("Failed to cache terminated session: " + err.Error())
} }
} }
if _, err := pipe.Exec(ctx); err != nil {
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
}
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
return authService return authService
} }
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
a.log.Error(
"Failed to terminate older sessions for user trying to log in",
zap.String("username", username),
zap.Error(err))
return err
}
pipe := a.redis.Pipeline()
for _, guid := range sessionGuids {
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
}
if _, err := pipe.Exec(ctx); err != nil {
a.log.Error(
"Failed to cache terminated sessions",
zap.Error(err))
return err
}
return nil
}
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
UserID: userID,
Name: utils.NewPointer(cinfo.UserAgent),
Platform: utils.NewPointer(cinfo.UserAgent),
LatestIp: utils.NewPointer(cinfo.IP),
}); if err != nil {
a.log.Error(
"Failed to add session to database",
zap.Error(err))
return nil, err
}
a.log.Info(
"Registered a new user session",
zap.String("username", cinfo.Username),
zap.String("session", cinfo.Session))
return &session, nil
}
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) { func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
var occupationStatus database.CheckUserRegistrationAvailabilityRow var occupationStatus database.CheckUserRegistrationAvailabilityRow
@@ -91,18 +143,17 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to open a transaction", "Failed to open a transaction",
zap.Error(err)) zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
} }
defer helper.RollbackOnError(err)
defer helper.RollbackOnError(err) isInProgress, err := a.redis.Get(
if isInProgress, err := a.redis.Get(
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", fmt.Sprintf("email::%s::registration_in_progress",
request.Email), request.Email),
).Bool(); err != nil { ).Bool(); if err != nil {
if err != redis.Nil { if err != redis.Nil {
a.log.Error( a.log.Error(
"Failed to look up cached registration_in_progress state of email as part of registration procedure", "Failed to look up cached registration_in_progress state of email as part of registration procedure",
@@ -113,13 +164,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
isInProgress = false isInProgress = false
} else if isInProgress { } else if isInProgress {
a.log.Warn( a.log.Warn(
"Attempted to begin registration on email that is in progress of registration or on cooldown", "Attempted to begin registration on email that is in progress of registration or on cooldown",
zap.String("email", request.Email)) zap.String("email", request.Email))
return false, errs.ErrTooManyRequests return false, errs.ErrTooManyRequests
} }
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{ if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
Email: request.Email, Email: request.Email,
Username: request.Username, Username: request.Username,
}); err != nil { }); err != nil {
a.log.Error( a.log.Error(
@@ -141,12 +192,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
// Falsely confirm in order to avoid disclosing registered email addresses // Falsely confirm in order to avoid disclosing registered email addresses
if err := a.redis.Set( if err := a.redis.Set(
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email), fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true, true,
time.Duration(10 * time.Minute), // XXX: magic number time.Duration(10*time.Minute), // XXX: magic number
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration", "Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.Error(err)) zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
@@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
} else { } else {
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{ if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
Username: request.Username, Username: request.Username,
Email: request.Email, Email: request.Email,
}); err != nil { }); err != nil {
a.log.Error( a.log.Error(
"Failed to purge unverified accounts as part of registration", "Failed to purge unverified accounts as part of registration",
@@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
} }
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{ if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
UserID: user.ID, UserID: user.ID,
Email: utils.NewPointer(request.Email), Email: utils.NewPointer(request.Email),
PasswordHash: passwordHash, // Hashed in database PasswordHash: passwordHash, // Hashed in database
}); err != nil { }); err != nil {
@@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
} }
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{ if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
UserID: user.ID, UserID: user.ID,
CodeType: int32(enums.RegistrationCodeType), CodeType: int32(enums.RegistrationCodeType),
CodeHash: generatedCodeHash, // Hashed in database CodeHash: generatedCodeHash, // Hashed in database
}); err != nil { }); err != nil {
a.log.Error("Failed to add registration code to database", zap.Error(err)) a.log.Error("Failed to add registration code to database", zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
@@ -224,8 +275,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
if config.GetConfig().SmtpEnabled { if config.GetConfig().SmtpEnabled {
if err := a.smtp.SendEmail( if err := a.smtp.SendEmail(
request.Email, request.Email,
"Easywish", "Easywish",
fmt.Sprintf("Your registration code is %s", generatedCode), fmt.Sprintf("Your registration code is %s", generatedCode),
); err != nil { ); err != nil {
a.log.Error( a.log.Error(
@@ -246,12 +297,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
if err := a.redis.Set( if err := a.redis.Set(
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email), fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true, true,
time.Duration(10 * time.Minute), // XXX: magic number time.Duration(10*time.Minute), // XXX: magic number
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to cache registration_in_progress state for email", "Failed to cache registration_in_progress state for email",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.Error(err)) zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
@@ -267,19 +318,18 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
return true, nil return true, nil
} }
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) { func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
var user database.User var user database.User
var profile database.Profile var profile database.Profile
var session database.Session var confirmationCode database.ConfirmationCode
var confirmationCode database.ConfirmationCode
var accessToken, refreshToken string var accessToken, refreshToken string
var err error var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to open a transaction", "Failed to open a transaction",
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -290,8 +340,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
a.log.Warn( a.log.Warn(
"Could not find user attempting to complete registration with given username", "Could not find user attempting to complete registration with given username",
zap.String("username", request.Username), zap.String("username", request.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrUserNotFound return nil, errs.ErrUserNotFound
} }
@@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
} }
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{ confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
UserID: user.ID, UserID: user.ID,
CodeType: int32(enums.RegistrationCodeType), CodeType: int32(enums.RegistrationCodeType),
Code: request.VerificationCode, Code: request.VerificationCode,
}) })
if err != nil { if err != nil {
@@ -327,10 +377,10 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
} }
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{ err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
ID: confirmationCode.ID, ID: confirmationCode.ID,
Used: utils.NewPointer(true), Used: utils.NewPointer(true),
}) })
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to update the user's registration code used state", "Failed to update the user's registration code used state",
@@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{ err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
ID: user.ID, ID: user.ID,
Verified: utils.NewPointer(true), Verified: utils.NewPointer(true),
}) })
if err != nil { if err != nil {
a.log.Error("Failed to update verified status for user", a.log.Error("Failed to update verified status for user",
zap.String("username", user.Username), zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{ profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
UserID: user.ID, UserID: user.ID,
Name: request.Name, Name: request.Name,
}) })
if err != nil { if err != nil {
a.log.Error("Failed to create profile for user", a.log.Error("Failed to create profile for user",
zap.String("username", user.Username), zap.String("username", user.Username),
) )
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -369,29 +418,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
if err != nil { if err != nil {
a.log.Error("Failed to create profile settings for user", a.log.Error("Failed to create profile settings for user",
zap.String("username", user.Username),
zap.Error(err))
return nil, errs.ErrServerError
}
// TODO: session info
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID,
Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create a new session during registration, rolling back registration",
zap.String("username", user.Username), zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
a.log.Error("", zap.Error(err))
return nil, errs.ErrServerError
}
// TODO: get user role // TODO: get user role
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole) accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
@@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
zap.String("username", request.Username)) zap.String("username", request.Username))
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{ response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}} }}
@@ -421,19 +459,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
} }
// TODO: totp // TODO: totp
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) { func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
var userRow database.GetValidUserByLoginCredentialsRow var userRow database.GetValidUserByLoginCredentialsRow
var session database.Session
var err error var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to open a transaction", "Failed to open a transaction",
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
defer helper.RollbackOnError(err) defer helper.RollbackOnError(err)
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{ userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
Username: request.Username, Username: request.Username,
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
} }
// Until release 4, only 1 session at a time is supported // Until release 4, only 1 session at a time is supported
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil { err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
a.log.Error( a.log.Error(
"Failed to terminate older sessions for user trying to log in", "Failed to terminate user's sessions during login",
zap.String("username", request.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
// TODO: use actual values for session metadata a.log.Error("", zap.Error(err))
UserID: userRow.ID,
Name: utils.NewPointer("New device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create session for a new login",
zap.String("username", userRow.Username),
zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
} }
response := models.LoginResponse{Tokens: models.Tokens{ response := models.LoginResponse{Tokens: models.Tokens{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}} }}
@@ -509,12 +534,13 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil { sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
if err != nil {
if utils.ErrorIsOneOf( if utils.ErrorIsOneOf(
err, err,
errs.ErrInvalidToken, errs.ErrInvalidToken,
errs.ErrTokenExpired, errs.ErrTokenExpired,
errs.ErrWrongTokenType, errs.ErrWrongTokenType,
errs.ErrSessionNotFound, errs.ErrSessionNotFound,
errs.ErrSessionTerminated, errs.ErrSessionTerminated,
@@ -522,19 +548,20 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
return nil, err return nil, err
} else { } else {
a.log.Error( a.log.Error(
"Encountered an unexpected error while validating token", "Encountered an unexpected error while validating token",
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
} }
accessToken, refreshToken, err := utils.GenerateTokens( accessToken, refreshToken, err := utils.GenerateTokens(
sessionInfo.Username, sessionInfo.Username,
sessionInfo.Session, sessionInfo.Session,
sessionInfo.Role, sessionInfo.Role,
); if err != nil { )
if err != nil {
a.log.Error( a.log.Error(
"Failed to generate tokens for user during refresh", "Failed to generate tokens for user during refresh",
zap.String("username", sessionInfo.Username), zap.String("username", sessionInfo.Username),
zap.String("session", sessionInfo.Session), zap.String("session", sessionInfo.Session),
zap.Error(err)) zap.Error(err))
@@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
response := models.RefreshResponse{ response := models.RefreshResponse{
Tokens: models.Tokens{ Tokens: models.Tokens{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}, },
} }
@@ -551,13 +578,13 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
return &response, nil return &response, nil
} }
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
var err error var err error
token, err := jwt.ParseWithClaims( token, err := jwt.ParseWithClaims(
jwtToken, jwtToken,
&dto.UserClaims{}, &dto.UserClaims{},
func(token *jwt.Token) (any, error) { func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -569,11 +596,12 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
if err != nil { if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) { if errors.Is(err, jwt.ErrTokenExpired) {
return nil, errs.ErrTokenExpired return nil, errs.ErrTokenExpired
} }
return nil, errs.ErrInvalidToken return nil, errs.ErrInvalidToken
} }
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid { claims, ok := token.Claims.(*dto.UserClaims)
if ok && token.Valid {
if claims.Type != tokenType { if claims.Type != tokenType {
return nil, errs.ErrWrongTokenType return nil, errs.ErrWrongTokenType
@@ -583,7 +611,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil { if redisErr != nil && redisErr != redis.Nil {
a.log.Error( a.log.Error(
"Failed to lookup cache to check whether session is not terminated", "Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr)) zap.Error(redisErr))
return nil, redisErr return nil, redisErr
} }
@@ -603,20 +631,20 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
} }
a.log.Error( a.log.Error(
"Failed to lookup session in database", "Failed to lookup session in database",
zap.String("session", claims.Session), zap.String("session", claims.Session),
zap.Error(err)) zap.Error(err))
return nil, err return nil, err
} }
if err := a.redis.Set( if err := a.redis.Set(
ctx, ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session), fmt.Sprintf("session::%s::is_terminated", claims.Session),
*session.Terminated, *session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number time.Duration(8*time.Hour), // XXX: magic number
).Err(); err != nil { ).Err(); err != nil {
a.log.Error( a.log.Error(
"Failed to cache session's is_terminated state", "Failed to cache session's is_terminated state",
zap.String("session", claims.Session), zap.String("session", claims.Session),
zap.Error(err)) zap.Error(err))
// c.AbortWithStatus(http.StatusInternalServerError) // c.AbortWithStatus(http.StatusInternalServerError)
@@ -633,15 +661,15 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
sessionInfo := dto.SessionInfo{ sessionInfo := dto.SessionInfo{
Username: claims.Username, Username: claims.Username,
Session: claims.Session, Session: claims.Session,
Role: claims.Role, Role: claims.Role,
} }
return &sessionInfo, nil return &sessionInfo, nil
} }
func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) { func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) {
var user database.User var user database.User
var generatedCode, hashedCode string var generatedCode, hashedCode string
var err error var err error
@@ -649,7 +677,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to open a transaction", "Failed to open a transaction",
zap.Error(err)) zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
} }
@@ -660,7 +688,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64() cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64()
if redisErr != nil && redisErr != redis.Nil { if redisErr != nil && redisErr != redis.Nil {
a.log.Error( a.log.Error(
"Failed to get reset_cooldown state for user", "Failed to get reset_cooldown state for user",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.Error(redisErr)) zap.Error(redisErr))
return false, errs.ErrServerError return false, errs.ErrServerError
@@ -677,7 +705,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
// Enable cooldown for the email despite that account does not exist // Enable cooldown for the email despite that account does not exist
err := a.redis.Set( err := a.redis.Set(
ctx, ctx,
fmt.Sprintf("email::%s::reset_cooldown", request.Email), fmt.Sprintf("email::%s::reset_cooldown", request.Email),
time.Now().Add(10*time.Minute), time.Now().Add(10*time.Minute),
time.Duration(10*time.Minute), time.Duration(10*time.Minute),
@@ -687,7 +715,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
a.log.Error( a.log.Error(
"Failed to set reset cooldown for email", "Failed to set reset cooldown for email",
zap.Error(err)) zap.Error(err))
return false, err return false, err
} }
a.log.Warn( a.log.Warn(
@@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
} }
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{ if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
UserID: user.ID, UserID: user.ID,
CodeType: int32(enums.PasswordResetCodeType), CodeType: int32(enums.PasswordResetCodeType),
CodeHash: hashedCode, CodeHash: hashedCode,
}); err != nil { }); err != nil {
@@ -723,7 +751,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
} }
err = a.redis.Set( err = a.redis.Set(
ctx, ctx,
fmt.Sprintf("email::%s::reset_cooldown", request.Email), fmt.Sprintf("email::%s::reset_cooldown", request.Email),
time.Now().Add(10*time.Minute), time.Now().Add(10*time.Minute),
time.Duration(10*time.Minute), time.Duration(10*time.Minute),
@@ -733,7 +761,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
a.log.Error( a.log.Error(
"Failed to set reset cooldown for email. Cancelling password reset", "Failed to set reset cooldown for email. Cancelling password reset",
zap.Error(err)) zap.Error(err))
return false, err return false, err
} }
if err = helper.Commit(); err != nil { if err = helper.Commit(); err != nil {
@@ -757,7 +785,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil { if err != nil {
a.log.Error( a.log.Error(
"Failed to open a transaction", "Failed to open a transaction",
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -779,13 +807,13 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
} }
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{ if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
UserID: user.ID, UserID: user.ID,
CodeType: int32(enums.PasswordResetCodeType), CodeType: int32(enums.PasswordResetCodeType),
Code: request.VerificationCode, Code: request.VerificationCode,
}); err != nil { }); err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
a.log.Warn( a.log.Warn(
"Attempted to reset password for user using incorrect confirmation code", "Attempted to reset password for user using incorrect confirmation code",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.String("username", user.Username), zap.String("username", user.Username),
zap.String("provided_code", request.VerificationCode), zap.String("provided_code", request.VerificationCode),
@@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
} }
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{ if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
ID: resetCode.ID, ID: resetCode.ID,
Used: utils.NewPointer(true), Used: utils.NewPointer(true),
}); err != nil { }); err != nil {
a.log.Error( a.log.Error(
@@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
} }
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{ if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
Username: user.Username, Username: user.Username,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
}); err != nil { }); err != nil {
a.log.Error( a.log.Error(
@@ -828,24 +856,22 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
} }
if request.LogOutSessions { if request.LogOutSessions {
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil { err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
a.log.Error( a.log.Error(
"Failed to log out older sessions as part of user password reset", "Failed to terminate user's sessions during login",
zap.String("email", request.Email),
zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
} }
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID, UserID: user.ID,
Name: utils.NewPointer("First device"), Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"), Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"), LatestIp: utils.NewPointer("Unknown"),
}); err != nil { }); if err != nil {
a.log.Error( a.log.Error(
"Failed to create new session for user as part of user password reset", "Failed to create new session for user as part of user password reset",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.String("username", user.Username), zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
@@ -854,7 +880,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
// TODO: get user role // TODO: get user role
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil { if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil {
a.log.Error( a.log.Error(
"Failed to generate tokens as part of user password reset", "Failed to generate tokens as part of user password reset",
zap.String("email", request.Email), zap.String("email", request.Email),
zap.String("username", user.Username), zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
@@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
response := models.PasswordResetCompleteResponse{ response := models.PasswordResetCompleteResponse{
Tokens: models.Tokens{ Tokens: models.Tokens{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}, },
} }
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
return &response, nil return &response, nil
} }
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
a.log.Error(
"Failed to open a transaction",
zap.Error(err))
return false, errs.ErrServerError
}
defer helper.RollbackOnError(err)
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
a.log.Error(
"Failed to get user login information",
zap.Error(err))
return false, errs.ErrServerError
}
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
a.log.Warn(
"Provided invalid old password while changing password",
zap.String("username", uinfo.Username))
return false, errs.ErrForbidden
}
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
a.log.Error(
"Failed to hash new password while changing password",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
Username: uinfo.Username,
PasswordHash: newPasswordHash,
}); if err != nil {
a.log.Error(
"Failed to save new password into database",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
if err := helper.Commit(); err != nil {
a.log.Error(
"Failed to commit transaction",
zap.Error(err))
return false, errs.ErrServerError
}
return true, nil
}

View File

@@ -22,8 +22,8 @@ import "errors"
func ErrorIsOneOf(err error, ignoreErrors ...error) bool { func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
for _, ignore := range ignoreErrors { for _, ignore := range ignoreErrors {
if errors.Is(err, ignore) { if errors.Is(err, ignore) {
return false return true
} }
} }
return true return false
} }

View File

@@ -259,11 +259,12 @@ WHERE
terminated IS TRUE AND terminated IS TRUE AND
last_refresh_exp_time > CURRENT_TIMESTAMP; last_refresh_exp_time > CURRENT_TIMESTAMP;
;-- name: TerminateAllSessionsForUserByUsername :exec ;-- name: TerminateAllSessionsForUserByUsername :many
UPDATE sessions UPDATE sessions
SET terminated = TRUE SET terminated = TRUE
FROM users FROM users
WHERE sessions.user_id = users.id AND users.username = @username::text; WHERE sessions.user_id = users.id AND users.username = @username::text
RETURNING sessions.guid;
;-- name: PruneTerminatedSessions :exec ;-- name: PruneTerminatedSessions :exec
DELETE FROM sessions DELETE FROM sessions

View File

@@ -62,8 +62,8 @@ CREATE TABLE IF NOT EXISTS "sessions" (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
guid UUID NOT NULL DEFAULT gen_random_uuid(), guid UUID NOT NULL DEFAULT gen_random_uuid(),
name VARCHAR(100), name VARCHAR(175),
platform VARCHAR(32), platform VARCHAR(175),
latest_ip VARCHAR(16), latest_ip VARCHAR(16),
login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds', last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',