refactor: Simplify AuthMiddleware;
refactor: Move token validation logic to AuthService; refactor: Remove Redis cache checks from middleware; fix: Improve error handling for token validation; refactor: Update Refresh method to use new validation logic; chore: Clean up unused imports and comments
This commit is contained in:
@@ -18,27 +18,19 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"easywish/config"
|
|
||||||
"easywish/internal/database"
|
|
||||||
"easywish/internal/dto"
|
"easywish/internal/dto"
|
||||||
|
"easywish/internal/services"
|
||||||
"easywish/internal/utils/enums"
|
"easywish/internal/utils/enums"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
errs "easywish/internal/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// XXX: cluttered; move cache & database check to auth service
|
func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc {
|
||||||
func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
cfg := config.GetConfig()
|
|
||||||
authHeader := c.GetHeader("Authorization")
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
@@ -53,101 +45,26 @@ func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redi
|
|||||||
}
|
}
|
||||||
|
|
||||||
tokenString := authHeader
|
tokenString := authHeader
|
||||||
|
if sessionInfo, err := auth.ValidateToken(tokenString, enums.JwtAccessTokenType); err != nil {
|
||||||
token, err := jwt.ParseWithClaims(
|
if errors.Is(err, errs.ErrTokenExpired) {
|
||||||
tokenString,
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
|
||||||
&dto.UserClaims{},
|
} else if errors.Is(err, errs.ErrTokenInvalid) {
|
||||||
func(token *jwt.Token) (any, error) {
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
} else if errors.Is(err, errs.ErrWrongTokenType) {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
|
||||||
}
|
} else if errors.Is(err, errs.ErrSessionNotFound) {
|
||||||
return []byte(cfg.JwtSecret), nil
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
|
||||||
},
|
} else if errors.Is(err, errs.ErrSessionTerminated) {
|
||||||
)
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
|
|
||||||
} else {
|
} else {
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
} else {
|
||||||
|
c.Set("session_info", sessionInfo)
|
||||||
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid {
|
|
||||||
|
|
||||||
if claims.Type != enums.JwtAccessTokenType {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.TODO()
|
|
||||||
isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
|
|
||||||
if redisErr != nil && redisErr != redis.Nil {
|
|
||||||
log.Error(
|
|
||||||
"Failed to lookup cache to check whether session is not terminated",
|
|
||||||
zap.Error(redisErr))
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache if nil
|
|
||||||
if redisErr == redis.Nil {
|
|
||||||
db := database.NewDbHelper(dbctx)
|
|
||||||
|
|
||||||
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
log.Warn(
|
|
||||||
"Session does not exist or was deleted",
|
|
||||||
zap.String("session", claims.Session))
|
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Error(
|
|
||||||
"Failed to lookup session in database",
|
|
||||||
zap.String("session", claims.Session),
|
|
||||||
zap.Error(err))
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := redisClient.Set(
|
|
||||||
ctx,
|
|
||||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
|
||||||
session.Terminated,
|
|
||||||
time.Duration(8 * time.Hour), // XXX: magic number
|
|
||||||
).Err(); err != nil {
|
|
||||||
log.Error(
|
|
||||||
"Failed to cache session's is_terminated state",
|
|
||||||
zap.String("session", claims.Session),
|
|
||||||
zap.Error(err))
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isTerminated = *session.Terminated
|
|
||||||
}
|
|
||||||
|
|
||||||
if isTerminated {
|
|
||||||
log.Warn(
|
|
||||||
"Attempt to access resource from a terminated session",
|
|
||||||
zap.String("session", claims.Session))
|
|
||||||
c.AbortWithStatus(http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("session_info", dto.SessionInfo{
|
|
||||||
Username: claims.Username,
|
|
||||||
Session: claims.Session,
|
|
||||||
Role: claims.Role,
|
|
||||||
})
|
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,17 +19,16 @@ package routes
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"easywish/internal/controllers"
|
"easywish/internal/controllers"
|
||||||
"easywish/internal/database"
|
|
||||||
"easywish/internal/middleware"
|
"easywish/internal/middleware"
|
||||||
|
"easywish/internal/services"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRouter(engine *gin.Engine, log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client, groups []RouteGroup) *gin.Engine {
|
func NewRouter(engine *gin.Engine, log *zap.Logger, auth services.AuthService, groups []RouteGroup) *gin.Engine {
|
||||||
apiGroup := engine.Group("/api")
|
apiGroup := engine.Group("/api")
|
||||||
apiGroup.Use(middleware.AuthMiddleware(log, dbctx, redisClient))
|
apiGroup.Use(middleware.AuthMiddleware(log, auth))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
subgroup := apiGroup.Group(group.BasePath)
|
subgroup := apiGroup.Group(group.BasePath)
|
||||||
subgroup.Use(group.Middleware...)
|
subgroup.Use(group.Middleware...)
|
||||||
|
|||||||
@@ -509,96 +509,31 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|||||||
|
|
||||||
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
||||||
|
|
||||||
var err error
|
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtAccessTokenType); if err != nil {
|
||||||
|
return nil, err
|
||||||
token, err := jwt.ParseWithClaims(
|
|
||||||
request.RefreshToken,
|
|
||||||
&dto.UserClaims{},
|
|
||||||
func(token *jwt.Token) (any, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
}
|
||||||
return []byte(config.GetConfig().JwtSecret), nil
|
|
||||||
|
accessToken, refreshToken, err := utils.GenerateTokens(
|
||||||
|
sessionInfo.Username,
|
||||||
|
sessionInfo.Session,
|
||||||
|
sessionInfo.Role,
|
||||||
|
); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to generate tokens for user during refresh",
|
||||||
|
zap.String("username", sessionInfo.Username),
|
||||||
|
zap.String("session", sessionInfo.Session),
|
||||||
|
zap.Error(err))
|
||||||
|
return nil, errs.ErrServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
response := models.RefreshResponse{
|
||||||
|
Tokens: models.Tokens{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
},
|
},
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
||||||
// AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
|
|
||||||
} else {
|
|
||||||
// c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
|
||||||
}
|
|
||||||
return nil, errs.ErrUnauthorized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid {
|
return &response, nil
|
||||||
|
|
||||||
if claims.Type != enums.JwtAccessTokenType {
|
|
||||||
// c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
|
|
||||||
return nil, errs.ErrUnauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.TODO()
|
|
||||||
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
|
|
||||||
if redisErr != nil && redisErr != redis.Nil {
|
|
||||||
a.log.Error(
|
|
||||||
"Failed to lookup cache to check whether session is not terminated",
|
|
||||||
zap.Error(redisErr))
|
|
||||||
// c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return nil, errs.ErrServerError
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache if nil
|
|
||||||
if redisErr == redis.Nil {
|
|
||||||
db := database.NewDbHelper(a.dbctx)
|
|
||||||
|
|
||||||
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
a.log.Warn(
|
|
||||||
"Session does not exist or was deleted",
|
|
||||||
zap.String("session", claims.Session))
|
|
||||||
// c.AbortWithStatus(http.StatusUnauthorized)
|
|
||||||
return nil, errs.ErrUnauthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
a.log.Error(
|
|
||||||
"Failed to lookup session in database",
|
|
||||||
zap.String("session", claims.Session),
|
|
||||||
zap.Error(err))
|
|
||||||
// c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return nil, errs.ErrServerError
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := a.redis.Set(
|
|
||||||
ctx,
|
|
||||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
|
||||||
session.Terminated,
|
|
||||||
time.Duration(8 * time.Hour), // XXX: magic number
|
|
||||||
).Err(); err != nil {
|
|
||||||
a.log.Error(
|
|
||||||
"Failed to cache session's is_terminated state",
|
|
||||||
zap.String("session", claims.Session),
|
|
||||||
zap.Error(err))
|
|
||||||
// c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return nil, errs.ErrServerError
|
|
||||||
}
|
|
||||||
|
|
||||||
isTerminated = *session.Terminated
|
|
||||||
}
|
|
||||||
|
|
||||||
if isTerminated {
|
|
||||||
a.log.Warn(
|
|
||||||
"Attempt to access resource from a terminated session",
|
|
||||||
zap.String("session", claims.Session))
|
|
||||||
// c.AbortWithStatus(http.StatusUnauthorized)
|
|
||||||
return nil, errs.ErrUnauthorized
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO: generate some tokens
|
|
||||||
|
|
||||||
return nil, errs.ErrNotImplemented
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user