From 8b558eaf5fd4db7e6b262f5a5a04bc0af512d5a6 Mon Sep 17 00:00:00 2001 From: Nikolai Papin Date: Tue, 15 Jul 2025 23:32:25 +0300 Subject: [PATCH] feat: fully implemented Refresh method; fix: Improve error handling in Refresh method for token validation; fix: Update Refresh route to use correct request model; fix: Correct request model for password reset complete route; fix: Redis pipeline error handling in AuthService constructor; fix: Refresh method wanted access token; refactor: Enhance error handling for unexpected token validation errors; refactor: Simplify claims extraction in ValidateToken method; fix: Ensure session termination state is correctly dereferenced; refactor: Return structured session info in ValidateToken method; feat: New util method to check if an error is one of multiple given ones; --- backend/internal/controllers/auth.go | 16 +++++++++---- backend/internal/services/auth.go | 36 ++++++++++++++++++++++------ backend/internal/utils/errors.go | 29 ++++++++++++++++++++++ 3 files changed, 70 insertions(+), 11 deletions(-) create mode 100644 backend/internal/utils/errors.go diff --git a/backend/internal/controllers/auth.go b/backend/internal/controllers/auth.go index 62e8701..742b093 100644 --- a/backend/internal/controllers/auth.go +++ b/backend/internal/controllers/auth.go @@ -155,8 +155,16 @@ func (a *authControllerImpl) Refresh(c *gin.Context) { response, err := a.auth.Refresh(request.Body) if err != nil { - if errors.Is(err, errs.ErrUnauthorized) { - c.Status(http.StatusUnauthorized) + if errors.Is(err, errs.ErrTokenExpired) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"}) + } else if errors.Is(err, errs.ErrTokenInvalid) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"}) + } else if errors.Is(err, errs.ErrWrongTokenType) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"}) + } else if errors.Is(err, errs.ErrSessionNotFound) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"}) + } else if errors.Is(err, errs.ErrSessionTerminated) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"}) } else { c.Status(http.StatusInternalServerError) } @@ -233,7 +241,7 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) { group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin) group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete) group.POST("/login", middleware.RequestMiddleware[models.LoginRequest](enums.GuestRole), a.Login) - group.POST("/refresh", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.UserRole), a.Refresh) + group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh) group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin) - group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.PasswordResetComplete) + group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete) } diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index 85d1c53..bfb5c9e 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -71,8 +71,8 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C // FIXME: review possible problems due to a large pipeline request pipe := _redis.Pipeline() for _, guid := range guids { - if err := pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, 0); err != nil { - panic("Failed to cache terminated session: " + err.Err().Error()) + if err := pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, 0).Err(); err != nil { + panic("Failed to cache terminated session: " + err.Error()) } } _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) @@ -509,8 +509,23 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { - sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtAccessTokenType); if err != nil { - return nil, err + sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil { + + if utils.ErrorIsOneOf( + err, + errs.ErrInvalidToken, + errs.ErrTokenExpired, + errs.ErrWrongTokenType, + errs.ErrSessionNotFound, + errs.ErrSessionTerminated, + ) { + return nil, err + } else { + a.log.Error( + "Encountered an unexpected error while validating token", + zap.Error(err)) + return nil, errs.ErrServerError + } } accessToken, refreshToken, err := utils.GenerateTokens( @@ -558,7 +573,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke return nil, errs.ErrInvalidToken } - if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { + claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid { if claims.Type != tokenType { return nil, errs.ErrWrongTokenType @@ -597,7 +612,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke if err := a.redis.Set( ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session), - session.Terminated, + *session.Terminated, time.Duration(8 * time.Hour), // XXX: magic number ).Err(); err != nil { a.log.Error( @@ -615,7 +630,14 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke return nil, errs.ErrSessionTerminated } } - return nil, errs.ErrNotImplemented + + sessionInfo := dto.SessionInfo{ + Username: claims.Username, + Session: claims.Session, + Role: claims.Role, + } + + return &sessionInfo, nil } func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) { diff --git a/backend/internal/utils/errors.go b/backend/internal/utils/errors.go new file mode 100644 index 0000000..55b0452 --- /dev/null +++ b/backend/internal/utils/errors.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 Nikolai Papin +// +// This file is part of Easywish +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See +// the GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package utils + +import "errors" + +func ErrorIsOneOf(err error, ignoreErrors ...error) bool { + for _, ignore := range ignoreErrors { + if errors.Is(err, ignore) { + return false + } + } + return true +}