diff --git a/backend/internal/database/models.go b/backend/internal/database/models.go index 10a2833..83b3941 100644 --- a/backend/internal/database/models.go +++ b/backend/internal/database/models.go @@ -80,5 +80,6 @@ type User struct { Username string Verified *bool RegistrationDate pgtype.Timestamp + Role int32 Deleted *bool } diff --git a/backend/internal/database/query.sql.go b/backend/internal/database/query.sql.go index a99d26d..b85e226 100644 --- a/backend/internal/database/query.sql.go +++ b/backend/internal/database/query.sql.go @@ -236,7 +236,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S const createUser = `-- name: CreateUser :one INSERT INTO users(username, verified) -VALUES ($1, false) RETURNING id, username, verified, registration_date, deleted +VALUES ($1, false) RETURNING id, username, verified, registration_date, role, deleted ` func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) { @@ -247,6 +247,7 @@ func (q *Queries) CreateUser(ctx context.Context, username string) (User, error) &i.Username, &i.Verified, &i.RegistrationDate, + &i.Role, &i.Deleted, ) return i, err @@ -264,7 +265,7 @@ WITH deleted_rows AS ( AND linfo.email = $2::text )) AND verified IS FALSE - RETURNING id, username, verified, registration_date, deleted + RETURNING id, username, verified, registration_date, role, deleted ) SELECT COUNT(*) AS deleted_count FROM deleted_rows ` @@ -543,7 +544,7 @@ func (q *Queries) GetUnexpiredTerminatedSessionsGuidsPaginated(ctx context.Conte } const getUser = `-- name: GetUser :one -SELECT id, username, verified, registration_date, deleted FROM users +SELECT id, username, verified, registration_date, role, deleted FROM users WHERE id = $1 ` @@ -555,6 +556,7 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) { &i.Username, &i.Verified, &i.RegistrationDate, + &i.Role, &i.Deleted, ) return i, err @@ -630,7 +632,7 @@ func (q *Queries) GetUserBansByUsername(ctx context.Context, username string) ([ } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT users.id, users.username, users.verified, users.registration_date, users.deleted FROM users +SELECT users.id, users.username, users.verified, users.registration_date, users.role, users.deleted FROM users JOIN login_informations linfo ON linfo.user_id = users.id WHERE linfo.email = $1::text ` @@ -643,13 +645,14 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.Username, &i.Verified, &i.RegistrationDate, + &i.Role, &i.Deleted, ) return i, err } const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, username, verified, registration_date, deleted FROM users +SELECT id, username, verified, registration_date, role, deleted FROM users WHERE username = $1 ` @@ -661,6 +664,7 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, &i.Username, &i.Verified, &i.RegistrationDate, + &i.Role, &i.Deleted, ) return i, err @@ -698,7 +702,7 @@ func (q *Queries) GetValidConfirmationCodeByCode(ctx context.Context, arg GetVal } const getValidConfirmationCodesByUsername = `-- name: GetValidConfirmationCodesByUsername :many -SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, users.deleted FROM confirmation_codes +SELECT confirmation_codes.id, user_id, code_type, code_hash, expires_at, used, confirmation_codes.deleted, users.id, username, verified, registration_date, role, users.deleted FROM confirmation_codes JOIN users on users.id = confirmation_codes.user_id WHERE users.username = $1::text AND @@ -724,6 +728,7 @@ type GetValidConfirmationCodesByUsernameRow struct { Username string Verified *bool RegistrationDate pgtype.Timestamp + Role int32 Deleted_2 *bool } @@ -748,6 +753,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G &i.Username, &i.Verified, &i.RegistrationDate, + &i.Role, &i.Deleted_2, ); err != nil { return nil, err @@ -762,8 +768,7 @@ func (q *Queries) GetValidConfirmationCodesByUsername(ctx context.Context, arg G const getValidUserByLoginCredentials = `-- name: GetValidUserByLoginCredentials :one SELECT - users.id, - users.username, + users.id, users.username, users.verified, users.registration_date, users.role, users.deleted, linfo.password_hash, linfo.totp_encrypted FROM users @@ -783,10 +788,14 @@ type GetValidUserByLoginCredentialsParams struct { } type GetValidUserByLoginCredentialsRow struct { - ID int64 - Username string - PasswordHash string - TotpEncrypted *string + ID int64 + Username string + Verified *bool + RegistrationDate pgtype.Timestamp + Role int32 + Deleted *bool + PasswordHash string + TotpEncrypted *string } func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetValidUserByLoginCredentialsParams) (GetValidUserByLoginCredentialsRow, error) { @@ -795,6 +804,10 @@ func (q *Queries) GetValidUserByLoginCredentials(ctx context.Context, arg GetVal err := row.Scan( &i.ID, &i.Username, + &i.Verified, + &i.RegistrationDate, + &i.Role, + &i.Deleted, &i.PasswordHash, &i.TotpEncrypted, ) diff --git a/backend/internal/models/auth.go b/backend/internal/models/auth.go index 9288f61..365f644 100644 --- a/backend/internal/models/auth.go +++ b/backend/internal/models/auth.go @@ -49,9 +49,8 @@ type LoginResponse struct { Tokens } -// TODO: length check type RefreshRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token" binding:"required,max=2000"` } type RefreshResponse struct { diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index a22c46c..04c884f 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -38,6 +38,11 @@ import ( "go.uber.org/zap" ) +var ( + AuthTerminatedSessionCacheDuration = time.Duration(8 * time.Hour) + AuthRegistrationCooldownCacheDuration = time.Duration(10 * time.Minute) +) + type AuthService interface { RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) @@ -87,7 +92,7 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C pipe := _redis.Pipeline() for _, guid := range guids { key := fmt.Sprintf("session::%s::is_terminated", guid) - pipe.Set(ctx, key, true, time.Duration(8 * time.Hour)) // XXX: magic number + pipe.Set(ctx, key, true, AuthTerminatedSessionCacheDuration) } if _, err := pipe.Exec(ctx); err != nil { @@ -122,7 +127,12 @@ func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, usern pipe := a.redis.Pipeline() for _, guid := range sessionGuids { - pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number + pipe.Set( + ctx, + fmt.Sprintf("session::%s::is_terminated", guid), + true, + AuthTerminatedSessionCacheDuration, + ) } if _, err := pipe.Exec(ctx); err != nil { @@ -219,7 +229,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ context.TODO(), fmt.Sprintf("email::%s::registration_in_progress", request.Email), true, - time.Duration(10*time.Minute), // XXX: magic number + AuthRegistrationCooldownCacheDuration, ).Err(); err != nil { a.log.Error( "Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration", @@ -324,7 +334,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ context.TODO(), fmt.Sprintf("email::%s::registration_in_progress", request.Email), true, - time.Duration(10*time.Minute), // XXX: magic number + AuthTerminatedSessionCacheDuration, ).Err(); err != nil { a.log.Error( "Failed to cache registration_in_progress state for email", @@ -453,8 +463,7 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod return nil, errs.ErrServerError } - // TODO: get user role - accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole) + accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.Role(user.Role)) if err != nil { a.log.Error( @@ -483,7 +492,6 @@ func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request mod return &response, nil } -// TODO: totp func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) { var userRow database.GetValidUserByLoginCredentialsRow var err error @@ -500,12 +508,7 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{ Username: request.Username, Password: request.Password, - }) - - if err != nil { - a.log.Warn( - "Failed login attempt", - zap.Error(err)) + }); if err != nil { var returnedError error @@ -516,6 +519,9 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques returnedError = errs.ErrServerError } + a.log.Warn( + "Failed login attempt", + zap.Error(err)) return nil, returnedError } @@ -532,9 +538,11 @@ func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginReques return nil, errs.ErrServerError } - // TODO: get user role - accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String(), enums.UserRole) - if err != nil { + accessToken, refreshToken, err := utils.GenerateTokens( + userRow.Username, + session.Guid.String(), + enums.Role(userRow.Role), + ); if err != nil { a.log.Error( "Failed to generate tokens for a new login", zap.String("username", userRow.Username), @@ -666,7 +674,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session), *session.Terminated, - time.Duration(8*time.Hour), // XXX: magic number + AuthTerminatedSessionCacheDuration, ).Err(); err != nil { a.log.Error( "Failed to cache session's is_terminated state", @@ -902,8 +910,11 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp zap.Error(err)) } - // TODO: get user role - if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil { + if accessToken, refreshToken, err = utils.GenerateTokens( + user.Username, + session.Guid.String(), + enums.UserRole, + ); err != nil { a.log.Error( "Failed to generate tokens as part of user password reset", zap.String("email", request.Email), diff --git a/sqlc/query.sql b/sqlc/query.sql index e7e6074..81c0d83 100644 --- a/sqlc/query.sql +++ b/sqlc/query.sql @@ -81,8 +81,7 @@ WHERE linfo.email = @email::text; ;-- name: GetValidUserByLoginCredentials :one SELECT - users.id, - users.username, + users.*, linfo.password_hash, linfo.totp_encrypted FROM users diff --git a/sqlc/schema.sql b/sqlc/schema.sql index 8e9602c..2d731be 100644 --- a/sqlc/schema.sql +++ b/sqlc/schema.sql @@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS "users" ( username VARCHAR(20) UNIQUE NOT NULL, verified BOOLEAN DEFAULT FALSE, registration_date TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + role INTEGER NOT NULL DEFAULT 0, deleted BOOLEAN DEFAULT FALSE );