feat: fully implemented Refresh method;

fix: Improve error handling in Refresh method for token validation;
fix: Update Refresh route to use correct request model;
fix: Correct request model for password reset complete route;
fix: Redis pipeline error handling in AuthService constructor;
fix: Refresh method wanted access token;
refactor: Enhance error handling for unexpected token validation errors;
refactor: Simplify claims extraction in ValidateToken method;
fix: Ensure session termination state is correctly dereferenced;
refactor: Return structured session info in ValidateToken method;
feat: New util method to check if an error is one of multiple given ones;
This commit is contained in:
2025-07-15 23:32:25 +03:00
parent e465da6854
commit 8b558eaf5f
3 changed files with 70 additions and 11 deletions

View File

@@ -155,8 +155,16 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
response, err := a.auth.Refresh(request.Body)
if err != nil {
if errors.Is(err, errs.ErrUnauthorized) {
c.Status(http.StatusUnauthorized)
if errors.Is(err, errs.ErrTokenExpired) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
} else if errors.Is(err, errs.ErrTokenInvalid) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
} else if errors.Is(err, errs.ErrWrongTokenType) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
} else if errors.Is(err, errs.ErrSessionNotFound) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
} else if errors.Is(err, errs.ErrSessionTerminated) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
} else {
c.Status(http.StatusInternalServerError)
}
@@ -233,7 +241,7 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
group.POST("/login", middleware.RequestMiddleware[models.LoginRequest](enums.GuestRole), a.Login)
group.POST("/refresh", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.UserRole), a.Refresh)
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.PasswordResetComplete)
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
}

View File

@@ -71,8 +71,8 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
// FIXME: review possible problems due to a large pipeline request
pipe := _redis.Pipeline()
for _, guid := range guids {
if err := pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, 0); err != nil {
panic("Failed to cache terminated session: " + err.Err().Error())
if err := pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, 0).Err(); err != nil {
panic("Failed to cache terminated session: " + err.Error())
}
}
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
@@ -509,8 +509,23 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtAccessTokenType); if err != nil {
return nil, err
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
if utils.ErrorIsOneOf(
err,
errs.ErrInvalidToken,
errs.ErrTokenExpired,
errs.ErrWrongTokenType,
errs.ErrSessionNotFound,
errs.ErrSessionTerminated,
) {
return nil, err
} else {
a.log.Error(
"Encountered an unexpected error while validating token",
zap.Error(err))
return nil, errs.ErrServerError
}
}
accessToken, refreshToken, err := utils.GenerateTokens(
@@ -558,7 +573,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
return nil, errs.ErrInvalidToken
}
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid {
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
if claims.Type != tokenType {
return nil, errs.ErrWrongTokenType
@@ -597,7 +612,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
if err := a.redis.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
session.Terminated,
*session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
).Err(); err != nil {
a.log.Error(
@@ -615,7 +630,14 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
return nil, errs.ErrSessionTerminated
}
}
return nil, errs.ErrNotImplemented
sessionInfo := dto.SessionInfo{
Username: claims.Username,
Session: claims.Session,
Role: claims.Role,
}
return &sessionInfo, nil
}
func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) {

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2025 Nikolai Papin
//
// This file is part of Easywish
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
// the GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package utils
import "errors"
func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
for _, ignore := range ignoreErrors {
if errors.Is(err, ignore) {
return false
}
}
return true
}