Files
easywish/backend/internal/middleware/auth.go
Nikolai Papin d8ea9f79c6 feat: add session expiration tracking and validation
feat: implement Redis caching for terminated sessions
feat: add new session GUID queries for validation
refactor: extend Session model with last_refresh_exp_time
refactor: update token generation to include role and session
refactor: modify auth middleware to validate session status
refactor: replace GetUserSessions with GetValidUserSessions
chore: add uuid/v5 dependency
fix: update router to pass dependencies to auth middleware
chore: update SQL schema and queries for new expiration field
2025-07-14 20:44:30 +03:00

153 lines
4.2 KiB
Go

// Copyright (c) 2025 Nikolai Papin
//
// This file is part of Easywish
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
// the GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package middleware
import (
"context"
"easywish/config"
"easywish/internal/database"
"easywish/internal/utils/enums"
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
)
type Claims struct {
Username string `json:"username"`
Role enums.Role `json:"role"`
Type enums.JwtTokenType `json:"type"`
Session string `json:"session"`
jwt.RegisteredClaims
}
// XXX: cluttered; move cache & database check to auth service
func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.GetConfig()
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Set("username", nil)
c.Set("role", enums.GuestRole)
c.Next()
return
}
tokenString := authHeader
token, err := jwt.ParseWithClaims(
tokenString,
&Claims{},
func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(cfg.JwtSecret), nil
},
)
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Token expired"})
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
}
return
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
if claims.Type != enums.JwtAccessTokenType {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"})
return
}
ctx := context.TODO()
isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
if redisErr != nil && redisErr != redis.Nil {
log.Error(
"Failed to lookup cache to check whether session is not terminated",
zap.Error(redisErr))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
// Cache if nil
if redisErr == redis.Nil {
db := database.NewDbHelper(dbctx)
session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
log.Warn(
"Session does not exist or was deleted",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
log.Error(
"Failed to lookup session in database",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if err := redisClient.Set(
ctx,
fmt.Sprintf("session::%s::is_terminated", claims.Session),
session.Terminated,
time.Duration(8 * time.Hour), // XXX: magic number
).Err(); err != nil {
log.Error(
"Failed to cache session's is_terminated state",
zap.String("session", claims.Session),
zap.Error(err))
c.AbortWithStatus(http.StatusInternalServerError)
return
}
isTerminated = *session.Terminated
}
if isTerminated {
log.Warn(
"Attempt to access resource from a terminated session",
zap.String("session", claims.Session))
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.Set("username", claims.Username)
c.Set("role", claims.Role)
c.Next()
} else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid claims"})
}
}
}