feat: add change password endpoint using old password;

feat: implement change password service method with validation;
fix: correct ErrorIsOneOf function logic to return true on match;
refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity;
refactor: update session termination to return GUIDs and cache in Redis;
fix: ensure RollbackOnError only rolls back uncommitted transactions;
fix: handle transaction commit errors properly in dbHelper;
refactor: add helper methods for session termination and registration;
refactor: pass client info to login and registration complete methods;
fix: improve token validation error handling in refresh endpoint;
refactor: update auth middleware to set session info correctly;
chore: remove unused ClientInfo DTO;
fix: correct password reset complete to use session termination helper;
refactor: adjust database queries for session management;
chore: update SQL schema and queries for sessions;
docs: update swagger docs with new endpoint and model changes
This commit is contained in:
2025-07-17 03:44:22 +03:00
parent 8b558eaf5f
commit 827928178e
14 changed files with 454 additions and 173 deletions

View File

@@ -38,6 +38,44 @@ const docTemplate = `{
"responses": {} "responses": {}
} }
}, },
"/auth/changePassword": {
"post": {
"security": [
{
"JWT": []
}
],
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Auth"
],
"summary": "Set new password using the old password",
"parameters": [
{
"description": " ",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/models.ChangePasswordRequest"
}
}
],
"responses": {
"200": {
"description": "Password successfully changed"
},
"403": {
"description": "Invalid old password"
}
}
}
},
"/auth/login": { "/auth/login": {
"post": { "post": {
"consumes": [ "consumes": [
@@ -391,6 +429,24 @@ const docTemplate = `{
} }
} }
}, },
"models.ChangePasswordRequest": {
"type": "object",
"required": [
"old_password",
"password"
],
"properties": {
"old_password": {
"type": "string"
},
"password": {
"type": "string"
},
"totp": {
"type": "string"
}
}
},
"models.LoginRequest": { "models.LoginRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -445,7 +501,7 @@ const docTemplate = `{
"email": { "email": {
"type": "string" "type": "string"
}, },
"log_out_accounts": { "log_out_sessions": {
"type": "boolean" "type": "boolean"
}, },
"password": { "password": {

View File

@@ -34,6 +34,44 @@
"responses": {} "responses": {}
} }
}, },
"/auth/changePassword": {
"post": {
"security": [
{
"JWT": []
}
],
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Auth"
],
"summary": "Set new password using the old password",
"parameters": [
{
"description": " ",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/models.ChangePasswordRequest"
}
}
],
"responses": {
"200": {
"description": "Password successfully changed"
},
"403": {
"description": "Invalid old password"
}
}
}
},
"/auth/login": { "/auth/login": {
"post": { "post": {
"consumes": [ "consumes": [
@@ -387,6 +425,24 @@
} }
} }
}, },
"models.ChangePasswordRequest": {
"type": "object",
"required": [
"old_password",
"password"
],
"properties": {
"old_password": {
"type": "string"
},
"password": {
"type": "string"
},
"totp": {
"type": "string"
}
}
},
"models.LoginRequest": { "models.LoginRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -441,7 +497,7 @@
"email": { "email": {
"type": "string" "type": "string"
}, },
"log_out_accounts": { "log_out_sessions": {
"type": "boolean" "type": "boolean"
}, },
"password": { "password": {

View File

@@ -5,6 +5,18 @@ definitions:
healthy: healthy:
type: boolean type: boolean
type: object type: object
models.ChangePasswordRequest:
properties:
old_password:
type: string
password:
type: string
totp:
type: string
required:
- old_password
- password
type: object
models.LoginRequest: models.LoginRequest:
properties: properties:
password: password:
@@ -38,7 +50,7 @@ definitions:
properties: properties:
email: email:
type: string type: string
log_out_accounts: log_out_sessions:
type: boolean type: boolean
password: password:
type: string type: string
@@ -125,6 +137,29 @@ paths:
summary: Change account password summary: Change account password
tags: tags:
- Account - Account
/auth/changePassword:
post:
consumes:
- application/json
parameters:
- description: ' '
in: body
name: request
required: true
schema:
$ref: '#/definitions/models.ChangePasswordRequest'
produces:
- application/json
responses:
"200":
description: Password successfully changed
"403":
description: Invalid old password
security:
- JWT: []
summary: Set new password using the old password
tags:
- Auth
/auth/login: /auth/login:
post: post:
consumes: consumes:

View File

@@ -38,6 +38,7 @@ type AuthController interface {
Refresh(c *gin.Context) Refresh(c *gin.Context)
PasswordResetBegin(c *gin.Context) PasswordResetBegin(c *gin.Context)
PasswordResetComplete(c *gin.Context) PasswordResetComplete(c *gin.Context)
ChangePassword(c *gin.Context)
Router Router
} }
@@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) {
return return
} }
response, err := a.auth.Login(request.Body) response, err := a.auth.Login(request.User, request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrForbidden) { if errors.Is(err, errs.ErrForbidden) {
@@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
response, err := a.auth.Refresh(request.Body) response, err := a.auth.Refresh(request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrTokenExpired) { if utils.ErrorIsOneOf(
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"}) err,
} else if errors.Is(err, errs.ErrTokenInvalid) { errs.ErrTokenExpired,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"}) errs.ErrTokenInvalid,
} else if errors.Is(err, errs.ErrWrongTokenType) { errs.ErrInvalidToken,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"}) errs.ErrWrongTokenType,
} else if errors.Is(err, errs.ErrSessionNotFound) { errs.ErrSessionNotFound,
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"}) errs.ErrSessionTerminated,
} else if errors.Is(err, errs.ErrSessionTerminated) { ) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
} else { } else {
c.Status(http.StatusInternalServerError) c.JSON(http.StatusInternalServerError, err.Error())
} }
return return
} }
@@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
return return
} }
response, err := a.auth.RegistrationComplete(request.Body) response, err := a.auth.RegistrationComplete(request.User, request.Body)
if err != nil { if err != nil {
if errors.Is(err, errs.ErrForbidden) { if errors.Is(err, errs.ErrForbidden) {
@@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
} }
// @Summary Set new password using the old password
// @Tags Auth
// @Accept json
// @Produce json
// @Security JWT
// @Param request body models.ChangePasswordRequest true " "
// @Success 200 "Password successfully changed"
// @Failure 403 "Invalid old password"
// @Router /auth/changePassword [post]
func (a *authControllerImpl) ChangePassword(c *gin.Context) {
request, ok := utils.GetRequest[models.ChangePasswordRequest](c)
if !ok {
c.Status(http.StatusBadRequest)
return
}
response, err := a.auth.ChangePassword(request.Body, request.User)
if err != nil {
if errors.Is(err, errs.ErrForbidden) {
c.Status(http.StatusForbidden)
} else {
c.Status(http.StatusInternalServerError)
}
return
}
c.JSON(http.StatusOK, response)
}
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) { func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin) group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete) group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
@@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh) group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin) group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete) group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword)
} }

View File

@@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct {
TXlessQueries Queries TXlessQueries Queries
TX pgx.Tx TX pgx.Tx
TXQueries Queries TXQueries Queries
isCommited bool
} }
func NewDbHelper(dbContext DbContext) DbHelper { func NewDbHelper(dbContext DbContext) DbHelper {
@@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error {
errCommit := d.TX.Commit(d.CTX) errCommit := d.TX.Commit(d.CTX)
if errCommit != nil { if errCommit != nil {
errRollback := d.TX.Rollback(d.CTX) d.isCommited = true
if errRollback != nil {
return errRollback
} }
return errCommit return errCommit
} }
return nil
}
// Rollback implements DbHelperTransaction. // Rollback implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) Rollback() error { func (d *dbHelperTransactionImpl) Rollback() error {
err := d.TX.Rollback(d.CTX) if d.isCommited {
if err != nil {
return err
}
return nil return nil
} }
return d.TX.Rollback(d.CTX);
}
// RollbackOnError implements DbHelperTransaction. // RollbackOnError implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error { func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
if err != nil { if d.isCommited || err == nil {
return d.Rollback() return d.Rollback()
} }
return nil return nil

View File

@@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error {
return err return err
} }
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many
UPDATE sessions UPDATE sessions
SET terminated = TRUE SET terminated = TRUE
FROM users FROM users
WHERE sessions.user_id = users.id AND users.username = $1::text WHERE sessions.user_id = users.id AND users.username = $1::text
RETURNING sessions.guid
` `
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error { func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) {
_, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username) rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username)
return err if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var guid pgtype.UUID
if err := rows.Scan(&guid); err != nil {
return nil, err
}
items = append(items, guid)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
} }
const updateBannedUser = `-- name: UpdateBannedUser :exec const updateBannedUser = `-- name: UpdateBannedUser :exec

View File

@@ -31,7 +31,7 @@ var (
ErrServerError = errors.New("Internal server error") ErrServerError = errors.New("Internal server error")
ErrTokenExpired = errors.New("Token is expired") ErrTokenExpired = errors.New("Token is expired")
ErrTokenInvalid = errors.New("Token is invalid") ErrTokenInvalid = ErrInvalidToken
ErrWrongTokenType = errors.New("Invalid token type") ErrWrongTokenType = errors.New("Invalid token type")
ErrSessionNotFound = errors.New("Could not find session in database") ErrSessionNotFound = errors.New("Could not find session in database")
ErrSessionTerminated = errors.New("Session is terminated") ErrSessionTerminated = errors.New("Session is terminated")

View File

@@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc
} }
return return
} else { } else {
c.Set("session_info", sessionInfo) c.Set("session_info", *sessionInfo)
c.Next() c.Next()
} }

View File

@@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct {
type PasswordResetCompleteResponse struct { type PasswordResetCompleteResponse struct {
Tokens Tokens
} }
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"password" binding:"required" validate:"password"`
TOTP string `json:"totp"`
}
type ChangePasswordResponse struct {
Tokens
}

View File

@@ -40,11 +40,12 @@ import (
type AuthService interface { type AuthService interface {
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
Login(request models.LoginRequest) (*models.LoginResponse, error) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error) PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
} }
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
panic("Failed to cache terminated session: " + err.Error()) panic("Failed to cache terminated session: " + err.Error())
} }
} }
if _, err := pipe.Exec(ctx); err != nil {
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
}
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
return authService return authService
} }
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
a.log.Error(
"Failed to terminate older sessions for user trying to log in",
zap.String("username", username),
zap.Error(err))
return err
}
pipe := a.redis.Pipeline()
for _, guid := range sessionGuids {
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
}
if _, err := pipe.Exec(ctx); err != nil {
a.log.Error(
"Failed to cache terminated sessions",
zap.Error(err))
return err
}
return nil
}
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
UserID: userID,
Name: utils.NewPointer(cinfo.UserAgent),
Platform: utils.NewPointer(cinfo.UserAgent),
LatestIp: utils.NewPointer(cinfo.IP),
}); if err != nil {
a.log.Error(
"Failed to add session to database",
zap.Error(err))
return nil, err
}
a.log.Info(
"Registered a new user session",
zap.String("username", cinfo.Username),
zap.String("session", cinfo.Session))
return &session, nil
}
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) { func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
var occupationStatus database.CheckUserRegistrationAvailabilityRow var occupationStatus database.CheckUserRegistrationAvailabilityRow
@@ -95,14 +147,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
zap.Error(err)) zap.Error(err))
return false, errs.ErrServerError return false, errs.ErrServerError
} }
defer helper.RollbackOnError(err) defer helper.RollbackOnError(err)
if isInProgress, err := a.redis.Get( isInProgress, err := a.redis.Get(
context.TODO(), context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", fmt.Sprintf("email::%s::registration_in_progress",
request.Email), request.Email),
).Bool(); err != nil { ).Bool(); if err != nil {
if err != redis.Nil { if err != redis.Nil {
a.log.Error( a.log.Error(
"Failed to look up cached registration_in_progress state of email as part of registration procedure", "Failed to look up cached registration_in_progress state of email as part of registration procedure",
@@ -267,11 +318,10 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
return true, nil return true, nil
} }
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) { func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
var user database.User var user database.User
var profile database.Profile var profile database.Profile
var session database.Session
var confirmationCode database.ConfirmationCode var confirmationCode database.ConfirmationCode
var accessToken, refreshToken string var accessToken, refreshToken string
var err error var err error
@@ -360,7 +410,6 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
if err != nil { if err != nil {
a.log.Error("Failed to create profile for user", a.log.Error("Failed to create profile for user",
zap.String("username", user.Username), zap.String("username", user.Username),
) )
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -374,19 +423,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
// TODO: session info session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ a.log.Error("", zap.Error(err))
UserID: user.ID,
Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create a new session during registration, rolling back registration",
zap.String("username", user.Username),
zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -421,9 +459,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
} }
// TODO: totp // TODO: totp
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) { func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
var userRow database.GetValidUserByLoginCredentialsRow var userRow database.GetValidUserByLoginCredentialsRow
var session database.Session
var err error var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx) helper, db, err := database.NewDbHelperTransaction(a.dbctx)
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
} }
// Until release 4, only 1 session at a time is supported // Until release 4, only 1 session at a time is supported
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil { err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
a.log.Error( a.log.Error(
"Failed to terminate older sessions for user trying to log in", "Failed to terminate user's sessions during login",
zap.String("username", request.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
// TODO: use actual values for session metadata a.log.Error("", zap.Error(err))
UserID: userRow.ID,
Name: utils.NewPointer("New device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create session for a new login",
zap.String("username", userRow.Username),
zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
@@ -509,7 +534,8 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil { sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
if err != nil {
if utils.ErrorIsOneOf( if utils.ErrorIsOneOf(
err, err,
@@ -532,7 +558,8 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
sessionInfo.Username, sessionInfo.Username,
sessionInfo.Session, sessionInfo.Session,
sessionInfo.Role, sessionInfo.Role,
); if err != nil { )
if err != nil {
a.log.Error( a.log.Error(
"Failed to generate tokens for user during refresh", "Failed to generate tokens for user during refresh",
zap.String("username", sessionInfo.Username), zap.String("username", sessionInfo.Username),
@@ -573,7 +600,8 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
return nil, errs.ErrInvalidToken return nil, errs.ErrInvalidToken
} }
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid { claims, ok := token.Claims.(*dto.UserClaims)
if ok && token.Valid {
if claims.Type != tokenType { if claims.Type != tokenType {
return nil, errs.ErrWrongTokenType return nil, errs.ErrWrongTokenType
@@ -828,22 +856,20 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
} }
if request.LogOutSessions { if request.LogOutSessions {
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil { err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
a.log.Error( a.log.Error(
"Failed to log out older sessions as part of user password reset", "Failed to terminate user's sessions during login",
zap.String("email", request.Email),
zap.String("username", user.Username),
zap.Error(err)) zap.Error(err))
return nil, errs.ErrServerError return nil, errs.ErrServerError
} }
} }
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID, UserID: user.ID,
Name: utils.NewPointer("First device"), Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"), Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"), LatestIp: utils.NewPointer("Unknown"),
}); err != nil { }); if err != nil {
a.log.Error( a.log.Error(
"Failed to create new session for user as part of user password reset", "Failed to create new session for user as part of user password reset",
zap.String("email", request.Email), zap.String("email", request.Email),
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
return &response, nil return &response, nil
} }
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
a.log.Error(
"Failed to open a transaction",
zap.Error(err))
return false, errs.ErrServerError
}
defer helper.RollbackOnError(err)
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
a.log.Error(
"Failed to get user login information",
zap.Error(err))
return false, errs.ErrServerError
}
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
a.log.Warn(
"Provided invalid old password while changing password",
zap.String("username", uinfo.Username))
return false, errs.ErrForbidden
}
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
a.log.Error(
"Failed to hash new password while changing password",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
Username: uinfo.Username,
PasswordHash: newPasswordHash,
}); if err != nil {
a.log.Error(
"Failed to save new password into database",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
if err := helper.Commit(); err != nil {
a.log.Error(
"Failed to commit transaction",
zap.Error(err))
return false, errs.ErrServerError
}
return true, nil
}

View File

@@ -22,8 +22,8 @@ import "errors"
func ErrorIsOneOf(err error, ignoreErrors ...error) bool { func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
for _, ignore := range ignoreErrors { for _, ignore := range ignoreErrors {
if errors.Is(err, ignore) { if errors.Is(err, ignore) {
return false
}
}
return true return true
} }
}
return false
}

View File

@@ -259,11 +259,12 @@ WHERE
terminated IS TRUE AND terminated IS TRUE AND
last_refresh_exp_time > CURRENT_TIMESTAMP; last_refresh_exp_time > CURRENT_TIMESTAMP;
;-- name: TerminateAllSessionsForUserByUsername :exec ;-- name: TerminateAllSessionsForUserByUsername :many
UPDATE sessions UPDATE sessions
SET terminated = TRUE SET terminated = TRUE
FROM users FROM users
WHERE sessions.user_id = users.id AND users.username = @username::text; WHERE sessions.user_id = users.id AND users.username = @username::text
RETURNING sessions.guid;
;-- name: PruneTerminatedSessions :exec ;-- name: PruneTerminatedSessions :exec
DELETE FROM sessions DELETE FROM sessions

View File

@@ -62,8 +62,8 @@ CREATE TABLE IF NOT EXISTS "sessions" (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
guid UUID NOT NULL DEFAULT gen_random_uuid(), guid UUID NOT NULL DEFAULT gen_random_uuid(),
name VARCHAR(100), name VARCHAR(175),
platform VARCHAR(32), platform VARCHAR(175),
latest_ip VARCHAR(16), latest_ip VARCHAR(16),
login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds', last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',