|
|
|
|
@@ -40,19 +40,20 @@ import (
|
|
|
|
|
|
|
|
|
|
type AuthService interface {
|
|
|
|
|
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
|
|
|
|
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
|
|
|
|
Login(request models.LoginRequest) (*models.LoginResponse, error)
|
|
|
|
|
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
|
|
|
|
Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
|
|
|
|
|
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
|
|
|
|
|
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
|
|
|
|
|
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
|
|
|
|
|
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
|
|
|
|
|
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type authServiceImpl struct {
|
|
|
|
|
log *zap.Logger
|
|
|
|
|
log *zap.Logger
|
|
|
|
|
dbctx database.DbContext
|
|
|
|
|
redis *redis.Client
|
|
|
|
|
smtp SmtpService
|
|
|
|
|
smtp SmtpService
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
|
|
|
|
|
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
|
|
|
|
|
panic("Failed to cache terminated session: " + err.Error())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
|
|
|
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
|
|
|
|
|
return authService
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
|
|
|
|
|
|
|
|
|
|
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to terminate older sessions for user trying to log in",
|
|
|
|
|
zap.String("username", username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pipe := a.redis.Pipeline()
|
|
|
|
|
for _, guid := range sessionGuids {
|
|
|
|
|
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to cache terminated sessions",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
|
|
|
|
|
|
|
|
|
|
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
|
|
|
|
|
UserID: userID,
|
|
|
|
|
Name: utils.NewPointer(cinfo.UserAgent),
|
|
|
|
|
Platform: utils.NewPointer(cinfo.UserAgent),
|
|
|
|
|
LatestIp: utils.NewPointer(cinfo.IP),
|
|
|
|
|
}); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to add session to database",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
a.log.Info(
|
|
|
|
|
"Registered a new user session",
|
|
|
|
|
zap.String("username", cinfo.Username),
|
|
|
|
|
zap.String("session", cinfo.Session))
|
|
|
|
|
return &session, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
|
|
|
|
|
|
|
|
|
|
var occupationStatus database.CheckUserRegistrationAvailabilityRow
|
|
|
|
|
@@ -91,18 +143,17 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
defer helper.RollbackOnError(err)
|
|
|
|
|
|
|
|
|
|
defer helper.RollbackOnError(err)
|
|
|
|
|
|
|
|
|
|
if isInProgress, err := a.redis.Get(
|
|
|
|
|
isInProgress, err := a.redis.Get(
|
|
|
|
|
context.TODO(),
|
|
|
|
|
fmt.Sprintf("email::%s::registration_in_progress",
|
|
|
|
|
request.Email),
|
|
|
|
|
).Bool(); err != nil {
|
|
|
|
|
request.Email),
|
|
|
|
|
).Bool(); if err != nil {
|
|
|
|
|
if err != redis.Nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
|
|
|
|
|
@@ -113,13 +164,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
isInProgress = false
|
|
|
|
|
} else if isInProgress {
|
|
|
|
|
a.log.Warn(
|
|
|
|
|
"Attempted to begin registration on email that is in progress of registration or on cooldown",
|
|
|
|
|
"Attempted to begin registration on email that is in progress of registration or on cooldown",
|
|
|
|
|
zap.String("email", request.Email))
|
|
|
|
|
return false, errs.ErrTooManyRequests
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
|
|
|
|
|
Email: request.Email,
|
|
|
|
|
Email: request.Email,
|
|
|
|
|
Username: request.Username,
|
|
|
|
|
}); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
@@ -141,12 +192,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
// Falsely confirm in order to avoid disclosing registered email addresses
|
|
|
|
|
if err := a.redis.Set(
|
|
|
|
|
context.TODO(),
|
|
|
|
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
|
|
|
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
|
|
|
|
true,
|
|
|
|
|
time.Duration(10 * time.Minute), // XXX: magic number
|
|
|
|
|
time.Duration(10*time.Minute), // XXX: magic number
|
|
|
|
|
).Err(); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
|
|
|
|
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
@@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
} else {
|
|
|
|
|
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
|
|
|
|
|
Username: request.Username,
|
|
|
|
|
Email: request.Email,
|
|
|
|
|
Email: request.Email,
|
|
|
|
|
}); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to purge unverified accounts as part of registration",
|
|
|
|
|
@@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Email: utils.NewPointer(request.Email),
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Email: utils.NewPointer(request.Email),
|
|
|
|
|
PasswordHash: passwordHash, // Hashed in database
|
|
|
|
|
}); err != nil {
|
|
|
|
|
|
|
|
|
|
@@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
CodeType: int32(enums.RegistrationCodeType),
|
|
|
|
|
CodeHash: generatedCodeHash, // Hashed in database
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
CodeType: int32(enums.RegistrationCodeType),
|
|
|
|
|
CodeHash: generatedCodeHash, // Hashed in database
|
|
|
|
|
}); err != nil {
|
|
|
|
|
a.log.Error("Failed to add registration code to database", zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
@@ -224,8 +275,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
if config.GetConfig().SmtpEnabled {
|
|
|
|
|
|
|
|
|
|
if err := a.smtp.SendEmail(
|
|
|
|
|
request.Email,
|
|
|
|
|
"Easywish",
|
|
|
|
|
request.Email,
|
|
|
|
|
"Easywish",
|
|
|
|
|
fmt.Sprintf("Your registration code is %s", generatedCode),
|
|
|
|
|
); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
@@ -246,12 +297,12 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
|
|
|
|
|
if err := a.redis.Set(
|
|
|
|
|
context.TODO(),
|
|
|
|
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
|
|
|
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
|
|
|
|
true,
|
|
|
|
|
time.Duration(10 * time.Minute), // XXX: magic number
|
|
|
|
|
time.Duration(10*time.Minute), // XXX: magic number
|
|
|
|
|
).Err(); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to cache registration_in_progress state for email",
|
|
|
|
|
"Failed to cache registration_in_progress state for email",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
@@ -267,19 +318,18 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|
|
|
|
return true, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
|
|
|
|
|
|
|
|
|
var user database.User
|
|
|
|
|
var profile database.Profile
|
|
|
|
|
var session database.Session
|
|
|
|
|
var confirmationCode database.ConfirmationCode
|
|
|
|
|
var confirmationCode database.ConfirmationCode
|
|
|
|
|
var accessToken, refreshToken string
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
@@ -290,8 +340,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
if err != nil {
|
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
|
|
|
a.log.Warn(
|
|
|
|
|
"Could not find user attempting to complete registration with given username",
|
|
|
|
|
zap.String("username", request.Username),
|
|
|
|
|
"Could not find user attempting to complete registration with given username",
|
|
|
|
|
zap.String("username", request.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrUserNotFound
|
|
|
|
|
}
|
|
|
|
|
@@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
CodeType: int32(enums.RegistrationCodeType),
|
|
|
|
|
Code: request.VerificationCode,
|
|
|
|
|
Code: request.VerificationCode,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
@@ -327,10 +377,10 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
|
|
|
|
ID: confirmationCode.ID,
|
|
|
|
|
ID: confirmationCode.ID,
|
|
|
|
|
Used: utils.NewPointer(true),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to update the user's registration code used state",
|
|
|
|
|
@@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
|
|
|
|
ID: user.ID,
|
|
|
|
|
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
|
|
|
|
ID: user.ID,
|
|
|
|
|
Verified: utils.NewPointer(true),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error("Failed to update verified status for user",
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
|
|
|
|
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Name: request.Name,
|
|
|
|
|
Name: request.Name,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error("Failed to create profile for user",
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
|
|
|
|
|
)
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
@@ -369,29 +418,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error("Failed to create profile settings for user",
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: session info
|
|
|
|
|
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Name: utils.NewPointer("First device"),
|
|
|
|
|
Platform: utils.NewPointer("Unknown"),
|
|
|
|
|
LatestIp: utils.NewPointer("Unknown"),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to create a new session during registration, rolling back registration",
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
|
|
|
|
|
a.log.Error("", zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: get user role
|
|
|
|
|
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
|
|
|
|
|
accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole)
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
@@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
zap.String("username", request.Username))
|
|
|
|
|
|
|
|
|
|
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
RefreshToken: refreshToken,
|
|
|
|
|
}}
|
|
|
|
|
|
|
|
|
|
@@ -421,19 +459,18 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: totp
|
|
|
|
|
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) {
|
|
|
|
|
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
|
|
|
|
var userRow database.GetValidUserByLoginCredentialsRow
|
|
|
|
|
var session database.Session
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
defer helper.RollbackOnError(err)
|
|
|
|
|
defer helper.RollbackOnError(err)
|
|
|
|
|
|
|
|
|
|
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
|
|
|
|
Username: request.Username,
|
|
|
|
|
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Until release 4, only 1 session at a time is supported
|
|
|
|
|
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil {
|
|
|
|
|
err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to terminate older sessions for user trying to log in",
|
|
|
|
|
zap.String("username", request.Username),
|
|
|
|
|
"Failed to terminate user's sessions during login",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
|
|
|
|
// TODO: use actual values for session metadata
|
|
|
|
|
UserID: userRow.ID,
|
|
|
|
|
Name: utils.NewPointer("New device"),
|
|
|
|
|
Platform: utils.NewPointer("Unknown"),
|
|
|
|
|
LatestIp: utils.NewPointer("Unknown"),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to create session for a new login",
|
|
|
|
|
zap.String("username", userRow.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
|
|
|
|
|
a.log.Error("", zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response := models.LoginResponse{Tokens: models.Tokens{
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
RefreshToken: refreshToken,
|
|
|
|
|
}}
|
|
|
|
|
|
|
|
|
|
@@ -509,12 +534,13 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
|
|
|
|
|
|
|
|
|
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
|
|
|
|
|
|
|
|
|
|
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
|
|
if utils.ErrorIsOneOf(
|
|
|
|
|
err,
|
|
|
|
|
err,
|
|
|
|
|
errs.ErrInvalidToken,
|
|
|
|
|
errs.ErrTokenExpired,
|
|
|
|
|
errs.ErrTokenExpired,
|
|
|
|
|
errs.ErrWrongTokenType,
|
|
|
|
|
errs.ErrSessionNotFound,
|
|
|
|
|
errs.ErrSessionTerminated,
|
|
|
|
|
@@ -522,19 +548,20 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|
|
|
|
return nil, err
|
|
|
|
|
} else {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Encountered an unexpected error while validating token",
|
|
|
|
|
"Encountered an unexpected error while validating token",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
accessToken, refreshToken, err := utils.GenerateTokens(
|
|
|
|
|
sessionInfo.Username,
|
|
|
|
|
sessionInfo.Session,
|
|
|
|
|
sessionInfo.Role,
|
|
|
|
|
); if err != nil {
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to generate tokens for user during refresh",
|
|
|
|
|
"Failed to generate tokens for user during refresh",
|
|
|
|
|
zap.String("username", sessionInfo.Username),
|
|
|
|
|
zap.String("session", sessionInfo.Session),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
@@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|
|
|
|
|
|
|
|
|
response := models.RefreshResponse{
|
|
|
|
|
Tokens: models.Tokens{
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
RefreshToken: refreshToken,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
@@ -551,13 +578,13 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|
|
|
|
return &response, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
|
|
|
|
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
|
|
|
|
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
token, err := jwt.ParseWithClaims(
|
|
|
|
|
jwtToken,
|
|
|
|
|
&dto.UserClaims{},
|
|
|
|
|
jwtToken,
|
|
|
|
|
&dto.UserClaims{},
|
|
|
|
|
func(token *jwt.Token) (any, error) {
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
@@ -569,11 +596,12 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|
|
|
|
if err != nil {
|
|
|
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
|
|
|
return nil, errs.ErrTokenExpired
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil, errs.ErrInvalidToken
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
|
|
|
|
|
claims, ok := token.Claims.(*dto.UserClaims)
|
|
|
|
|
if ok && token.Valid {
|
|
|
|
|
|
|
|
|
|
if claims.Type != tokenType {
|
|
|
|
|
return nil, errs.ErrWrongTokenType
|
|
|
|
|
@@ -583,7 +611,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|
|
|
|
isTerminated, redisErr := a.redis.Get(ctx, fmt.Sprintf("session::%s::is_terminated", claims.Session)).Bool()
|
|
|
|
|
if redisErr != nil && redisErr != redis.Nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to lookup cache to check whether session is not terminated",
|
|
|
|
|
"Failed to lookup cache to check whether session is not terminated",
|
|
|
|
|
zap.Error(redisErr))
|
|
|
|
|
return nil, redisErr
|
|
|
|
|
}
|
|
|
|
|
@@ -603,20 +631,20 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to lookup session in database",
|
|
|
|
|
"Failed to lookup session in database",
|
|
|
|
|
zap.String("session", claims.Session),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := a.redis.Set(
|
|
|
|
|
ctx,
|
|
|
|
|
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
|
|
|
|
*session.Terminated,
|
|
|
|
|
time.Duration(8 * time.Hour), // XXX: magic number
|
|
|
|
|
ctx,
|
|
|
|
|
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
|
|
|
|
*session.Terminated,
|
|
|
|
|
time.Duration(8*time.Hour), // XXX: magic number
|
|
|
|
|
).Err(); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to cache session's is_terminated state",
|
|
|
|
|
"Failed to cache session's is_terminated state",
|
|
|
|
|
zap.String("session", claims.Session),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
// c.AbortWithStatus(http.StatusInternalServerError)
|
|
|
|
|
@@ -633,15 +661,15 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|
|
|
|
|
|
|
|
|
sessionInfo := dto.SessionInfo{
|
|
|
|
|
Username: claims.Username,
|
|
|
|
|
Session: claims.Session,
|
|
|
|
|
Role: claims.Role,
|
|
|
|
|
Session: claims.Session,
|
|
|
|
|
Role: claims.Role,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return &sessionInfo, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var user database.User
|
|
|
|
|
var generatedCode, hashedCode string
|
|
|
|
|
var err error
|
|
|
|
|
@@ -649,7 +677,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
@@ -660,7 +688,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64()
|
|
|
|
|
if redisErr != nil && redisErr != redis.Nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to get reset_cooldown state for user",
|
|
|
|
|
"Failed to get reset_cooldown state for user",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.Error(redisErr))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
@@ -677,7 +705,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
|
|
|
// Enable cooldown for the email despite that account does not exist
|
|
|
|
|
err := a.redis.Set(
|
|
|
|
|
ctx,
|
|
|
|
|
ctx,
|
|
|
|
|
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
|
|
|
|
|
time.Now().Add(10*time.Minute),
|
|
|
|
|
time.Duration(10*time.Minute),
|
|
|
|
|
@@ -687,7 +715,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to set reset cooldown for email",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, err
|
|
|
|
|
return false, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
a.log.Warn(
|
|
|
|
|
@@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
CodeType: int32(enums.PasswordResetCodeType),
|
|
|
|
|
CodeHash: hashedCode,
|
|
|
|
|
}); err != nil {
|
|
|
|
|
@@ -723,7 +751,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = a.redis.Set(
|
|
|
|
|
ctx,
|
|
|
|
|
ctx,
|
|
|
|
|
fmt.Sprintf("email::%s::reset_cooldown", request.Email),
|
|
|
|
|
time.Now().Add(10*time.Minute),
|
|
|
|
|
time.Duration(10*time.Minute),
|
|
|
|
|
@@ -733,7 +761,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to set reset cooldown for email. Cancelling password reset",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, err
|
|
|
|
|
return false, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err = helper.Commit(); err != nil {
|
|
|
|
|
@@ -757,7 +785,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
@@ -779,13 +807,13 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
CodeType: int32(enums.PasswordResetCodeType),
|
|
|
|
|
Code: request.VerificationCode,
|
|
|
|
|
Code: request.VerificationCode,
|
|
|
|
|
}); err != nil {
|
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
|
|
|
a.log.Warn(
|
|
|
|
|
"Attempted to reset password for user using incorrect confirmation code",
|
|
|
|
|
"Attempted to reset password for user using incorrect confirmation code",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.String("provided_code", request.VerificationCode),
|
|
|
|
|
@@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
|
|
|
|
ID: resetCode.ID,
|
|
|
|
|
ID: resetCode.ID,
|
|
|
|
|
Used: utils.NewPointer(true),
|
|
|
|
|
}); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
@@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
|
|
|
|
Username: user.Username,
|
|
|
|
|
Username: user.Username,
|
|
|
|
|
PasswordHash: hashedPassword,
|
|
|
|
|
}); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
@@ -828,24 +856,22 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if request.LogOutSessions {
|
|
|
|
|
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil {
|
|
|
|
|
err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to log out older sessions as part of user password reset",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
"Failed to terminate user's sessions during login",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return nil, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Name: utils.NewPointer("First device"),
|
|
|
|
|
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
Name: utils.NewPointer("First device"),
|
|
|
|
|
Platform: utils.NewPointer("Unknown"),
|
|
|
|
|
LatestIp: utils.NewPointer("Unknown"),
|
|
|
|
|
}); err != nil {
|
|
|
|
|
}); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to create new session for user as part of user password reset",
|
|
|
|
|
"Failed to create new session for user as part of user password reset",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
@@ -854,7 +880,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
// TODO: get user role
|
|
|
|
|
if accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String(), enums.UserRole); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to generate tokens as part of user password reset",
|
|
|
|
|
"Failed to generate tokens as part of user password reset",
|
|
|
|
|
zap.String("email", request.Email),
|
|
|
|
|
zap.String("username", user.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
@@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
|
|
|
|
|
response := models.PasswordResetCompleteResponse{
|
|
|
|
|
Tokens: models.Tokens{
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
AccessToken: accessToken,
|
|
|
|
|
RefreshToken: refreshToken,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|
|
|
|
return &response, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
|
|
|
|
|
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to open a transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
defer helper.RollbackOnError(err)
|
|
|
|
|
|
|
|
|
|
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to get user login information",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
|
|
|
|
|
a.log.Warn(
|
|
|
|
|
"Provided invalid old password while changing password",
|
|
|
|
|
zap.String("username", uinfo.Username))
|
|
|
|
|
return false, errs.ErrForbidden
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to hash new password while changing password",
|
|
|
|
|
zap.String("username", uinfo.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
|
|
|
|
Username: uinfo.Username,
|
|
|
|
|
PasswordHash: newPasswordHash,
|
|
|
|
|
}); if err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to save new password into database",
|
|
|
|
|
zap.String("username", uinfo.Username),
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := helper.Commit(); err != nil {
|
|
|
|
|
a.log.Error(
|
|
|
|
|
"Failed to commit transaction",
|
|
|
|
|
zap.Error(err))
|
|
|
|
|
return false, errs.ErrServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true, nil
|
|
|
|
|
}
|
|
|
|
|
|