From e465da68549eee39af36bd24bb9b7e0ce8f4c449 Mon Sep 17 00:00:00 2001 From: Nikolai Papin Date: Tue, 15 Jul 2025 22:37:41 +0300 Subject: [PATCH] refactor: Simplify AuthMiddleware; refactor: Move token validation logic to AuthService; refactor: Remove Redis cache checks from middleware; fix: Improve error handling for token validation; refactor: Update Refresh method to use new validation logic; chore: Clean up unused imports and comments --- backend/internal/middleware/auth.go | 121 +++++----------------------- backend/internal/routes/router.go | 7 +- backend/internal/services/auth.go | 107 +++++------------------- 3 files changed, 43 insertions(+), 192 deletions(-) diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 358d76f..2dc0874 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -18,27 +18,19 @@ package middleware import ( - "context" - "easywish/config" - "easywish/internal/database" "easywish/internal/dto" + "easywish/internal/services" "easywish/internal/utils/enums" "errors" - "fmt" "net/http" - "time" "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" - "github.com/golang-jwt/jwt/v5" - "github.com/jackc/pgx/v5" "go.uber.org/zap" + errs "easywish/internal/errors" ) -// XXX: cluttered; move cache & database check to auth service -func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc { +func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc { return func(c *gin.Context) { - cfg := config.GetConfig() authHeader := c.GetHeader("Authorization") if authHeader == "" { @@ -53,101 +45,26 @@ func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redi } tokenString := authHeader - - token, err := jwt.ParseWithClaims( - tokenString, - &dto.UserClaims{}, - func(token *jwt.Token) (any, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return []byte(cfg.JwtSecret), nil - }, - ) - - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) + if sessionInfo, err := auth.ValidateToken(tokenString, enums.JwtAccessTokenType); err != nil { + if errors.Is(err, errs.ErrTokenExpired) { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"}) + } else if errors.Is(err, errs.ErrTokenInvalid) { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"}) + } else if errors.Is(err, errs.ErrWrongTokenType) { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"}) + } else if errors.Is(err, errs.ErrSessionNotFound) { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"}) + } else if errors.Is(err, errs.ErrSessionTerminated) { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"}) } else { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) } return - } - - if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { - - if claims.Type != enums.JwtAccessTokenType { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"}) - return - } - - ctx := context.TODO() - isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() - if redisErr != nil && redisErr != redis.Nil { - log.Error( - "Failed to lookup cache to check whether session is not terminated", - zap.Error(redisErr)) - c.AbortWithStatus(http.StatusInternalServerError) - return - } - - // Cache if nil - if redisErr == redis.Nil { - db := database.NewDbHelper(dbctx) - - session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) - if err != nil { - - if errors.Is(err, pgx.ErrNoRows) { - log.Warn( - "Session does not exist or was deleted", - zap.String("session", claims.Session)) - c.AbortWithStatus(http.StatusUnauthorized) - return - } - - log.Error( - "Failed to lookup session in database", - zap.String("session", claims.Session), - zap.Error(err)) - c.AbortWithStatus(http.StatusInternalServerError) - return - } - - if err := redisClient.Set( - ctx, - fmt.Sprintf("session::%s::is_terminated", claims.Session), - session.Terminated, - time.Duration(8 * time.Hour), // XXX: magic number - ).Err(); err != nil { - log.Error( - "Failed to cache session's is_terminated state", - zap.String("session", claims.Session), - zap.Error(err)) - c.AbortWithStatus(http.StatusInternalServerError) - return - } - - isTerminated = *session.Terminated - } - - if isTerminated { - log.Warn( - "Attempt to access resource from a terminated session", - zap.String("session", claims.Session)) - c.AbortWithStatus(http.StatusUnauthorized) - return - } - - c.Set("session_info", dto.SessionInfo{ - Username: claims.Username, - Session: claims.Session, - Role: claims.Role, - }) - - c.Next() } else { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"}) + c.Set("session_info", sessionInfo) + c.Next() } + + return } } diff --git a/backend/internal/routes/router.go b/backend/internal/routes/router.go index ab2d029..7a263f7 100644 --- a/backend/internal/routes/router.go +++ b/backend/internal/routes/router.go @@ -19,17 +19,16 @@ package routes import ( "easywish/internal/controllers" - "easywish/internal/database" "easywish/internal/middleware" + "easywish/internal/services" "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" "go.uber.org/zap" ) -func NewRouter(engine *gin.Engine, log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client, groups []RouteGroup) *gin.Engine { +func NewRouter(engine *gin.Engine, log *zap.Logger, auth services.AuthService, groups []RouteGroup) *gin.Engine { apiGroup := engine.Group("/api") - apiGroup.Use(middleware.AuthMiddleware(log, dbctx, redisClient)) + apiGroup.Use(middleware.AuthMiddleware(log, auth)) for _, group := range groups { subgroup := apiGroup.Group(group.BasePath) subgroup.Use(group.Middleware...) diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index 430c425..85d1c53 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -509,96 +509,31 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { - var err error + sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtAccessTokenType); if err != nil { + return nil, err + } - token, err := jwt.ParseWithClaims( - request.RefreshToken, - &dto.UserClaims{}, - func(token *jwt.Token) (any, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return []byte(config.GetConfig().JwtSecret), nil + accessToken, refreshToken, err := utils.GenerateTokens( + sessionInfo.Username, + sessionInfo.Session, + sessionInfo.Role, + ); if err != nil { + a.log.Error( + "Failed to generate tokens for user during refresh", + zap.String("username", sessionInfo.Username), + zap.String("session", sessionInfo.Session), + zap.Error(err)) + return nil, errs.ErrServerError + } + + response := models.RefreshResponse{ + Tokens: models.Tokens{ + AccessToken: accessToken, + RefreshToken: refreshToken, }, - ) - - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - // AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) - } else { - // c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) - } - return nil, errs.ErrUnauthorized } - if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { - - if claims.Type != enums.JwtAccessTokenType { - // c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"}) - return nil, errs.ErrUnauthorized - } - - ctx := context.TODO() - isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() - if redisErr != nil && redisErr != redis.Nil { - a.log.Error( - "Failed to lookup cache to check whether session is not terminated", - zap.Error(redisErr)) - // c.AbortWithStatus(http.StatusInternalServerError) - return nil, errs.ErrServerError - } - - // Cache if nil - if redisErr == redis.Nil { - db := database.NewDbHelper(a.dbctx) - - session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) - if err != nil { - - if errors.Is(err, pgx.ErrNoRows) { - a.log.Warn( - "Session does not exist or was deleted", - zap.String("session", claims.Session)) - // c.AbortWithStatus(http.StatusUnauthorized) - return nil, errs.ErrUnauthorized - } - - a.log.Error( - "Failed to lookup session in database", - zap.String("session", claims.Session), - zap.Error(err)) - // c.AbortWithStatus(http.StatusInternalServerError) - return nil, errs.ErrServerError - } - - if err := a.redis.Set( - ctx, - fmt.Sprintf("session::%s::is_terminated", claims.Session), - session.Terminated, - time.Duration(8 * time.Hour), // XXX: magic number - ).Err(); err != nil { - a.log.Error( - "Failed to cache session's is_terminated state", - zap.String("session", claims.Session), - zap.Error(err)) - // c.AbortWithStatus(http.StatusInternalServerError) - return nil, errs.ErrServerError - } - - isTerminated = *session.Terminated - } - - if isTerminated { - a.log.Warn( - "Attempt to access resource from a terminated session", - zap.String("session", claims.Session)) - // c.AbortWithStatus(http.StatusUnauthorized) - return nil, errs.ErrUnauthorized - } - } - // TODO: generate some tokens - - return nil, errs.ErrNotImplemented + return &response, nil } func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {