From a582b75c825706428eab9165ca45b78567a1d593 Mon Sep 17 00:00:00 2001 From: Nikolai Papin Date: Tue, 15 Jul 2025 21:59:05 +0300 Subject: [PATCH] feat: new ValidateToken method for AuthService, based on code from the monolithic implementation of auth middleware; feat: add detailed authentication error types; --- backend/internal/errors/auth.go | 6 +++ backend/internal/services/auth.go | 80 ++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/backend/internal/errors/auth.go b/backend/internal/errors/auth.go index 427f259..51a707a 100644 --- a/backend/internal/errors/auth.go +++ b/backend/internal/errors/auth.go @@ -29,4 +29,10 @@ var ( ErrInvalidCredentials = errors.New("Invalid username, password or TOTP code") ErrInvalidToken = errors.New("Token is invalid or expired") ErrServerError = errors.New("Internal server error") + + ErrTokenExpired = errors.New("Token is expired") + ErrTokenInvalid = errors.New("Token is invalid") + ErrWrongTokenType = errors.New("Invalid token type") + ErrSessionNotFound = errors.New("Could not find session in database") + ErrSessionTerminated = errors.New("Session is terminated") ) diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index f14d6fe..430c425 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -601,7 +601,85 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres return nil, errs.ErrNotImplemented } -func (a *authServiceImpl) ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { +func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { + + var err error + + token, err := jwt.ParseWithClaims( + jwtToken, + &dto.UserClaims{}, + func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(config.GetConfig().JwtSecret), nil + }, + ) + + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, errs.ErrTokenExpired + } + return nil, errs.ErrInvalidToken + } + + if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { + + if claims.Type != tokenType { + return nil, errs.ErrWrongTokenType + } + + ctx := context.TODO() + isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() + if redisErr != nil && redisErr != redis.Nil { + a.log.Error( + "Failed to lookup cache to check whether session is not terminated", + zap.Error(redisErr)) + return nil, redisErr + } + + // Cache if nil + if redisErr == redis.Nil { + db := database.NewDbHelper(a.dbctx) + + session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) + if err != nil { + + if errors.Is(err, pgx.ErrNoRows) { + a.log.Warn( + "Session does not exist or was deleted", + zap.String("session", claims.Session)) + return nil, errs.ErrSessionNotFound + } + + a.log.Error( + "Failed to lookup session in database", + zap.String("session", claims.Session), + zap.Error(err)) + return nil, err + } + + if err := a.redis.Set( + ctx, + fmt.Sprintf("session::%s::is_terminated", claims.Session), + session.Terminated, + time.Duration(8 * time.Hour), // XXX: magic number + ).Err(); err != nil { + a.log.Error( + "Failed to cache session's is_terminated state", + zap.String("session", claims.Session), + zap.Error(err)) + // c.AbortWithStatus(http.StatusInternalServerError) + return nil, err + } + + isTerminated = *session.Terminated + } + + if isTerminated { + return nil, errs.ErrSessionTerminated + } + } return nil, errs.ErrNotImplemented }