feat: add change password endpoint using old password;

feat: implement change password service method with validation;
fix: correct ErrorIsOneOf function logic to return true on match;
refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity;
refactor: update session termination to return GUIDs and cache in Redis;
fix: ensure RollbackOnError only rolls back uncommitted transactions;
fix: handle transaction commit errors properly in dbHelper;
refactor: add helper methods for session termination and registration;
refactor: pass client info to login and registration complete methods;
fix: improve token validation error handling in refresh endpoint;
refactor: update auth middleware to set session info correctly;
chore: remove unused ClientInfo DTO;
fix: correct password reset complete to use session termination helper;
refactor: adjust database queries for session management;
chore: update SQL schema and queries for sessions;
docs: update swagger docs with new endpoint and model changes
This commit is contained in:
2025-07-17 03:44:22 +03:00
parent 8b558eaf5f
commit 827928178e
14 changed files with 454 additions and 173 deletions

View File

@@ -38,6 +38,7 @@ type AuthController interface {
Refresh(c *gin.Context)
PasswordResetBegin(c *gin.Context)
PasswordResetComplete(c *gin.Context)
ChangePassword(c *gin.Context)
Router
}
@@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) {
return
}
response, err := a.auth.Login(request.Body)
response, err := a.auth.Login(request.User, request.Body)
if err != nil {
if errors.Is(err, errs.ErrForbidden) {
@@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
response, err := a.auth.Refresh(request.Body)
if err != nil {
if errors.Is(err, errs.ErrTokenExpired) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
} else if errors.Is(err, errs.ErrTokenInvalid) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
} else if errors.Is(err, errs.ErrWrongTokenType) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
} else if errors.Is(err, errs.ErrSessionNotFound) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
} else if errors.Is(err, errs.ErrSessionTerminated) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
if utils.ErrorIsOneOf(
err,
errs.ErrTokenExpired,
errs.ErrTokenInvalid,
errs.ErrInvalidToken,
errs.ErrWrongTokenType,
errs.ErrSessionNotFound,
errs.ErrSessionTerminated,
) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
} else {
c.Status(http.StatusInternalServerError)
c.JSON(http.StatusInternalServerError, err.Error())
}
return
}
@@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
return
}
response, err := a.auth.RegistrationComplete(request.Body)
response, err := a.auth.RegistrationComplete(request.User, request.Body)
if err != nil {
if errors.Is(err, errs.ErrForbidden) {
@@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
c.JSON(http.StatusOK, response)
}
// @Summary Set new password using the old password
// @Tags Auth
// @Accept json
// @Produce json
// @Security JWT
// @Param request body models.ChangePasswordRequest true " "
// @Success 200 "Password successfully changed"
// @Failure 403 "Invalid old password"
// @Router /auth/changePassword [post]
func (a *authControllerImpl) ChangePassword(c *gin.Context) {
request, ok := utils.GetRequest[models.ChangePasswordRequest](c)
if !ok {
c.Status(http.StatusBadRequest)
return
}
response, err := a.auth.ChangePassword(request.Body, request.User)
if err != nil {
if errors.Is(err, errs.ErrForbidden) {
c.Status(http.StatusForbidden)
} else {
c.Status(http.StatusInternalServerError)
}
return
}
c.JSON(http.StatusOK, response)
}
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
@@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword)
}

View File

@@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct {
TXlessQueries Queries
TX pgx.Tx
TXQueries Queries
isCommited bool
}
func NewDbHelper(dbContext DbContext) DbHelper {
@@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error {
errCommit := d.TX.Commit(d.CTX)
if errCommit != nil {
errRollback := d.TX.Rollback(d.CTX)
if errRollback != nil {
return errRollback
}
return errCommit
d.isCommited = true
}
return nil
return errCommit
}
// Rollback implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) Rollback() error {
err := d.TX.Rollback(d.CTX)
if err != nil {
return err
if d.isCommited {
return nil
}
return nil
return d.TX.Rollback(d.CTX);
}
// RollbackOnError implements DbHelperTransaction.
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
if err != nil {
if d.isCommited || err == nil {
return d.Rollback()
}
return nil

View File

@@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error {
return err
}
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many
UPDATE sessions
SET terminated = TRUE
FROM users
WHERE sessions.user_id = users.id AND users.username = $1::text
RETURNING sessions.guid
`
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error {
_, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username)
return err
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) {
rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username)
if err != nil {
return nil, err
}
defer rows.Close()
var items []pgtype.UUID
for rows.Next() {
var guid pgtype.UUID
if err := rows.Scan(&guid); err != nil {
return nil, err
}
items = append(items, guid)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateBannedUser = `-- name: UpdateBannedUser :exec

View File

@@ -31,7 +31,7 @@ var (
ErrServerError = errors.New("Internal server error")
ErrTokenExpired = errors.New("Token is expired")
ErrTokenInvalid = errors.New("Token is invalid")
ErrTokenInvalid = ErrInvalidToken
ErrWrongTokenType = errors.New("Invalid token type")
ErrSessionNotFound = errors.New("Could not find session in database")
ErrSessionTerminated = errors.New("Session is terminated")

View File

@@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc
}
return
} else {
c.Set("session_info", sessionInfo)
c.Set("session_info", *sessionInfo)
c.Next()
}

View File

@@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct {
type PasswordResetCompleteResponse struct {
Tokens
}
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"password" binding:"required" validate:"password"`
TOTP string `json:"totp"`
}
type ChangePasswordResponse struct {
Tokens
}

View File

@@ -40,19 +40,20 @@ import (
type AuthService interface {
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
Login(request models.LoginRequest) (*models.LoginResponse, error)
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
}
type authServiceImpl struct {
log *zap.Logger
log *zap.Logger
dbctx database.DbContext
redis *redis.Client
smtp SmtpService
smtp SmtpService
}
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
panic("Failed to cache terminated session: " + err.Error())
}
}
if _, err := pipe.Exec(ctx); err != nil {
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
}
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
return authService
}
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
a.log.Error(
"Failed to terminate older sessions for user trying to log in",
zap.String("username", username),
zap.Error(err))
return err
}
pipe := a.redis.Pipeline()
for _, guid := range sessionGuids {
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
}
if _, err := pipe.Exec(ctx); err != nil {
a.log.Error(
"Failed to cache terminated sessions",
zap.Error(err))
return err
}
return nil
}
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
UserID: userID,
Name: utils.NewPointer(cinfo.UserAgent),
Platform: utils.NewPointer(cinfo.UserAgent),
LatestIp: utils.NewPointer(cinfo.IP),
}); if err != nil {
a.log.Error(
"Failed to add session to database",
zap.Error(err))
return nil, err
}
a.log.Info(
"Registered a new user session",
zap.String("username", cinfo.Username),
zap.String("session", cinfo.Session))
return &session, nil
}
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
var occupationStatus database.CheckUserRegistrationAvailabilityRow
@@ -91,18 +143,17 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil {
a.log.Error(
"Failed to open a transaction",
"Failed to open a transaction",
zap.Error(err))
return false, errs.ErrServerError
}
defer helper.RollbackOnError(err)
defer helper.RollbackOnError(err)
if isInProgress, err := a.redis.Get(
isInProgress, err := a.redis.Get(
context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress",
request.Email),
).Bool(); err != nil {
request.Email),
).Bool(); if err != nil {
if err != redis.Nil {
a.log.Error(
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
@@ -113,13 +164,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
isInProgress = false
} else if isInProgress {
a.log.Warn(
"Attempted to begin registration on email that is in progress of registration or on cooldown",
"Attempted to begin registration on email that is in progress of registration or on cooldown",
zap.String("email", request.Email))
return false, errs.ErrTooManyRequests
}
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
Email: request.Email,
Email: request.Email,
Username: request.Username,
}); err != nil {
a.log.Error(
@@ -141,12 +192,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
// Falsely confirm in order to avoid disclosing registered email addresses
if err := a.redis.Set(
context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true,
time.Duration(10 * time.Minute), // XXX: magic number
time.Duration(10*time.Minute), // XXX: magic number
).Err(); err != nil {
a.log.Error(
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
zap.String("email", request.Email),
zap.Error(err))
return false, errs.ErrServerError
@@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
} else {
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
Username: request.Username,
Email: request.Email,
Email: request.Email,
}); err != nil {
a.log.Error(
"Failed to purge unverified accounts as part of registration",
@@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
}
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
UserID: user.ID,
Email: utils.NewPointer(request.Email),
UserID: user.ID,
Email: utils.NewPointer(request.Email),
PasswordHash: passwordHash, // Hashed in database
}); err != nil {
@@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
}
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
UserID: user.ID,
CodeType: int32(enums.RegistrationCodeType),
CodeHash: generatedCodeHash, // Hashed in database
UserID: user.ID,
CodeType: int32(enums.RegistrationCodeType),
CodeHash: generatedCodeHash, // Hashed in database
}); err != nil {
a.log.Error("Failed to add registration code to database", zap.Error(err))
return false, errs.ErrServerError
@@ -224,8 +275,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
if config.GetConfig().SmtpEnabled {
if err := a.smtp.SendEmail(
request.Email,
"Easywish",
request.Email,
"Easywish",
fmt.Sprintf("Your registration code is %s", generatedCode),
); err != nil {
a.log.Error(
@@ -246,12 +297,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
if err := a.redis.Set(
context.TODO(),
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
true,
time.Duration(10 * time.Minute), // XXX: magic number
time.Duration(10*time.Minute), // XXX: magic number
).Err(); err != nil {
a.log.Error(
"Failed to cache registration_in_progress state for email",
"Failed to cache registration_in_progress state for email",
zap.String("email", request.Email),
zap.Error(err))
return false, errs.ErrServerError
@@ -267,19 +318,18 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
return true, nil
}
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
var user database.User
var profile database.Profile
var session database.Session
var confirmationCode database.ConfirmationCode
var confirmationCode database.ConfirmationCode
var accessToken, refreshToken string
var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil {
a.log.Error(
"Failed to open a transaction",
"Failed to open a transaction",
zap.Error(err))
return nil, errs.ErrServerError
}
@@ -290,8 +340,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
a.log.Warn(
"Could not find user attempting to complete registration with given username",
zap.String("username", request.Username),
"Could not find user attempting to complete registration with given username",
zap.String("username", request.Username),
zap.Error(err))
return nil, errs.ErrUserNotFound
}
@@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
}
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
UserID: user.ID,
UserID: user.ID,
CodeType: int32(enums.RegistrationCodeType),
Code: request.VerificationCode,
Code: request.VerificationCode,
})
if err != nil {
@@ -327,10 +377,10 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
}
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
ID: confirmationCode.ID,
ID: confirmationCode.ID,
Used: utils.NewPointer(true),
})
if err != nil {
a.log.Error(
"Failed to update the user's registration code used state",
@@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
return nil, errs.ErrServerError
}
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
ID: user.ID,
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
ID: user.ID,
Verified: utils.NewPointer(true),
})
if err != nil {
a.log.Error("Failed to update verified status for user",
zap.String("username", user.Username),
zap.Error(err))
zap.String("username", user.Username),
zap.Error(err))
return nil, errs.ErrServerError
}
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
UserID: user.ID,
Name: request.Name,
Name: request.Name,
})
if err != nil {
a.log.Error("Failed to create profile for user",
zap.String("username", user.Username),
)
return nil, errs.ErrServerError
}
@@ -369,29 +418,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
if err != nil {
a.log.Error("Failed to create profile settings for user",
zap.String("username", user.Username),
zap.Error(err))
return nil, errs.ErrServerError
}
// TODO: session info
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID,
Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create a new session during registration, rolling back registration",
zap.String("username", user.Username),
zap.Error(err))
return nil, errs.ErrServerError
}
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
a.log.Error("", zap.Error(err))
return nil, errs.ErrServerError
}
// TODO: get user role
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
if err != nil {
a.log.Error(
@@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
zap.String("username", request.Username))
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
AccessToken: accessToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
}}
@@ -421,19 +459,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
}
// TODO: totp
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) {
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
var userRow database.GetValidUserByLoginCredentialsRow
var session database.Session
var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil {
a.log.Error(
"Failed to open a transaction",
"Failed to open a transaction",
zap.Error(err))
return nil, errs.ErrServerError
}
defer helper.RollbackOnError(err)
defer helper.RollbackOnError(err)
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
Username: request.Username,
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
}
// Until release 4, only 1 session at a time is supported
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil {
err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
a.log.Error(
"Failed to terminate older sessions for user trying to log in",
zap.String("username", request.Username),
"Failed to terminate user's sessions during login",
zap.Error(err))
return nil, errs.ErrServerError
}
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
// TODO: use actual values for session metadata
UserID: userRow.ID,
Name: utils.NewPointer("New device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
})
if err != nil {
a.log.Error(
"Failed to create session for a new login",
zap.String("username", userRow.Username),
zap.Error(err))
session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
a.log.Error("", zap.Error(err))
return nil, errs.ErrServerError
}
@@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
}
response := models.LoginResponse{Tokens: models.Tokens{
AccessToken: accessToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
}}
@@ -509,12 +534,13 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
if err != nil {
if utils.ErrorIsOneOf(
err,
err,
errs.ErrInvalidToken,
errs.ErrTokenExpired,
errs.ErrTokenExpired,
errs.ErrWrongTokenType,
errs.ErrSessionNotFound,
errs.ErrSessionTerminated,
@@ -522,19 +548,20 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
return nil, err
} else {
a.log.Error(
"Encountered an unexpected error while validating token",
"Encountered an unexpected error while validating token",
zap.Error(err))
return nil, errs.ErrServerError
}
}
}
accessToken, refreshToken, err := utils.GenerateTokens(
sessionInfo.Username,
sessionInfo.Session,
sessionInfo.Role,
); if err != nil {
)
if err != nil {
a.log.Error(
"Failed to generate tokens for user during refresh",
"Failed to generate tokens for user during refresh",
zap.String("username", sessionInfo.Username),
zap.String("session", sessionInfo.Session),
zap.Error(err))
@@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
response := models.RefreshResponse{
Tokens: models.Tokens{
AccessToken: accessToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
}
@@ -551,13 +578,13 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
return &response, nil
}
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
var err error
token, err := jwt.ParseWithClaims(
jwtToken,
&dto.UserClaims{},
jwtToken,
&dto.UserClaims{},
func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@@ -569,11 +596,12 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, errs.ErrTokenExpired
}
}
return nil, errs.ErrInvalidToken
}
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
claims, ok := token.Claims.(*dto.UserClaims)
if ok && token.Valid {
if claims.Type != tokenType {
return nil, errs.ErrWrongTokenType
@@ -583,7 +611,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil {
a.log.Error(
"Failed to lookup cache to check whether session is not terminated",
"Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr))
return nil, redisErr
}
@@ -603,20 +631,20 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
}
a.log.Error(
"Failed to lookup session in database",
"Failed to lookup session in database",
zap.String("session", claims.Session),
zap.Error(err))
return nil, err
}
if err := a.redis.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
*session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
*session.Terminated,
time.Duration(8*time.Hour), // XXX: magic number
).Err(); err != nil {
a.log.Error(
"Failed to cache session's is_terminated state",
"Failed to cache session's is_terminated state",
zap.String("session", claims.Session),
zap.Error(err))
// c.AbortWithStatus(http.StatusInternalServerError)
@@ -633,15 +661,15 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
sessionInfo := dto.SessionInfo{
Username: claims.Username,
Session: claims.Session,
Role: claims.Role,
Session: claims.Session,
Role: claims.Role,
}
return &sessionInfo, nil
}
func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) {
var user database.User
var generatedCode, hashedCode string
var err error
@@ -649,7 +677,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil {
a.log.Error(
"Failed to open a transaction",
"Failed to open a transaction",
zap.Error(err))
return false, errs.ErrServerError
}
@@ -660,7 +688,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64()
if redisErr != nil && redisErr != redis.Nil {
a.log.Error(
"Failed to get reset_cooldown state for user",
"Failed to get reset_cooldown state for user",
zap.String("email", request.Email),
zap.Error(redisErr))
return false, errs.ErrServerError
@@ -677,7 +705,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
if errors.Is(err, pgx.ErrNoRows) {
// Enable cooldown for the email despite that account does not exist
err := a.redis.Set(
ctx,
ctx,
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
time.Now().Add(10*time.Minute),
time.Duration(10*time.Minute),
@@ -687,7 +715,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
a.log.Error(
"Failed to set reset cooldown for email",
zap.Error(err))
return false, err
return false, err
}
a.log.Warn(
@@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
}
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
UserID: user.ID,
UserID: user.ID,
CodeType: int32(enums.PasswordResetCodeType),
CodeHash: hashedCode,
}); err != nil {
@@ -723,7 +751,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
}
err = a.redis.Set(
ctx,
ctx,
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
time.Now().Add(10*time.Minute),
time.Duration(10*time.Minute),
@@ -733,7 +761,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
a.log.Error(
"Failed to set reset cooldown for email. Cancelling password reset",
zap.Error(err))
return false, err
return false, err
}
if err = helper.Commit(); err != nil {
@@ -757,7 +785,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
if err != nil {
a.log.Error(
"Failed to open a transaction",
"Failed to open a transaction",
zap.Error(err))
return nil, errs.ErrServerError
}
@@ -779,13 +807,13 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
}
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
UserID: user.ID,
UserID: user.ID,
CodeType: int32(enums.PasswordResetCodeType),
Code: request.VerificationCode,
Code: request.VerificationCode,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
a.log.Warn(
"Attempted to reset password for user using incorrect confirmation code",
"Attempted to reset password for user using incorrect confirmation code",
zap.String("email", request.Email),
zap.String("username", user.Username),
zap.String("provided_code", request.VerificationCode),
@@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
}
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
ID: resetCode.ID,
ID: resetCode.ID,
Used: utils.NewPointer(true),
}); err != nil {
a.log.Error(
@@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
}
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
Username: user.Username,
Username: user.Username,
PasswordHash: hashedPassword,
}); err != nil {
a.log.Error(
@@ -828,24 +856,22 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
}
if request.LogOutSessions {
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil {
err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
a.log.Error(
"Failed to log out older sessions as part of user password reset",
zap.String("email", request.Email),
zap.String("username", user.Username),
"Failed to terminate user's sessions during login",
zap.Error(err))
return nil, errs.ErrServerError
}
}
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID,
Name: utils.NewPointer("First device"),
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
UserID: user.ID,
Name: utils.NewPointer("First device"),
Platform: utils.NewPointer("Unknown"),
LatestIp: utils.NewPointer("Unknown"),
}); err != nil {
}); if err != nil {
a.log.Error(
"Failed to create new session for user as part of user password reset",
"Failed to create new session for user as part of user password reset",
zap.String("email", request.Email),
zap.String("username", user.Username),
zap.Error(err))
@@ -854,7 +880,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
// TODO: get user role
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil {
a.log.Error(
"Failed to generate tokens as part of user password reset",
"Failed to generate tokens as part of user password reset",
zap.String("email", request.Email),
zap.String("username", user.Username),
zap.Error(err))
@@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
response := models.PasswordResetCompleteResponse{
Tokens: models.Tokens{
AccessToken: accessToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
}
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
return &response, nil
}
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
var err error
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
a.log.Error(
"Failed to open a transaction",
zap.Error(err))
return false, errs.ErrServerError
}
defer helper.RollbackOnError(err)
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
a.log.Error(
"Failed to get user login information",
zap.Error(err))
return false, errs.ErrServerError
}
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
a.log.Warn(
"Provided invalid old password while changing password",
zap.String("username", uinfo.Username))
return false, errs.ErrForbidden
}
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
a.log.Error(
"Failed to hash new password while changing password",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
Username: uinfo.Username,
PasswordHash: newPasswordHash,
}); if err != nil {
a.log.Error(
"Failed to save new password into database",
zap.String("username", uinfo.Username),
zap.Error(err))
return false, errs.ErrServerError
}
if err := helper.Commit(); err != nil {
a.log.Error(
"Failed to commit transaction",
zap.Error(err))
return false, errs.ErrServerError
}
return true, nil
}

View File

@@ -22,8 +22,8 @@ import "errors"
func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
for _, ignore := range ignoreErrors {
if errors.Is(err, ignore) {
return false
return true
}
}
return true
return false
}