feat: add change password endpoint using old password;
feat: implement change password service method with validation; fix: correct ErrorIsOneOf function logic to return true on match; refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity; refactor: update session termination to return GUIDs and cache in Redis; fix: ensure RollbackOnError only rolls back uncommitted transactions; fix: handle transaction commit errors properly in dbHelper; refactor: add helper methods for session termination and registration; refactor: pass client info to login and registration complete methods; fix: improve token validation error handling in refresh endpoint; refactor: update auth middleware to set session info correctly; chore: remove unused ClientInfo DTO; fix: correct password reset complete to use session termination helper; refactor: adjust database queries for session management; chore: update SQL schema and queries for sessions; docs: update swagger docs with new endpoint and model changes
This commit is contained in:
@@ -40,19 +40,20 @@ import (
|
||||
|
||||
type AuthService interface {
|
||||
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
||||
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||
Login(request models.LoginRequest) (*models.LoginResponse, error)
|
||||
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||
Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
|
||||
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
|
||||
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
|
||||
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
|
||||
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
|
||||
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
|
||||
}
|
||||
|
||||
type authServiceImpl struct {
|
||||
log *zap.Logger
|
||||
log *zap.Logger
|
||||
dbctx database.DbContext
|
||||
redis *redis.Client
|
||||
smtp SmtpService
|
||||
smtp SmtpService
|
||||
}
|
||||
|
||||
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
|
||||
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
|
||||
panic("Failed to cache terminated session: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
|
||||
}
|
||||
|
||||
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
|
||||
return authService
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
|
||||
|
||||
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to terminate older sessions for user trying to log in",
|
||||
zap.String("username", username),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := a.redis.Pipeline()
|
||||
for _, guid := range sessionGuids {
|
||||
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache terminated sessions",
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
|
||||
|
||||
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
|
||||
UserID: userID,
|
||||
Name: utils.NewPointer(cinfo.UserAgent),
|
||||
Platform: utils.NewPointer(cinfo.UserAgent),
|
||||
LatestIp: utils.NewPointer(cinfo.IP),
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to add session to database",
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.log.Info(
|
||||
"Registered a new user session",
|
||||
zap.String("username", cinfo.Username),
|
||||
zap.String("session", cinfo.Session))
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
|
||||
|
||||
var occupationStatus database.CheckUserRegistrationAvailabilityRow
|
||||
@@ -91,18 +143,17 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
if isInProgress, err := a.redis.Get(
|
||||
isInProgress, err := a.redis.Get(
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress",
|
||||
request.Email),
|
||||
).Bool(); err != nil {
|
||||
request.Email),
|
||||
).Bool(); if err != nil {
|
||||
if err != redis.Nil {
|
||||
a.log.Error(
|
||||
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
|
||||
@@ -113,13 +164,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
isInProgress = false
|
||||
} else if isInProgress {
|
||||
a.log.Warn(
|
||||
"Attempted to begin registration on email that is in progress of registration or on cooldown",
|
||||
"Attempted to begin registration on email that is in progress of registration or on cooldown",
|
||||
zap.String("email", request.Email))
|
||||
return false, errs.ErrTooManyRequests
|
||||
}
|
||||
|
||||
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
|
||||
Email: request.Email,
|
||||
Email: request.Email,
|
||||
Username: request.Username,
|
||||
}); err != nil {
|
||||
a.log.Error(
|
||||
@@ -141,12 +192,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
// Falsely confirm in order to avoid disclosing registered email addresses
|
||||
if err := a.redis.Set(
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
true,
|
||||
time.Duration(10 * time.Minute), // XXX: magic number
|
||||
time.Duration(10*time.Minute), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
||||
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
||||
zap.String("email", request.Email),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
@@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
} else {
|
||||
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
|
||||
Username: request.Username,
|
||||
Email: request.Email,
|
||||
Email: request.Email,
|
||||
}); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to purge unverified accounts as part of registration",
|
||||
@@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
}
|
||||
|
||||
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
|
||||
UserID: user.ID,
|
||||
Email: utils.NewPointer(request.Email),
|
||||
UserID: user.ID,
|
||||
Email: utils.NewPointer(request.Email),
|
||||
PasswordHash: passwordHash, // Hashed in database
|
||||
}); err != nil {
|
||||
|
||||
@@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
}
|
||||
|
||||
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
||||
UserID: user.ID,
|
||||
CodeType: int32(enums.RegistrationCodeType),
|
||||
CodeHash: generatedCodeHash, // Hashed in database
|
||||
UserID: user.ID,
|
||||
CodeType: int32(enums.RegistrationCodeType),
|
||||
CodeHash: generatedCodeHash, // Hashed in database
|
||||
}); err != nil {
|
||||
a.log.Error("Failed to add registration code to database", zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
@@ -224,8 +275,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
if config.GetConfig().SmtpEnabled {
|
||||
|
||||
if err := a.smtp.SendEmail(
|
||||
request.Email,
|
||||
"Easywish",
|
||||
request.Email,
|
||||
"Easywish",
|
||||
fmt.Sprintf("Your registration code is %s", generatedCode),
|
||||
); err != nil {
|
||||
a.log.Error(
|
||||
@@ -246,12 +297,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
|
||||
if err := a.redis.Set(
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
true,
|
||||
time.Duration(10 * time.Minute), // XXX: magic number
|
||||
time.Duration(10*time.Minute), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache registration_in_progress state for email",
|
||||
"Failed to cache registration_in_progress state for email",
|
||||
zap.String("email", request.Email),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
@@ -267,19 +318,18 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
||||
|
||||
func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
||||
|
||||
var user database.User
|
||||
var profile database.Profile
|
||||
var session database.Session
|
||||
var confirmationCode database.ConfirmationCode
|
||||
var confirmationCode database.ConfirmationCode
|
||||
var accessToken, refreshToken string
|
||||
var err error
|
||||
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
@@ -290,8 +340,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
a.log.Warn(
|
||||
"Could not find user attempting to complete registration with given username",
|
||||
zap.String("username", request.Username),
|
||||
"Could not find user attempting to complete registration with given username",
|
||||
zap.String("username", request.Username),
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrUserNotFound
|
||||
}
|
||||
@@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
}
|
||||
|
||||
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
||||
UserID: user.ID,
|
||||
UserID: user.ID,
|
||||
CodeType: int32(enums.RegistrationCodeType),
|
||||
Code: request.VerificationCode,
|
||||
Code: request.VerificationCode,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -327,10 +377,10 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
}
|
||||
|
||||
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
||||
ID: confirmationCode.ID,
|
||||
ID: confirmationCode.ID,
|
||||
Used: utils.NewPointer(true),
|
||||
})
|
||||
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to update the user's registration code used state",
|
||||
@@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
||||
ID: user.ID,
|
||||
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
||||
ID: user.ID,
|
||||
Verified: utils.NewPointer(true),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error("Failed to update verified status for user",
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
||||
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
||||
UserID: user.ID,
|
||||
Name: request.Name,
|
||||
Name: request.Name,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error("Failed to create profile for user",
|
||||
zap.String("username", user.Username),
|
||||
|
||||
)
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
@@ -369,29 +418,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
|
||||
if err != nil {
|
||||
a.log.Error("Failed to create profile settings for user",
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
// TODO: session info
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
UserID: user.ID,
|
||||
Name: utils.NewPointer("First device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create a new session during registration, rolling back registration",
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
|
||||
a.log.Error("", zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
// TODO: get user role
|
||||
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
|
||||
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
@@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
zap.String("username", request.Username))
|
||||
|
||||
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
|
||||
AccessToken: accessToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}}
|
||||
|
||||
@@ -421,19 +459,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
}
|
||||
|
||||
// TODO: totp
|
||||
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) {
|
||||
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
||||
var userRow database.GetValidUserByLoginCredentialsRow
|
||||
var session database.Session
|
||||
var err error
|
||||
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
defer helper.RollbackOnError(err)
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
||||
Username: request.Username,
|
||||
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
||||
}
|
||||
|
||||
// Until release 4, only 1 session at a time is supported
|
||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil {
|
||||
err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to terminate older sessions for user trying to log in",
|
||||
zap.String("username", request.Username),
|
||||
"Failed to terminate user's sessions during login",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
// TODO: use actual values for session metadata
|
||||
UserID: userRow.ID,
|
||||
Name: utils.NewPointer("New device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create session for a new login",
|
||||
zap.String("username", userRow.Username),
|
||||
zap.Error(err))
|
||||
session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
|
||||
a.log.Error("", zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
@@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
||||
}
|
||||
|
||||
response := models.LoginResponse{Tokens: models.Tokens{
|
||||
AccessToken: accessToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}}
|
||||
|
||||
@@ -509,12 +534,13 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
||||
|
||||
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
||||
|
||||
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
|
||||
|
||||
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
|
||||
if err != nil {
|
||||
|
||||
if utils.ErrorIsOneOf(
|
||||
err,
|
||||
err,
|
||||
errs.ErrInvalidToken,
|
||||
errs.ErrTokenExpired,
|
||||
errs.ErrTokenExpired,
|
||||
errs.ErrWrongTokenType,
|
||||
errs.ErrSessionNotFound,
|
||||
errs.ErrSessionTerminated,
|
||||
@@ -522,19 +548,20 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
||||
return nil, err
|
||||
} else {
|
||||
a.log.Error(
|
||||
"Encountered an unexpected error while validating token",
|
||||
"Encountered an unexpected error while validating token",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := utils.GenerateTokens(
|
||||
sessionInfo.Username,
|
||||
sessionInfo.Session,
|
||||
sessionInfo.Role,
|
||||
); if err != nil {
|
||||
)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to generate tokens for user during refresh",
|
||||
"Failed to generate tokens for user during refresh",
|
||||
zap.String("username", sessionInfo.Username),
|
||||
zap.String("session", sessionInfo.Session),
|
||||
zap.Error(err))
|
||||
@@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
||||
|
||||
response := models.RefreshResponse{
|
||||
Tokens: models.Tokens{
|
||||
AccessToken: accessToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
}
|
||||
@@ -551,13 +578,13 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
||||
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
||||
|
||||
var err error
|
||||
|
||||
token, err := jwt.ParseWithClaims(
|
||||
jwtToken,
|
||||
&dto.UserClaims{},
|
||||
jwtToken,
|
||||
&dto.UserClaims{},
|
||||
func(token *jwt.Token) (any, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
@@ -569,11 +596,12 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, errs.ErrTokenExpired
|
||||
}
|
||||
}
|
||||
return nil, errs.ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
|
||||
claims, ok := token.Claims.(*dto.UserClaims)
|
||||
if ok && token.Valid {
|
||||
|
||||
if claims.Type != tokenType {
|
||||
return nil, errs.ErrWrongTokenType
|
||||
@@ -583,7 +611,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
|
||||
if redisErr != nil && redisErr != redis.Nil {
|
||||
a.log.Error(
|
||||
"Failed to lookup cache to check whether session is not terminated",
|
||||
"Failed to lookup cache to check whether session is not terminated",
|
||||
zap.Error(redisErr))
|
||||
return nil, redisErr
|
||||
}
|
||||
@@ -603,20 +631,20 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
}
|
||||
|
||||
a.log.Error(
|
||||
"Failed to lookup session in database",
|
||||
"Failed to lookup session in database",
|
||||
zap.String("session", claims.Session),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.redis.Set(
|
||||
ctx,
|
||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
||||
*session.Terminated,
|
||||
time.Duration(8 * time.Hour), // XXX: magic number
|
||||
ctx,
|
||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
||||
*session.Terminated,
|
||||
time.Duration(8*time.Hour), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache session's is_terminated state",
|
||||
"Failed to cache session's is_terminated state",
|
||||
zap.String("session", claims.Session),
|
||||
zap.Error(err))
|
||||
// c.AbortWithStatus(http.StatusInternalServerError)
|
||||
@@ -633,15 +661,15 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
|
||||
sessionInfo := dto.SessionInfo{
|
||||
Username: claims.Username,
|
||||
Session: claims.Session,
|
||||
Role: claims.Role,
|
||||
Session: claims.Session,
|
||||
Role: claims.Role,
|
||||
}
|
||||
|
||||
return &sessionInfo, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) {
|
||||
|
||||
|
||||
var user database.User
|
||||
var generatedCode, hashedCode string
|
||||
var err error
|
||||
@@ -649,7 +677,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
@@ -660,7 +688,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64()
|
||||
if redisErr != nil && redisErr != redis.Nil {
|
||||
a.log.Error(
|
||||
"Failed to get reset_cooldown state for user",
|
||||
"Failed to get reset_cooldown state for user",
|
||||
zap.String("email", request.Email),
|
||||
zap.Error(redisErr))
|
||||
return false, errs.ErrServerError
|
||||
@@ -677,7 +705,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
// Enable cooldown for the email despite that account does not exist
|
||||
err := a.redis.Set(
|
||||
ctx,
|
||||
ctx,
|
||||
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
|
||||
time.Now().Add(10*time.Minute),
|
||||
time.Duration(10*time.Minute),
|
||||
@@ -687,7 +715,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
a.log.Error(
|
||||
"Failed to set reset cooldown for email",
|
||||
zap.Error(err))
|
||||
return false, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
a.log.Warn(
|
||||
@@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
}
|
||||
|
||||
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
||||
UserID: user.ID,
|
||||
UserID: user.ID,
|
||||
CodeType: int32(enums.PasswordResetCodeType),
|
||||
CodeHash: hashedCode,
|
||||
}); err != nil {
|
||||
@@ -723,7 +751,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
}
|
||||
|
||||
err = a.redis.Set(
|
||||
ctx,
|
||||
ctx,
|
||||
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
|
||||
time.Now().Add(10*time.Minute),
|
||||
time.Duration(10*time.Minute),
|
||||
@@ -733,7 +761,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
||||
a.log.Error(
|
||||
"Failed to set reset cooldown for email. Cancelling password reset",
|
||||
zap.Error(err))
|
||||
return false, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err = helper.Commit(); err != nil {
|
||||
@@ -757,7 +785,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
@@ -779,13 +807,13 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
}
|
||||
|
||||
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
||||
UserID: user.ID,
|
||||
UserID: user.ID,
|
||||
CodeType: int32(enums.PasswordResetCodeType),
|
||||
Code: request.VerificationCode,
|
||||
Code: request.VerificationCode,
|
||||
}); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
a.log.Warn(
|
||||
"Attempted to reset password for user using incorrect confirmation code",
|
||||
"Attempted to reset password for user using incorrect confirmation code",
|
||||
zap.String("email", request.Email),
|
||||
zap.String("username", user.Username),
|
||||
zap.String("provided_code", request.VerificationCode),
|
||||
@@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
}
|
||||
|
||||
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
||||
ID: resetCode.ID,
|
||||
ID: resetCode.ID,
|
||||
Used: utils.NewPointer(true),
|
||||
}); err != nil {
|
||||
a.log.Error(
|
||||
@@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
}
|
||||
|
||||
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
||||
Username: user.Username,
|
||||
Username: user.Username,
|
||||
PasswordHash: hashedPassword,
|
||||
}); err != nil {
|
||||
a.log.Error(
|
||||
@@ -828,24 +856,22 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
}
|
||||
|
||||
if request.LogOutSessions {
|
||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil {
|
||||
err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to log out older sessions as part of user password reset",
|
||||
zap.String("email", request.Email),
|
||||
zap.String("username", user.Username),
|
||||
"Failed to terminate user's sessions during login",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
}
|
||||
|
||||
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
UserID: user.ID,
|
||||
Name: utils.NewPointer("First device"),
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
UserID: user.ID,
|
||||
Name: utils.NewPointer("First device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
}); err != nil {
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create new session for user as part of user password reset",
|
||||
"Failed to create new session for user as part of user password reset",
|
||||
zap.String("email", request.Email),
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
@@ -854,7 +880,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
// TODO: get user role
|
||||
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to generate tokens as part of user password reset",
|
||||
"Failed to generate tokens as part of user password reset",
|
||||
zap.String("email", request.Email),
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
@@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
|
||||
response := models.PasswordResetCompleteResponse{
|
||||
Tokens: models.Tokens{
|
||||
AccessToken: accessToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
}
|
||||
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
|
||||
|
||||
var err error
|
||||
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to get user login information",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
|
||||
a.log.Warn(
|
||||
"Provided invalid old password while changing password",
|
||||
zap.String("username", uinfo.Username))
|
||||
return false, errs.ErrForbidden
|
||||
}
|
||||
|
||||
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to hash new password while changing password",
|
||||
zap.String("username", uinfo.Username),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
||||
Username: uinfo.Username,
|
||||
PasswordHash: newPasswordHash,
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to save new password into database",
|
||||
zap.String("username", uinfo.Username),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
if err := helper.Commit(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to commit transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user