From d8ea9f79c60e46d2acc208c79ca6f7310abd047f Mon Sep 17 00:00:00 2001 From: Nikolai Papin Date: Mon, 14 Jul 2025 20:44:30 +0300 Subject: [PATCH] feat: add session expiration tracking and validation feat: implement Redis caching for terminated sessions feat: add new session GUID queries for validation refactor: extend Session model with last_refresh_exp_time refactor: update token generation to include role and session refactor: modify auth middleware to validate session status refactor: replace GetUserSessions with GetValidUserSessions chore: add uuid/v5 dependency fix: update router to pass dependencies to auth middleware chore: update SQL schema and queries for new expiration field --- backend/go.mod | 1 + backend/go.sum | 2 + backend/internal/database/models.go | 19 ++-- backend/internal/database/query.sql.go | 147 +++++++++++++++++-------- backend/internal/middleware/auth.go | 81 +++++++++++++- backend/internal/routes/router.go | 7 +- backend/internal/services/auth.go | 31 +++++- backend/internal/utils/jwt.go | 12 +- sqlc/query.sql | 21 +++- sqlc/schema.sql | 1 + 10 files changed, 248 insertions(+), 74 deletions(-) diff --git a/backend/go.mod b/backend/go.mod index 09e54d1..2e3c5b9 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -34,6 +34,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/gofrs/uuid/v5 v5.3.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 58bd2be..a3584bd 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -49,6 +49,8 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0= +github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= diff --git a/backend/internal/database/models.go b/backend/internal/database/models.go index 48d0f2d..10a2833 100644 --- a/backend/internal/database/models.go +++ b/backend/internal/database/models.go @@ -63,15 +63,16 @@ type ProfileSetting struct { } type Session struct { - ID int64 - UserID int64 - Guid pgtype.UUID - Name *string - Platform *string - LatestIp *string - LoginTime pgtype.Timestamp - LastSeenDate pgtype.Timestamp - Terminated *bool + ID int64 + UserID int64 + Guid pgtype.UUID + Name *string + Platform *string + LatestIp *string + LoginTime pgtype.Timestamp + LastRefreshExpTime pgtype.Timestamp + LastSeenDate pgtype.Timestamp + Terminated *bool } type User struct { diff --git a/backend/internal/database/query.sql.go b/backend/internal/database/query.sql.go index e36d045..fd7e669 100644 --- a/backend/internal/database/query.sql.go +++ b/backend/internal/database/query.sql.go @@ -201,7 +201,7 @@ func (q *Queries) CreateProfileSettings(ctx context.Context, profileID int64) (P const createSession = `-- name: CreateSession :one INSERT INTO sessions(user_id, name, platform, latest_ip) -VALUES ($1, $2, $3, $4) RETURNING id, user_id, guid, name, platform, latest_ip, login_time, last_seen_date, terminated +VALUES ($1, $2, $3, $4) RETURNING id, user_id, guid, name, platform, latest_ip, login_time, last_refresh_exp_time, last_seen_date, terminated ` type CreateSessionParams struct { @@ -227,6 +227,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S &i.Platform, &i.LatestIp, &i.LoginTime, + &i.LastRefreshExpTime, &i.LastSeenDate, &i.Terminated, ) @@ -484,6 +485,56 @@ func (q *Queries) GetProfilesRestricted(ctx context.Context, arg GetProfilesRest return items, nil } +const getSessionByGuid = `-- name: GetSessionByGuid :one +SELECT id, user_id, guid, name, platform, latest_ip, login_time, last_refresh_exp_time, last_seen_date, terminated FROM sessions +WHERE guid = ($1::text)::uuid +` + +func (q *Queries) GetSessionByGuid(ctx context.Context, guid string) (Session, error) { + row := q.db.QueryRow(ctx, getSessionByGuid, guid) + var i Session + err := row.Scan( + &i.ID, + &i.UserID, + &i.Guid, + &i.Name, + &i.Platform, + &i.LatestIp, + &i.LoginTime, + &i.LastRefreshExpTime, + &i.LastSeenDate, + &i.Terminated, + ) + return i, err +} + +const getUnexpiredTerminatedSessionsGuids = `-- name: GetUnexpiredTerminatedSessionsGuids :many +SELECT guid FROM sessions +WHERE + terminated IS TRUE AND + last_refresh_exp_time > CURRENT_TIMESTAMP +` + +func (q *Queries) GetUnexpiredTerminatedSessionsGuids(ctx context.Context) ([]pgtype.UUID, error) { + rows, err := q.db.Query(ctx, getUnexpiredTerminatedSessionsGuids) + if err != nil { + return nil, err + } + defer rows.Close() + var items []pgtype.UUID + for rows.Next() { + var guid pgtype.UUID + if err := rows.Scan(&guid); err != nil { + return nil, err + } + items = append(items, guid) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getUser = `-- name: GetUser :one SELECT id, username, verified, registration_date, deleted FROM users WHERE id = $1 @@ -608,41 +659,6 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, return i, err } -const getUserSessions = `-- name: GetUserSessions :many -SELECT id, user_id, guid, name, platform, latest_ip, login_time, last_seen_date, terminated FROM sessions -WHERE user_id = $1 AND terminated IS FALSE -` - -func (q *Queries) GetUserSessions(ctx context.Context, userID int64) ([]Session, error) { - rows, err := q.db.Query(ctx, getUserSessions, userID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Session - for rows.Next() { - var i Session - if err := rows.Scan( - &i.ID, - &i.UserID, - &i.Guid, - &i.Name, - &i.Platform, - &i.LatestIp, - &i.LoginTime, - &i.LastSeenDate, - &i.Terminated, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getValidConfirmationCodeByCode = `-- name: GetValidConfirmationCodeByCode :one SELECT id, user_id, code_type, code_hash, expires_at, used, deleted FROM confirmation_codes WHERE @@ -778,6 +794,44 @@ func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetVal return i, err } +const getValidUserSessions = `-- name: GetValidUserSessions :many +SELECT id, user_id, guid, name, platform, latest_ip, login_time, last_refresh_exp_time, last_seen_date, terminated FROM sessions +WHERE + user_id = $1 AND terminated IS FALSE AND + last_refresh_exp_time > CURRENT_TIMESTAMP +` + +func (q *Queries) GetValidUserSessions(ctx context.Context, userID int64) ([]Session, error) { + rows, err := q.db.Query(ctx, getValidUserSessions, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.Guid, + &i.Name, + &i.Platform, + &i.LatestIp, + &i.LoginTime, + &i.LastRefreshExpTime, + &i.LastSeenDate, + &i.Terminated, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const pruneExpiredConfirmationCodes = `-- name: PruneExpiredConfirmationCodes :exec DELETE FROM confirmation_codes WHERE expires_at < CURRENT_TIMESTAMP @@ -975,19 +1029,21 @@ SET platform = COALESCE($3, platform), latest_ip = COALESCE($4, latest_ip), login_time = COALESCE($5, login_time), - last_seen_date = COALESCE($6, last_seen_date), - terminated = COALESCE($7, terminated) + last_refresh_exp_time = COALESCE($6, last_refresh_exp_time), + last_seen_date = COALESCE($7, last_seen_date), + terminated = COALESCE($8, terminated) WHERE id = $1 ` type UpdateSessionParams struct { - ID int64 - Name *string - Platform *string - LatestIp *string - LoginTime pgtype.Timestamp - LastSeenDate pgtype.Timestamp - Terminated *bool + ID int64 + Name *string + Platform *string + LatestIp *string + LoginTime pgtype.Timestamp + LastRefreshExpTime pgtype.Timestamp + LastSeenDate pgtype.Timestamp + Terminated *bool } func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) error { @@ -997,6 +1053,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) er arg.Platform, arg.LatestIp, arg.LoginTime, + arg.LastRefreshExpTime, arg.LastSeenDate, arg.Terminated, ) diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 889f1b8..284381b 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -18,25 +18,32 @@ package middleware import ( + "context" "easywish/config" + "easywish/internal/database" "easywish/internal/utils/enums" "errors" "fmt" "net/http" + "time" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v5" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" ) type Claims struct { - Username string `json:"username"` - Role enums.Role `json:"role"` + Username string `json:"username"` + Role enums.Role `json:"role"` + Type enums.JwtTokenType `json:"type"` + Session string `json:"session"` jwt.RegisteredClaims } -// TODO: validate token type -// TODO: validate session guid -func AuthMiddleware() gin.HandlerFunc { +// XXX: cluttered; move cache & database check to auth service +func AuthMiddleware(log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client) gin.HandlerFunc { return func(c *gin.Context) { cfg := config.GetConfig() authHeader := c.GetHeader("Authorization") @@ -71,6 +78,70 @@ func AuthMiddleware() gin.HandlerFunc { } if claims, ok := token.Claims.(*Claims); ok && token.Valid { + + if claims.Type != enums.JwtAccessTokenType { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Not an access token"}) + return + } + + ctx := context.TODO() + isTerminated, redisErr := redisClient.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() + if redisErr != nil && redisErr != redis.Nil { + log.Error( + "Failed to lookup cache to check whether session is not terminated", + zap.Error(redisErr)) + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + // Cache if nil + if redisErr == redis.Nil { + db := database.NewDbHelper(dbctx) + + session, err := db.Queries.GetSessionByGuid(db.CTX, claims.Session) + if err != nil { + + if errors.Is(err, pgx.ErrNoRows) { + log.Warn( + "Session does not exist or was deleted", + zap.String("session", claims.Session)) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + log.Error( + "Failed to lookup session in database", + zap.String("session", claims.Session), + zap.Error(err)) + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + if err := redisClient.Set( + ctx, + fmt.Sprintf("session::%s::is_terminated", claims.Session), + session.Terminated, + time.Duration(8 * time.Hour), // XXX: magic number + ).Err(); err != nil { + log.Error( + "Failed to cache session's is_terminated state", + zap.String("session", claims.Session), + zap.Error(err)) + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + isTerminated = *session.Terminated + } + + if isTerminated { + log.Warn( + "Attempt to access resource from a terminated session", + zap.String("session", claims.Session)) + c.AbortWithStatus(http.StatusUnauthorized) + return + } + c.Set("username", claims.Username) c.Set("role", claims.Role) c.Next() diff --git a/backend/internal/routes/router.go b/backend/internal/routes/router.go index aa60339..ab2d029 100644 --- a/backend/internal/routes/router.go +++ b/backend/internal/routes/router.go @@ -19,14 +19,17 @@ package routes import ( "easywish/internal/controllers" + "easywish/internal/database" "easywish/internal/middleware" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "go.uber.org/zap" ) -func NewRouter(engine *gin.Engine, groups []RouteGroup) *gin.Engine { +func NewRouter(engine *gin.Engine, log *zap.Logger, dbctx database.DbContext, redisClient *redis.Client, groups []RouteGroup) *gin.Engine { apiGroup := engine.Group("/api") - apiGroup.Use(middleware.AuthMiddleware()) + apiGroup.Use(middleware.AuthMiddleware(log, dbctx, redisClient)) for _, group := range groups { subgroup := apiGroup.Group(group.BasePath) subgroup.Use(group.Middleware...) diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index 81bdc02..cf8676e 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -53,7 +53,27 @@ type authServiceImpl struct { } func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService { - return &authServiceImpl{log: _log, dbctx: _dbctx, redis: _redis, smtp: _smtp} + + authService := &authServiceImpl{log: _log, dbctx: _dbctx, redis: _redis, smtp: _smtp} + + // Cache terminated sessions + // FIXME: review possible RAM overflow + db := database.NewDbHelper(_dbctx) + guids, err := db.Queries.GetUnexpiredTerminatedSessionsGuids(db.CTX) + if err != nil { + panic("Failed to load terminated sessions' GUIDs") + } + + ctx := context.TODO() + // FIXME: review possible problems due to a large pipeline request + pipe := _redis.Pipeline() + for _, guid := range guids { + if err := pipe.Set(ctx, fmt.Sprint("session:%s:is_terminated", guid), true, 0); err != nil { + panic("Failed to cache terminated session: " + err.Err().Error()) + } + } + _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) + return authService } func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) { @@ -324,7 +344,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple return nil, errs.ErrServerError } - accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String()) + // TODO: get user role + accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole) if err != nil { a.log.Error( @@ -415,7 +436,8 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo return nil, errs.ErrServerError } - accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String()) + // TODO: get user role + accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String(), enums.UserRole) if err != nil { a.log.Error( "Failed to generate tokens for a new login", @@ -654,7 +676,8 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp zap.Error(err)) } - if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String()); err != nil { + // TODO: get user role + if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil { a.log.Error( "Failed to generate tokens as part of user password reset", zap.String("email", request.Email), diff --git a/backend/internal/utils/jwt.go b/backend/internal/utils/jwt.go index 6b38273..363052e 100644 --- a/backend/internal/utils/jwt.go +++ b/backend/internal/utils/jwt.go @@ -25,22 +25,24 @@ import ( "github.com/golang-jwt/jwt/v5" ) -func GenerateTokens(username string, sessionGuid string) (accessToken, refreshToken string, err error) { +func GenerateTokens(username string, sessionGuid string, role enums.Role) (accessToken, refreshToken string, err error) { cfg := config.GetConfig() accessClaims := jwt.MapClaims{ "username": username, - "guid": sessionGuid, + "role": role, + "session": sessionGuid, "type": enums.JwtAccessTokenType, - "exp": time.Now().Add(time.Minute * time.Duration(cfg.JwtExpAccess)).Unix(), + "exp": time.Now().Add(time.Minute * time.Duration(cfg.JwtExpAccess)).Unix(), } accessToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims).SignedString([]byte(cfg.JwtSecret)) refreshClaims := jwt.MapClaims{ "username": username, - "guid": sessionGuid, + "role": role, + "session": sessionGuid, "type": enums.JwtRefreshTokenType, - "exp": time.Now().Add(time.Hour * time.Duration(cfg.JwtExpRefresh)).Unix(), + "exp": time.Now().Add(time.Hour * time.Duration(cfg.JwtExpRefresh)).Unix(), } refreshToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims).SignedString([]byte(cfg.JwtSecret)) diff --git a/sqlc/query.sql b/sqlc/query.sql index b91c5a5..a640c10 100644 --- a/sqlc/query.sql +++ b/sqlc/query.sql @@ -238,13 +238,26 @@ SET platform = COALESCE($3, platform), latest_ip = COALESCE($4, latest_ip), login_time = COALESCE($5, login_time), - last_seen_date = COALESCE($6, last_seen_date), - terminated = COALESCE($7, terminated) + last_refresh_exp_time = COALESCE($6, last_refresh_exp_time), + last_seen_date = COALESCE($7, last_seen_date), + terminated = COALESCE($8, terminated) WHERE id = $1; -;-- name: GetUserSessions :many +;-- name: GetSessionByGuid :one SELECT * FROM sessions -WHERE user_id = $1 AND terminated IS FALSE; +WHERE guid = (@guid::text)::uuid; + +;-- name: GetValidUserSessions :many +SELECT * FROM sessions +WHERE + user_id = $1 AND terminated IS FALSE AND + last_refresh_exp_time > CURRENT_TIMESTAMP; + +;-- name: GetUnexpiredTerminatedSessionsGuids :many +SELECT guid FROM sessions +WHERE + terminated IS TRUE AND + last_refresh_exp_time > CURRENT_TIMESTAMP; ;-- name: TerminateAllSessionsForUserByUsername :exec UPDATE sessions diff --git a/sqlc/schema.sql b/sqlc/schema.sql index cb9b186..cd60b90 100644 --- a/sqlc/schema.sql +++ b/sqlc/schema.sql @@ -66,6 +66,7 @@ CREATE TABLE IF NOT EXISTS "sessions" ( platform VARCHAR(32), latest_ip VARCHAR(16), login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds', last_seen_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, terminated BOOLEAN DEFAULT FALSE );