refactor: Simplify AuthMiddleware;

refactor: Move token validation logic to AuthService;
refactor: Remove Redis cache checks from middleware;
fix: Improve error handling for token validation;
refactor: Update Refresh method to use new validation logic;
chore: Clean up unused imports and comments
This commit is contained in:
2025-07-15 22:37:41 +03:00
parent a582b75c82
commit e465da6854
3 changed files with 43 additions and 192 deletions

View File

@@ -18,27 +18,19 @@
package middleware package middleware
import ( import (
"context"
"easywish/config"
"easywish/internal/database"
"easywish/internal/dto" "easywish/internal/dto"
"easywish/internal/services"
"easywish/internal/utils/enums" "easywish/internal/utils/enums"
"errors" "errors"
"fmt"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5"
"go.uber.org/zap" "go.uber.org/zap"
errs "easywish/internal/errors"
) )
// XXX: cluttered; move cache & database check to auth service func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc {
func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
cfg := config.GetConfig()
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
@@ -53,101 +45,26 @@ func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redi
} }
tokenString := authHeader tokenString := authHeader
if sessionInfo, err := auth.ValidateToken(tokenString, enums.JwtAccessTokenType); err != nil {
token, err := jwt.ParseWithClaims( if errors.Is(err, errs.ErrTokenExpired) {
tokenString, c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
&dto.UserClaims{}, } else if errors.Is(err, errs.ErrTokenInvalid) {
func(token *jwt.Token) (any, error) { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { } else if errors.Is(err, errs.ErrWrongTokenType) {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
} } else if errors.Is(err, errs.ErrSessionNotFound) {
return []byte(cfg.JwtSecret), nil c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
}, } else if errors.Is(err, errs.ErrSessionTerminated) {
) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
} else { } else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
} }
return return
} } else {
c.Set("session_info", sessionInfo)
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid {
if claims.Type != enums.JwtAccessTokenType {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
return
}
ctx := context.TODO()
isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil {
log.Error(
"Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
// Cache if nil
if redisErr == redis.Nil {
db := database.NewDbHelper(dbctx)
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
log.Warn(
"Session does not exist or was deleted",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Error(
"Failed to lookup session in database",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if err := redisClient.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
).Err(); err != nil {
log.Error(
"Failed to cache session's is_terminated state",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
isTerminated = *session.Terminated
}
if isTerminated {
log.Warn(
"Attempt to access resource from a terminated session",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Set("session_info", dto.SessionInfo{
Username: claims.Username,
Session: claims.Session,
Role: claims.Role,
})
c.Next() c.Next()
} else { }
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"})
} return
} }
} }

View File

@@ -19,17 +19,16 @@ package routes
import ( import (
"easywish/internal/controllers" "easywish/internal/controllers"
"easywish/internal/database"
"easywish/internal/middleware" "easywish/internal/middleware"
"easywish/internal/services"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"go.uber.org/zap" "go.uber.org/zap"
) )
func NewRouter(engine *gin.Engine, log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client, groups []RouteGroup) *gin.Engine { func NewRouter(engine *gin.Engine, log *zap.Logger, auth services.AuthService, groups []RouteGroup) *gin.Engine {
apiGroup := engine.Group("/api") apiGroup := engine.Group("/api")
apiGroup.Use(middleware.AuthMiddleware(log, dbctx, redisClient)) apiGroup.Use(middleware.AuthMiddleware(log, auth))
for _, group := range groups { for _, group := range groups {
subgroup := apiGroup.Group(group.BasePath) subgroup := apiGroup.Group(group.BasePath)
subgroup.Use(group.Middleware...) subgroup.Use(group.Middleware...)

View File

@@ -509,96 +509,31 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
var err error sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtAccessTokenType); if err != nil {
return nil, err
token, err := jwt.ParseWithClaims(
request.RefreshToken,
&dto.UserClaims{},
func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
} }
return []byte(config.GetConfig().JwtSecret), nil
accessToken, refreshToken, err := utils.GenerateTokens(
sessionInfo.Username,
sessionInfo.Session,
sessionInfo.Role,
); if err != nil {
a.log.Error(
"Failed to generate tokens for user during refresh",
zap.String("username", sessionInfo.Username),
zap.String("session", sessionInfo.Session),
zap.Error(err))
return nil, errs.ErrServerError
}
response := models.RefreshResponse{
Tokens: models.Tokens{
AccessToken: accessToken,
RefreshToken: refreshToken,
}, },
)
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
// AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
} else {
// c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
}
return nil, errs.ErrUnauthorized
} }
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { return &response, nil
if claims.Type != enums.JwtAccessTokenType {
// c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
return nil, errs.ErrUnauthorized
}
ctx := context.TODO()
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil {
a.log.Error(
"Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr))
// c.AbortWithStatus(http.StatusInternalServerError)
return nil, errs.ErrServerError
}
// Cache if nil
if redisErr == redis.Nil {
db := database.NewDbHelper(a.dbctx)
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
a.log.Warn(
"Session does not exist or was deleted",
zap.String("session", claims.Session))
// c.AbortWithStatus(http.StatusUnauthorized)
return nil, errs.ErrUnauthorized
}
a.log.Error(
"Failed to lookup session in database",
zap.String("session", claims.Session),
zap.Error(err))
// c.AbortWithStatus(http.StatusInternalServerError)
return nil, errs.ErrServerError
}
if err := a.redis.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
).Err(); err != nil {
a.log.Error(
"Failed to cache session's is_terminated state",
zap.String("session", claims.Session),
zap.Error(err))
// c.AbortWithStatus(http.StatusInternalServerError)
return nil, errs.ErrServerError
}
isTerminated = *session.Terminated
}
if isTerminated {
a.log.Warn(
"Attempt to access resource from a terminated session",
zap.String("session", claims.Session))
// c.AbortWithStatus(http.StatusUnauthorized)
return nil, errs.ErrUnauthorized
}
}
// TODO: generate some tokens
return nil, errs.ErrNotImplemented
} }
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {