refactor: Simplify AuthMiddleware;

refactor: Move token validation logic to AuthService;
refactor: Remove Redis cache checks from middleware;
fix: Improve error handling for token validation;
refactor: Update Refresh method to use new validation logic;
chore: Clean up unused imports and comments
This commit is contained in:
2025-07-15 22:37:41 +03:00
parent a582b75c82
commit e465da6854
3 changed files with 43 additions and 192 deletions

View File

@@ -18,27 +18,19 @@
package middleware
import (
"context"
"easywish/config"
"easywish/internal/database"
"easywish/internal/dto"
"easywish/internal/services"
"easywish/internal/utils/enums"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
errs "easywish/internal/errors"
)
// XXX: cluttered; move cache & database check to auth service
func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc {
func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.GetConfig()
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
@@ -53,101 +45,26 @@ func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redi
}
tokenString := authHeader
token, err := jwt.ParseWithClaims(
tokenString,
&dto.UserClaims{},
func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(cfg.JwtSecret), nil
},
)
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
if sessionInfo, err := auth.ValidateToken(tokenString, enums.JwtAccessTokenType); err != nil {
if errors.Is(err, errs.ErrTokenExpired) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
} else if errors.Is(err, errs.ErrTokenInvalid) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
} else if errors.Is(err, errs.ErrWrongTokenType) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
} else if errors.Is(err, errs.ErrSessionNotFound) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
} else if errors.Is(err, errs.ErrSessionTerminated) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
}
return
}
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid {
if claims.Type != enums.JwtAccessTokenType {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
return
}
ctx := context.TODO()
isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil {
log.Error(
"Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
// Cache if nil
if redisErr == redis.Nil {
db := database.NewDbHelper(dbctx)
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
log.Warn(
"Session does not exist or was deleted",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Error(
"Failed to lookup session in database",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if err := redisClient.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
).Err(); err != nil {
log.Error(
"Failed to cache session's is_terminated state",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
isTerminated = *session.Terminated
}
if isTerminated {
log.Warn(
"Attempt to access resource from a terminated session",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Set("session_info", dto.SessionInfo{
Username: claims.Username,
Session: claims.Session,
Role: claims.Role,
})
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"})
c.Set("session_info", sessionInfo)
c.Next()
}
return
}
}