// Copyright (c) 2025 Nikolai Papin // // This file is part of Easywish // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See // the GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . package middleware import ( "context" "easywish/config" "easywish/internal/database" "easywish/internal/dto" "easywish/internal/utils/enums" "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5" "go.uber.org/zap" ) // XXX: cluttered; move cache & database check to auth service func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc { return func(c *gin.Context) { cfg := config.GetConfig() authHeader := c.GetHeader("Authorization") if authHeader == "" { c.Set("session_info", dto.SessionInfo{ Username: "", Session: "", Role: enums.GuestRole}, ) c.Next() return } tokenString := authHeader token, err := jwt.ParseWithClaims( tokenString, &dto.UserClaims{}, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(cfg.JwtSecret), nil }, ) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) } return } if claims, ok := token.Claims.(*dto.UserClaims); ok && token.Valid { if claims.Type != enums.JwtAccessTokenType { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"}) return } ctx := context.TODO() isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() if redisErr != nil && redisErr != redis.Nil { log.Error( "Failed to lookup cache to check whether session is not terminated", zap.Error(redisErr)) c.AbortWithStatus(http.StatusInternalServerError) return } // Cache if nil if redisErr == redis.Nil { db := database.NewDbHelper(dbctx) session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) if err != nil { if errors.Is(err, pgx.ErrNoRows) { log.Warn( "Session does not exist or was deleted", zap.String("session", claims.Session)) c.AbortWithStatus(http.StatusUnauthorized) return } log.Error( "Failed to lookup session in database", zap.String("session", claims.Session), zap.Error(err)) c.AbortWithStatus(http.StatusInternalServerError) return } if err := redisClient.Set( ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session), session.Terminated, time.Duration(8 * time.Hour), // XXX: magic number ).Err(); err != nil { log.Error( "Failed to cache session's is_terminated state", zap.String("session", claims.Session), zap.Error(err)) c.AbortWithStatus(http.StatusInternalServerError) return } isTerminated = *session.Terminated } if isTerminated { log.Warn( "Attempt to access resource from a terminated session", zap.String("session", claims.Session)) c.AbortWithStatus(http.StatusUnauthorized) return } c.Set("session_info", dto.SessionInfo{ Username: claims.Username, Session: claims.Session, Role: claims.Role, }) c.Next() } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"}) } } }