diff --git a/backend/internal/database/query.sql.go b/backend/internal/database/query.sql.go index 059edfc..a99d26d 100644 --- a/backend/internal/database/query.sql.go +++ b/backend/internal/database/query.sql.go @@ -508,15 +508,22 @@ func (q *Queries) GetSessionByGuid(ctx context.Context, guid string) (Session, e return i, err } -const getUnexpiredTerminatedSessionsGuids = `-- name: GetUnexpiredTerminatedSessionsGuids :many +const getUnexpiredTerminatedSessionsGuidsPaginated = `-- name: GetUnexpiredTerminatedSessionsGuidsPaginated :many SELECT guid FROM sessions WHERE terminated IS TRUE AND last_refresh_exp_time > CURRENT_TIMESTAMP +LIMIT $1::integer +OFFSET $2 ` -func (q *Queries) GetUnexpiredTerminatedSessionsGuids(ctx context.Context) ([]pgtype.UUID, error) { - rows, err := q.db.Query(ctx, getUnexpiredTerminatedSessionsGuids) +type GetUnexpiredTerminatedSessionsGuidsPaginatedParams struct { + BatchSize int32 + Offset int64 +} + +func (q *Queries) GetUnexpiredTerminatedSessionsGuidsPaginated(ctx context.Context, arg GetUnexpiredTerminatedSessionsGuidsPaginatedParams) ([]pgtype.UUID, error) { + rows, err := q.db.Query(ctx, getUnexpiredTerminatedSessionsGuidsPaginated, arg.BatchSize, arg.Offset) if err != nil { return nil, err } diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index 356aece..a22c46c 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -57,34 +57,59 @@ type authServiceImpl struct { } func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService { - authService := &authServiceImpl{log: _log, dbctx: _dbctx, redis: _redis, smtp: _smtp} - - // Cache terminated sessions - // FIXME: review possible RAM overflow - db := database.NewDbHelper(_dbctx) - guids, err := db.Queries.GetUnexpiredTerminatedSessionsGuids(db.CTX) - if err != nil { - panic("Failed to load terminated sessions' GUIDs") - } - ctx := context.TODO() - // FIXME: review possible problems due to a large pipeline request - pipe := _redis.Pipeline() - for _, guid := range guids { - if err := pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, 0).Err(); err != nil { - panic("Failed to cache terminated session: " + err.Error()) + db := database.NewDbHelper(_dbctx) + + // Batch processing parameters + batchSize := 1000 + offset := 0 + totalCached := 0 + + for { + guids, err := db.Queries.GetUnexpiredTerminatedSessionsGuidsPaginated( + db.CTX, + database.GetUnexpiredTerminatedSessionsGuidsPaginatedParams{ + BatchSize: int32(batchSize), + Offset: int64(offset), + }, + ) + if err != nil { + panic("Failed to load terminated sessions' GUIDs: " + err.Error()) } - } - if _, err := pipe.Exec(ctx); err != nil { - panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error()) - } + // Break loop when no more records + if len(guids) == 0 { + break + } - _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) - return authService + // Process batch in Redis pipeline + pipe := _redis.Pipeline() + for _, guid := range guids { + key := fmt.Sprintf("session::%s::is_terminated", guid) + pipe.Set(ctx, key, true, time.Duration(8 * time.Hour)) // XXX: magic number + } + + if _, err := pipe.Exec(ctx); err != nil { + panic("Failed to cache terminated sessions: " + err.Error()) + } + + totalCached += len(guids) + offset += len(guids) + + _log.Info( + "Cached batch of terminated sessions", + zap.Int("batch_size", len(guids)), + zap.Int("total_cached", totalCached)) } +_log.Info("Finished caching terminated sessions", + zap.Int("total_sessions", totalCached), + ) + + return authService + } + func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error { sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil { diff --git a/sqlc/query.sql b/sqlc/query.sql index 6a60051..e7e6074 100644 --- a/sqlc/query.sql +++ b/sqlc/query.sql @@ -253,11 +253,13 @@ WHERE user_id = $1 AND terminated IS FALSE AND last_refresh_exp_time > CURRENT_TIMESTAMP; -;-- name: GetUnexpiredTerminatedSessionsGuids :many +-- name: GetUnexpiredTerminatedSessionsGuidsPaginated :many SELECT guid FROM sessions WHERE terminated IS TRUE AND - last_refresh_exp_time > CURRENT_TIMESTAMP; + last_refresh_exp_time > CURRENT_TIMESTAMP +LIMIT @batch_size::integer +OFFSET $2; ;-- name: TerminateAllSessionsForUserByUsername :many UPDATE sessions