diff --git a/backend/docs/docs.go b/backend/docs/docs.go index d2b75e5..e1df226 100644 --- a/backend/docs/docs.go +++ b/backend/docs/docs.go @@ -38,6 +38,44 @@ const docTemplate = `{ "responses": {} } }, + "/auth/changePassword": { + "post": { + "security": [ + { + "JWT": [] + } + ], + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Auth" + ], + "summary": "Set new password using the old password", + "parameters": [ + { + "description": " ", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.ChangePasswordRequest" + } + } + ], + "responses": { + "200": { + "description": "Password successfully changed" + }, + "403": { + "description": "Invalid old password" + } + } + } + }, "/auth/login": { "post": { "consumes": [ @@ -391,6 +429,24 @@ const docTemplate = `{ } } }, + "models.ChangePasswordRequest": { + "type": "object", + "required": [ + "old_password", + "password" + ], + "properties": { + "old_password": { + "type": "string" + }, + "password": { + "type": "string" + }, + "totp": { + "type": "string" + } + } + }, "models.LoginRequest": { "type": "object", "required": [ @@ -445,7 +501,7 @@ const docTemplate = `{ "email": { "type": "string" }, - "log_out_accounts": { + "log_out_sessions": { "type": "boolean" }, "password": { diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index e757280..af5b4dd 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -34,6 +34,44 @@ "responses": {} } }, + "/auth/changePassword": { + "post": { + "security": [ + { + "JWT": [] + } + ], + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Auth" + ], + "summary": "Set new password using the old password", + "parameters": [ + { + "description": " ", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.ChangePasswordRequest" + } + } + ], + "responses": { + "200": { + "description": "Password successfully changed" + }, + "403": { + "description": "Invalid old password" + } + } + } + }, "/auth/login": { "post": { "consumes": [ @@ -387,6 +425,24 @@ } } }, + "models.ChangePasswordRequest": { + "type": "object", + "required": [ + "old_password", + "password" + ], + "properties": { + "old_password": { + "type": "string" + }, + "password": { + "type": "string" + }, + "totp": { + "type": "string" + } + } + }, "models.LoginRequest": { "type": "object", "required": [ @@ -441,7 +497,7 @@ "email": { "type": "string" }, - "log_out_accounts": { + "log_out_sessions": { "type": "boolean" }, "password": { diff --git a/backend/docs/swagger.yaml b/backend/docs/swagger.yaml index ee8f9b7..4454738 100644 --- a/backend/docs/swagger.yaml +++ b/backend/docs/swagger.yaml @@ -5,6 +5,18 @@ definitions: healthy: type: boolean type: object + models.ChangePasswordRequest: + properties: + old_password: + type: string + password: + type: string + totp: + type: string + required: + - old_password + - password + type: object models.LoginRequest: properties: password: @@ -38,7 +50,7 @@ definitions: properties: email: type: string - log_out_accounts: + log_out_sessions: type: boolean password: type: string @@ -125,6 +137,29 @@ paths: summary: Change account password tags: - Account + /auth/changePassword: + post: + consumes: + - application/json + parameters: + - description: ' ' + in: body + name: request + required: true + schema: + $ref: '#/definitions/models.ChangePasswordRequest' + produces: + - application/json + responses: + "200": + description: Password successfully changed + "403": + description: Invalid old password + security: + - JWT: [] + summary: Set new password using the old password + tags: + - Auth /auth/login: post: consumes: diff --git a/backend/internal/controllers/auth.go b/backend/internal/controllers/auth.go index 742b093..8e82813 100644 --- a/backend/internal/controllers/auth.go +++ b/backend/internal/controllers/auth.go @@ -38,6 +38,7 @@ type AuthController interface { Refresh(c *gin.Context) PasswordResetBegin(c *gin.Context) PasswordResetComplete(c *gin.Context) + ChangePassword(c *gin.Context) Router } @@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) { return } - response, err := a.auth.Login(request.Body) + response, err := a.auth.Login(request.User, request.Body) if err != nil { if errors.Is(err, errs.ErrForbidden) { @@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) { response, err := a.auth.Refresh(request.Body) if err != nil { - if errors.Is(err, errs.ErrTokenExpired) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"}) - } else if errors.Is(err, errs.ErrTokenInvalid) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"}) - } else if errors.Is(err, errs.ErrWrongTokenType) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"}) - } else if errors.Is(err, errs.ErrSessionNotFound) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"}) - } else if errors.Is(err, errs.ErrSessionTerminated) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"}) + if utils.ErrorIsOneOf( + err, + errs.ErrTokenExpired, + errs.ErrTokenInvalid, + errs.ErrInvalidToken, + errs.ErrWrongTokenType, + errs.ErrSessionNotFound, + errs.ErrSessionTerminated, + ) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) } else { - c.Status(http.StatusInternalServerError) + c.JSON(http.StatusInternalServerError, err.Error()) } return } @@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) { return } - response, err := a.auth.RegistrationComplete(request.Body) + response, err := a.auth.RegistrationComplete(request.User, request.Body) if err != nil { if errors.Is(err, errs.ErrForbidden) { @@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) { c.JSON(http.StatusOK, response) } +// @Summary Set new password using the old password +// @Tags Auth +// @Accept json +// @Produce json +// @Security JWT +// @Param request body models.ChangePasswordRequest true " " +// @Success 200 "Password successfully changed" +// @Failure 403 "Invalid old password" +// @Router /auth/changePassword [post] +func (a *authControllerImpl) ChangePassword(c *gin.Context) { + request, ok := utils.GetRequest[models.ChangePasswordRequest](c) + if !ok { + c.Status(http.StatusBadRequest) + return + } + + response, err := a.auth.ChangePassword(request.Body, request.User) + + if err != nil { + if errors.Is(err, errs.ErrForbidden) { + c.Status(http.StatusForbidden) + } else { + c.Status(http.StatusInternalServerError) + } + return + } + + c.JSON(http.StatusOK, response) +} + func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) { group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin) group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete) @@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) { group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh) group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin) group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete) + group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword) } diff --git a/backend/internal/database/helper.go b/backend/internal/database/helper.go index 6056d1f..9b8fbc2 100644 --- a/backend/internal/database/helper.go +++ b/backend/internal/database/helper.go @@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct { TXlessQueries Queries TX pgx.Tx TXQueries Queries + isCommited bool } func NewDbHelper(dbContext DbContext) DbHelper { @@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error { errCommit := d.TX.Commit(d.CTX) if errCommit != nil { - errRollback := d.TX.Rollback(d.CTX) - - if errRollback != nil { - return errRollback - } - - return errCommit + d.isCommited = true } - return nil + + return errCommit } // Rollback implements DbHelperTransaction. func (d *dbHelperTransactionImpl) Rollback() error { - err := d.TX.Rollback(d.CTX) - - if err != nil { - return err + if d.isCommited { + return nil } - return nil + + return d.TX.Rollback(d.CTX); } // RollbackOnError implements DbHelperTransaction. func (d *dbHelperTransactionImpl) RollbackOnError(err error) error { - if err != nil { + if d.isCommited || err == nil { return d.Rollback() } return nil diff --git a/backend/internal/database/query.sql.go b/backend/internal/database/query.sql.go index fd7e669..059edfc 100644 --- a/backend/internal/database/query.sql.go +++ b/backend/internal/database/query.sql.go @@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error { return err } -const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec +const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many UPDATE sessions SET terminated = TRUE FROM users WHERE sessions.user_id = users.id AND users.username = $1::text +RETURNING sessions.guid ` -func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error { - _, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username) - return err +func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) { + rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username) + if err != nil { + return nil, err + } + defer rows.Close() + var items []pgtype.UUID + for rows.Next() { + var guid pgtype.UUID + if err := rows.Scan(&guid); err != nil { + return nil, err + } + items = append(items, guid) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const updateBannedUser = `-- name: UpdateBannedUser :exec diff --git a/backend/internal/dto/userInfo.go b/backend/internal/dto/clientInfo.go similarity index 100% rename from backend/internal/dto/userInfo.go rename to backend/internal/dto/clientInfo.go diff --git a/backend/internal/errors/auth.go b/backend/internal/errors/auth.go index 51a707a..80f6fea 100644 --- a/backend/internal/errors/auth.go +++ b/backend/internal/errors/auth.go @@ -31,7 +31,7 @@ var ( ErrServerError = errors.New("Internal server error") ErrTokenExpired = errors.New("Token is expired") - ErrTokenInvalid = errors.New("Token is invalid") + ErrTokenInvalid = ErrInvalidToken ErrWrongTokenType = errors.New("Invalid token type") ErrSessionNotFound = errors.New("Could not find session in database") ErrSessionTerminated = errors.New("Session is terminated") diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index 2dc0874..09cd644 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc } return } else { - c.Set("session_info", sessionInfo) + c.Set("session_info", *sessionInfo) c.Next() } diff --git a/backend/internal/models/auth.go b/backend/internal/models/auth.go index 7665b29..9288f61 100644 --- a/backend/internal/models/auth.go +++ b/backend/internal/models/auth.go @@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct { type PasswordResetCompleteResponse struct { Tokens } + +type ChangePasswordRequest struct { + OldPassword string `json:"old_password" binding:"required"` + NewPassword string `json:"password" binding:"required" validate:"password"` + TOTP string `json:"totp"` +} + +type ChangePasswordResponse struct { + Tokens +} diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index bfb5c9e..356aece 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -40,19 +40,20 @@ import ( type AuthService interface { RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) - RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) - Login(request models.LoginRequest) (*models.LoginResponse, error) + RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) + Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error) + ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error) ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) } type authServiceImpl struct { - log *zap.Logger + log *zap.Logger dbctx database.DbContext redis *redis.Client - smtp SmtpService + smtp SmtpService } func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService { @@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C panic("Failed to cache terminated session: " + err.Error()) } } + + if _, err := pipe.Exec(ctx); err != nil { + panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error()) + } + _log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids))) return authService } +func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error { + + sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil { + a.log.Error( + "Failed to terminate older sessions for user trying to log in", + zap.String("username", username), + zap.Error(err)) + return err + } + + pipe := a.redis.Pipeline() + for _, guid := range sessionGuids { + pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number + } + + if _, err := pipe.Exec(ctx); err != nil { + a.log.Error( + "Failed to cache terminated sessions", + zap.Error(err)) + return err + } + + return nil +} + +func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) { + + session, err := queries.CreateSession(ctx, database.CreateSessionParams{ + UserID: userID, + Name: utils.NewPointer(cinfo.UserAgent), + Platform: utils.NewPointer(cinfo.UserAgent), + LatestIp: utils.NewPointer(cinfo.IP), + }); if err != nil { + a.log.Error( + "Failed to add session to database", + zap.Error(err)) + return nil, err + } + + a.log.Info( + "Registered a new user session", + zap.String("username", cinfo.Username), + zap.String("session", cinfo.Session)) + return &session, nil +} + func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) { var occupationStatus database.CheckUserRegistrationAvailabilityRow @@ -91,18 +143,17 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ helper, db, err := database.NewDbHelperTransaction(a.dbctx) if err != nil { a.log.Error( - "Failed to open a transaction", + "Failed to open a transaction", zap.Error(err)) return false, errs.ErrServerError } + defer helper.RollbackOnError(err) - defer helper.RollbackOnError(err) - - if isInProgress, err := a.redis.Get( + isInProgress, err := a.redis.Get( context.TODO(), fmt.Sprintf("email::%s::registration_in_progress", - request.Email), - ).Bool(); err != nil { + request.Email), + ).Bool(); if err != nil { if err != redis.Nil { a.log.Error( "Failed to look up cached registration_in_progress state of email as part of registration procedure", @@ -113,13 +164,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ isInProgress = false } else if isInProgress { a.log.Warn( - "Attempted to begin registration on email that is in progress of registration or on cooldown", + "Attempted to begin registration on email that is in progress of registration or on cooldown", zap.String("email", request.Email)) return false, errs.ErrTooManyRequests } if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{ - Email: request.Email, + Email: request.Email, Username: request.Username, }); err != nil { a.log.Error( @@ -141,12 +192,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ // Falsely confirm in order to avoid disclosing registered email addresses if err := a.redis.Set( context.TODO(), - fmt.Sprintf("email::%s::registration_in_progress", request.Email), + fmt.Sprintf("email::%s::registration_in_progress", request.Email), true, - time.Duration(10 * time.Minute), // XXX: magic number + time.Duration(10*time.Minute), // XXX: magic number ).Err(); err != nil { a.log.Error( - "Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration", + "Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration", zap.String("email", request.Email), zap.Error(err)) return false, errs.ErrServerError @@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ } else { if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{ Username: request.Username, - Email: request.Email, + Email: request.Email, }); err != nil { a.log.Error( "Failed to purge unverified accounts as part of registration", @@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ } if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{ - UserID: user.ID, - Email: utils.NewPointer(request.Email), + UserID: user.ID, + Email: utils.NewPointer(request.Email), PasswordHash: passwordHash, // Hashed in database }); err != nil { @@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ } if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{ - UserID: user.ID, - CodeType: int32(enums.RegistrationCodeType), - CodeHash: generatedCodeHash, // Hashed in database + UserID: user.ID, + CodeType: int32(enums.RegistrationCodeType), + CodeHash: generatedCodeHash, // Hashed in database }); err != nil { a.log.Error("Failed to add registration code to database", zap.Error(err)) return false, errs.ErrServerError @@ -224,8 +275,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ if config.GetConfig().SmtpEnabled { if err := a.smtp.SendEmail( - request.Email, - "Easywish", + request.Email, + "Easywish", fmt.Sprintf("Your registration code is %s", generatedCode), ); err != nil { a.log.Error( @@ -246,12 +297,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ if err := a.redis.Set( context.TODO(), - fmt.Sprintf("email::%s::registration_in_progress", request.Email), + fmt.Sprintf("email::%s::registration_in_progress", request.Email), true, - time.Duration(10 * time.Minute), // XXX: magic number + time.Duration(10*time.Minute), // XXX: magic number ).Err(); err != nil { a.log.Error( - "Failed to cache registration_in_progress state for email", + "Failed to cache registration_in_progress state for email", zap.String("email", request.Email), zap.Error(err)) return false, errs.ErrServerError @@ -267,19 +318,18 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ return true, nil } -func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) { - +func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) { + var user database.User var profile database.Profile - var session database.Session - var confirmationCode database.ConfirmationCode + var confirmationCode database.ConfirmationCode var accessToken, refreshToken string var err error helper, db, err := database.NewDbHelperTransaction(a.dbctx) if err != nil { a.log.Error( - "Failed to open a transaction", + "Failed to open a transaction", zap.Error(err)) return nil, errs.ErrServerError } @@ -290,8 +340,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple if err != nil { if errors.Is(err, pgx.ErrNoRows) { a.log.Warn( - "Could not find user attempting to complete registration with given username", - zap.String("username", request.Username), + "Could not find user attempting to complete registration with given username", + zap.String("username", request.Username), zap.Error(err)) return nil, errs.ErrUserNotFound } @@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple } confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{ - UserID: user.ID, + UserID: user.ID, CodeType: int32(enums.RegistrationCodeType), - Code: request.VerificationCode, + Code: request.VerificationCode, }) if err != nil { @@ -327,10 +377,10 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple } err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{ - ID: confirmationCode.ID, + ID: confirmationCode.ID, Used: utils.NewPointer(true), }) - + if err != nil { a.log.Error( "Failed to update the user's registration code used state", @@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple return nil, errs.ErrServerError } - err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{ - ID: user.ID, + err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{ + ID: user.ID, Verified: utils.NewPointer(true), }) if err != nil { a.log.Error("Failed to update verified status for user", - zap.String("username", user.Username), - zap.Error(err)) + zap.String("username", user.Username), + zap.Error(err)) return nil, errs.ErrServerError } - profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{ + profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{ UserID: user.ID, - Name: request.Name, + Name: request.Name, }) if err != nil { a.log.Error("Failed to create profile for user", zap.String("username", user.Username), - ) return nil, errs.ErrServerError } @@ -369,29 +418,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple if err != nil { a.log.Error("Failed to create profile settings for user", - zap.String("username", user.Username), - zap.Error(err)) - return nil, errs.ErrServerError - } - - // TODO: session info - session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ - UserID: user.ID, - Name: utils.NewPointer("First device"), - Platform: utils.NewPointer("Unknown"), - LatestIp: utils.NewPointer("Unknown"), - }) - - if err != nil { - a.log.Error( - "Failed to create a new session during registration, rolling back registration", zap.String("username", user.Username), zap.Error(err)) return nil, errs.ErrServerError } + session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil { + a.log.Error("", zap.Error(err)) + return nil, errs.ErrServerError + } + // TODO: get user role - accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole) + accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole) if err != nil { a.log.Error( @@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple zap.String("username", request.Username)) response := models.RegistrationCompleteResponse{Tokens: models.Tokens{ - AccessToken: accessToken, + AccessToken: accessToken, RefreshToken: refreshToken, }} @@ -421,19 +459,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple } // TODO: totp -func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) { +func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) { var userRow database.GetValidUserByLoginCredentialsRow - var session database.Session var err error helper, db, err := database.NewDbHelperTransaction(a.dbctx) if err != nil { a.log.Error( - "Failed to open a transaction", + "Failed to open a transaction", zap.Error(err)) return nil, errs.ErrServerError } - defer helper.RollbackOnError(err) + defer helper.RollbackOnError(err) userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{ Username: request.Username, @@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo } // Until release 4, only 1 session at a time is supported - if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil { + err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil { a.log.Error( - "Failed to terminate older sessions for user trying to log in", - zap.String("username", request.Username), + "Failed to terminate user's sessions during login", zap.Error(err)) return nil, errs.ErrServerError } - session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ - // TODO: use actual values for session metadata - UserID: userRow.ID, - Name: utils.NewPointer("New device"), - Platform: utils.NewPointer("Unknown"), - LatestIp: utils.NewPointer("Unknown"), - }) - - if err != nil { - a.log.Error( - "Failed to create session for a new login", - zap.String("username", userRow.Username), - zap.Error(err)) + session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil { + a.log.Error("", zap.Error(err)) return nil, errs.ErrServerError } @@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo } response := models.LoginResponse{Tokens: models.Tokens{ - AccessToken: accessToken, + AccessToken: accessToken, RefreshToken: refreshToken, }} @@ -509,12 +534,13 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { - sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil { - + sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType) + if err != nil { + if utils.ErrorIsOneOf( - err, + err, errs.ErrInvalidToken, - errs.ErrTokenExpired, + errs.ErrTokenExpired, errs.ErrWrongTokenType, errs.ErrSessionNotFound, errs.ErrSessionTerminated, @@ -522,19 +548,20 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres return nil, err } else { a.log.Error( - "Encountered an unexpected error while validating token", + "Encountered an unexpected error while validating token", zap.Error(err)) return nil, errs.ErrServerError } - } + } accessToken, refreshToken, err := utils.GenerateTokens( sessionInfo.Username, sessionInfo.Session, sessionInfo.Role, - ); if err != nil { + ) + if err != nil { a.log.Error( - "Failed to generate tokens for user during refresh", + "Failed to generate tokens for user during refresh", zap.String("username", sessionInfo.Username), zap.String("session", sessionInfo.Session), zap.Error(err)) @@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres response := models.RefreshResponse{ Tokens: models.Tokens{ - AccessToken: accessToken, + AccessToken: accessToken, RefreshToken: refreshToken, }, } @@ -551,13 +578,13 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres return &response, nil } -func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { +func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) { var err error token, err := jwt.ParseWithClaims( - jwtToken, - &dto.UserClaims{}, + jwtToken, + &dto.UserClaims{}, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -569,11 +596,12 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, errs.ErrTokenExpired - } + } return nil, errs.ErrInvalidToken } - claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid { + claims, ok := token.Claims.(*dto.UserClaims) + if ok && token.Valid { if claims.Type != tokenType { return nil, errs.ErrWrongTokenType @@ -583,7 +611,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool() if redisErr != nil && redisErr != redis.Nil { a.log.Error( - "Failed to lookup cache to check whether session is not terminated", + "Failed to lookup cache to check whether session is not terminated", zap.Error(redisErr)) return nil, redisErr } @@ -603,20 +631,20 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke } a.log.Error( - "Failed to lookup session in database", + "Failed to lookup session in database", zap.String("session", claims.Session), zap.Error(err)) return nil, err } if err := a.redis.Set( - ctx, - fmt.Sprintf("session::%s::is_terminated", claims.Session), - *session.Terminated, - time.Duration(8 * time.Hour), // XXX: magic number + ctx, + fmt.Sprintf("session::%s::is_terminated", claims.Session), + *session.Terminated, + time.Duration(8*time.Hour), // XXX: magic number ).Err(); err != nil { a.log.Error( - "Failed to cache session's is_terminated state", + "Failed to cache session's is_terminated state", zap.String("session", claims.Session), zap.Error(err)) // c.AbortWithStatus(http.StatusInternalServerError) @@ -633,15 +661,15 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke sessionInfo := dto.SessionInfo{ Username: claims.Username, - Session: claims.Session, - Role: claims.Role, + Session: claims.Session, + Role: claims.Role, } return &sessionInfo, nil } func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) { - + var user database.User var generatedCode, hashedCode string var err error @@ -649,7 +677,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe helper, db, err := database.NewDbHelperTransaction(a.dbctx) if err != nil { a.log.Error( - "Failed to open a transaction", + "Failed to open a transaction", zap.Error(err)) return false, errs.ErrServerError } @@ -660,7 +688,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64() if redisErr != nil && redisErr != redis.Nil { a.log.Error( - "Failed to get reset_cooldown state for user", + "Failed to get reset_cooldown state for user", zap.String("email", request.Email), zap.Error(redisErr)) return false, errs.ErrServerError @@ -677,7 +705,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe if errors.Is(err, pgx.ErrNoRows) { // Enable cooldown for the email despite that account does not exist err := a.redis.Set( - ctx, + ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email), time.Now().Add(10*time.Minute), time.Duration(10*time.Minute), @@ -687,7 +715,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe a.log.Error( "Failed to set reset cooldown for email", zap.Error(err)) - return false, err + return false, err } a.log.Warn( @@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe } if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{ - UserID: user.ID, + UserID: user.ID, CodeType: int32(enums.PasswordResetCodeType), CodeHash: hashedCode, }); err != nil { @@ -723,7 +751,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe } err = a.redis.Set( - ctx, + ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email), time.Now().Add(10*time.Minute), time.Duration(10*time.Minute), @@ -733,7 +761,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe a.log.Error( "Failed to set reset cooldown for email. Cancelling password reset", zap.Error(err)) - return false, err + return false, err } if err = helper.Commit(); err != nil { @@ -757,7 +785,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp helper, db, err := database.NewDbHelperTransaction(a.dbctx) if err != nil { a.log.Error( - "Failed to open a transaction", + "Failed to open a transaction", zap.Error(err)) return nil, errs.ErrServerError } @@ -779,13 +807,13 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp } if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{ - UserID: user.ID, + UserID: user.ID, CodeType: int32(enums.PasswordResetCodeType), - Code: request.VerificationCode, + Code: request.VerificationCode, }); err != nil { if errors.Is(err, pgx.ErrNoRows) { a.log.Warn( - "Attempted to reset password for user using incorrect confirmation code", + "Attempted to reset password for user using incorrect confirmation code", zap.String("email", request.Email), zap.String("username", user.Username), zap.String("provided_code", request.VerificationCode), @@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp } if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{ - ID: resetCode.ID, + ID: resetCode.ID, Used: utils.NewPointer(true), }); err != nil { a.log.Error( @@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp } if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{ - Username: user.Username, + Username: user.Username, PasswordHash: hashedPassword, }); err != nil { a.log.Error( @@ -828,24 +856,22 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp } if request.LogOutSessions { - if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil { + err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil { a.log.Error( - "Failed to log out older sessions as part of user password reset", - zap.String("email", request.Email), - zap.String("username", user.Username), + "Failed to terminate user's sessions during login", zap.Error(err)) return nil, errs.ErrServerError } } - if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ - UserID: user.ID, - Name: utils.NewPointer("First device"), + session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ + UserID: user.ID, + Name: utils.NewPointer("First device"), Platform: utils.NewPointer("Unknown"), LatestIp: utils.NewPointer("Unknown"), - }); err != nil { + }); if err != nil { a.log.Error( - "Failed to create new session for user as part of user password reset", + "Failed to create new session for user as part of user password reset", zap.String("email", request.Email), zap.String("username", user.Username), zap.Error(err)) @@ -854,7 +880,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp // TODO: get user role if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil { a.log.Error( - "Failed to generate tokens as part of user password reset", + "Failed to generate tokens as part of user password reset", zap.String("email", request.Email), zap.String("username", user.Username), zap.Error(err)) @@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp response := models.PasswordResetCompleteResponse{ Tokens: models.Tokens{ - AccessToken: accessToken, + AccessToken: accessToken, RefreshToken: refreshToken, }, } @@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp return &response, nil } +func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) { + + var err error + + helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil { + a.log.Error( + "Failed to open a transaction", + zap.Error(err)) + return false, errs.ErrServerError + } + defer helper.RollbackOnError(err) + + linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil { + a.log.Error( + "Failed to get user login information", + zap.Error(err)) + return false, errs.ErrServerError + } + + if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) { + a.log.Warn( + "Provided invalid old password while changing password", + zap.String("username", uinfo.Username)) + return false, errs.ErrForbidden + } + + newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil { + a.log.Error( + "Failed to hash new password while changing password", + zap.String("username", uinfo.Username), + zap.Error(err)) + return false, errs.ErrServerError + } + + err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{ + Username: uinfo.Username, + PasswordHash: newPasswordHash, + }); if err != nil { + a.log.Error( + "Failed to save new password into database", + zap.String("username", uinfo.Username), + zap.Error(err)) + return false, errs.ErrServerError + } + + if err := helper.Commit(); err != nil { + a.log.Error( + "Failed to commit transaction", + zap.Error(err)) + return false, errs.ErrServerError + } + + return true, nil +} diff --git a/backend/internal/utils/errors.go b/backend/internal/utils/errors.go index 55b0452..f8512e0 100644 --- a/backend/internal/utils/errors.go +++ b/backend/internal/utils/errors.go @@ -22,8 +22,8 @@ import "errors" func ErrorIsOneOf(err error, ignoreErrors ...error) bool { for _, ignore := range ignoreErrors { if errors.Is(err, ignore) { - return false + return true } } - return true + return false } diff --git a/sqlc/query.sql b/sqlc/query.sql index a640c10..6a60051 100644 --- a/sqlc/query.sql +++ b/sqlc/query.sql @@ -259,11 +259,12 @@ WHERE terminated IS TRUE AND last_refresh_exp_time > CURRENT_TIMESTAMP; -;-- name: TerminateAllSessionsForUserByUsername :exec +;-- name: TerminateAllSessionsForUserByUsername :many UPDATE sessions SET terminated = TRUE FROM users -WHERE sessions.user_id = users.id AND users.username = @username::text; +WHERE sessions.user_id = users.id AND users.username = @username::text +RETURNING sessions.guid; ;-- name: PruneTerminatedSessions :exec DELETE FROM sessions diff --git a/sqlc/schema.sql b/sqlc/schema.sql index cd60b90..8e9602c 100644 --- a/sqlc/schema.sql +++ b/sqlc/schema.sql @@ -62,8 +62,8 @@ CREATE TABLE IF NOT EXISTS "sessions" ( id BIGSERIAL PRIMARY KEY, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, guid UUID NOT NULL DEFAULT gen_random_uuid(), - name VARCHAR(100), - platform VARCHAR(32), + name VARCHAR(175), + platform VARCHAR(175), latest_ip VARCHAR(16), login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',