// Copyright (c) 2025 Nikolai Papin // // This file is part of Easywish // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See // the GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . package middleware import ( "context" "easywish/config" "easywish/internal/database" "easywish/internal/utils/enums" "errors" "fmt" "net/http" "time" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5" "go.uber.org/zap" ) type Claims struct { Username string `json:"username"` Role enums.Role `json:"role"` Type enums.JwtTokenType `json:"type"` Session string `json:"session"` jwt.RegisteredClaims } // XXX: cluttered; move cache & database check to auth service func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc { return func(c *gin.Context) { cfg := config.GetConfig() authHeader := c.GetHeader("Authorization") if authHeader == "" { c.Set("username", nil) c.Set("role", enums.GuestRole) c.Next() return } tokenString := authHeader token, err := jwt.ParseWithClaims( tokenString, &Claims{}, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(cfg.JwtSecret), nil }, ) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) } return } if claims, ok := token.Claims.(*Claims); ok && token.Valid { if claims.Type != enums.JwtAccessTokenType { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"}) return } ctx := context.TODO() isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() if redisErr != nil && redisErr != redis.Nil { log.Error( "Failed to lookup cache to check whether session is not terminated", zap.Error(redisErr)) c.AbortWithStatus(http.StatusInternalServerError) return } // Cache if nil if redisErr == redis.Nil { db := database.NewDbHelper(dbctx) session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) if err != nil { if errors.Is(err, pgx.ErrNoRows) { log.Warn( "Session does not exist or was deleted", zap.String("session", claims.Session)) c.AbortWithStatus(http.StatusUnauthorized) return } log.Error( "Failed to lookup session in database", zap.String("session", claims.Session), zap.Error(err)) c.AbortWithStatus(http.StatusInternalServerError) return } if err := redisClient.Set( ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session), session.Terminated, time.Duration(8 * time.Hour), // XXX: magic number ).Err(); err != nil { log.Error( "Failed to cache session's is_terminated state", zap.String("session", claims.Session), zap.Error(err)) c.AbortWithStatus(http.StatusInternalServerError) return } isTerminated = *session.Terminated } if isTerminated { log.Warn( "Attempt to access resource from a terminated session", zap.String("session", claims.Session)) c.AbortWithStatus(http.StatusUnauthorized) return } c.Set("username", claims.Username) c.Set("role", claims.Role) c.Next() } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"}) } } }