feat: add change password endpoint using old password;
feat: implement change password service method with validation; fix: correct ErrorIsOneOf function logic to return true on match; refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity; refactor: update session termination to return GUIDs and cache in Redis; fix: ensure RollbackOnError only rolls back uncommitted transactions; fix: handle transaction commit errors properly in dbHelper; refactor: add helper methods for session termination and registration; refactor: pass client info to login and registration complete methods; fix: improve token validation error handling in refresh endpoint; refactor: update auth middleware to set session info correctly; chore: remove unused ClientInfo DTO; fix: correct password reset complete to use session termination helper; refactor: adjust database queries for session management; chore: update SQL schema and queries for sessions; docs: update swagger docs with new endpoint and model changes
This commit is contained in:
@@ -38,6 +38,44 @@ const docTemplate = `{
|
||||
"responses": {}
|
||||
}
|
||||
},
|
||||
"/auth/changePassword": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"JWT": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Auth"
|
||||
],
|
||||
"summary": "Set new password using the old password",
|
||||
"parameters": [
|
||||
{
|
||||
"description": " ",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/models.ChangePasswordRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Password successfully changed"
|
||||
},
|
||||
"403": {
|
||||
"description": "Invalid old password"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/auth/login": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
@@ -391,6 +429,24 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.ChangePasswordRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"old_password",
|
||||
"password"
|
||||
],
|
||||
"properties": {
|
||||
"old_password": {
|
||||
"type": "string"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
"totp": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.LoginRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -445,7 +501,7 @@ const docTemplate = `{
|
||||
"email": {
|
||||
"type": "string"
|
||||
},
|
||||
"log_out_accounts": {
|
||||
"log_out_sessions": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"password": {
|
||||
|
||||
@@ -34,6 +34,44 @@
|
||||
"responses": {}
|
||||
}
|
||||
},
|
||||
"/auth/changePassword": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"JWT": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Auth"
|
||||
],
|
||||
"summary": "Set new password using the old password",
|
||||
"parameters": [
|
||||
{
|
||||
"description": " ",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/models.ChangePasswordRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Password successfully changed"
|
||||
},
|
||||
"403": {
|
||||
"description": "Invalid old password"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/auth/login": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
@@ -387,6 +425,24 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.ChangePasswordRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"old_password",
|
||||
"password"
|
||||
],
|
||||
"properties": {
|
||||
"old_password": {
|
||||
"type": "string"
|
||||
},
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
"totp": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"models.LoginRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -441,7 +497,7 @@
|
||||
"email": {
|
||||
"type": "string"
|
||||
},
|
||||
"log_out_accounts": {
|
||||
"log_out_sessions": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"password": {
|
||||
|
||||
@@ -5,6 +5,18 @@ definitions:
|
||||
healthy:
|
||||
type: boolean
|
||||
type: object
|
||||
models.ChangePasswordRequest:
|
||||
properties:
|
||||
old_password:
|
||||
type: string
|
||||
password:
|
||||
type: string
|
||||
totp:
|
||||
type: string
|
||||
required:
|
||||
- old_password
|
||||
- password
|
||||
type: object
|
||||
models.LoginRequest:
|
||||
properties:
|
||||
password:
|
||||
@@ -38,7 +50,7 @@ definitions:
|
||||
properties:
|
||||
email:
|
||||
type: string
|
||||
log_out_accounts:
|
||||
log_out_sessions:
|
||||
type: boolean
|
||||
password:
|
||||
type: string
|
||||
@@ -125,6 +137,29 @@ paths:
|
||||
summary: Change account password
|
||||
tags:
|
||||
- Account
|
||||
/auth/changePassword:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
parameters:
|
||||
- description: ' '
|
||||
in: body
|
||||
name: request
|
||||
required: true
|
||||
schema:
|
||||
$ref: '#/definitions/models.ChangePasswordRequest'
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Password successfully changed
|
||||
"403":
|
||||
description: Invalid old password
|
||||
security:
|
||||
- JWT: []
|
||||
summary: Set new password using the old password
|
||||
tags:
|
||||
- Auth
|
||||
/auth/login:
|
||||
post:
|
||||
consumes:
|
||||
|
||||
@@ -38,6 +38,7 @@ type AuthController interface {
|
||||
Refresh(c *gin.Context)
|
||||
PasswordResetBegin(c *gin.Context)
|
||||
PasswordResetComplete(c *gin.Context)
|
||||
ChangePassword(c *gin.Context)
|
||||
Router
|
||||
}
|
||||
|
||||
@@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := a.auth.Login(request.Body)
|
||||
response, err := a.auth.Login(request.User, request.Body)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, errs.ErrForbidden) {
|
||||
@@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
|
||||
|
||||
response, err := a.auth.Refresh(request.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, errs.ErrTokenExpired) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
|
||||
} else if errors.Is(err, errs.ErrTokenInvalid) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
|
||||
} else if errors.Is(err, errs.ErrWrongTokenType) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
|
||||
} else if errors.Is(err, errs.ErrSessionNotFound) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
|
||||
} else if errors.Is(err, errs.ErrSessionTerminated) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
|
||||
if utils.ErrorIsOneOf(
|
||||
err,
|
||||
errs.ErrTokenExpired,
|
||||
errs.ErrTokenInvalid,
|
||||
errs.ErrInvalidToken,
|
||||
errs.ErrWrongTokenType,
|
||||
errs.ErrSessionNotFound,
|
||||
errs.ErrSessionTerminated,
|
||||
) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
} else {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.JSON(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := a.auth.RegistrationComplete(request.Body)
|
||||
response, err := a.auth.RegistrationComplete(request.User, request.Body)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, errs.ErrForbidden) {
|
||||
@@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// @Summary Set new password using the old password
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security JWT
|
||||
// @Param request body models.ChangePasswordRequest true " "
|
||||
// @Success 200 "Password successfully changed"
|
||||
// @Failure 403 "Invalid old password"
|
||||
// @Router /auth/changePassword [post]
|
||||
func (a *authControllerImpl) ChangePassword(c *gin.Context) {
|
||||
request, ok := utils.GetRequest[models.ChangePasswordRequest](c)
|
||||
if !ok {
|
||||
c.Status(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := a.auth.ChangePassword(request.Body, request.User)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, errs.ErrForbidden) {
|
||||
c.Status(http.StatusForbidden)
|
||||
} else {
|
||||
c.Status(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
|
||||
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
|
||||
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
|
||||
@@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
|
||||
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
|
||||
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
|
||||
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
|
||||
group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct {
|
||||
TXlessQueries Queries
|
||||
TX pgx.Tx
|
||||
TXQueries Queries
|
||||
isCommited bool
|
||||
}
|
||||
|
||||
func NewDbHelper(dbContext DbContext) DbHelper {
|
||||
@@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error {
|
||||
errCommit := d.TX.Commit(d.CTX)
|
||||
|
||||
if errCommit != nil {
|
||||
errRollback := d.TX.Rollback(d.CTX)
|
||||
|
||||
if errRollback != nil {
|
||||
return errRollback
|
||||
d.isCommited = true
|
||||
}
|
||||
|
||||
return errCommit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback implements DbHelperTransaction.
|
||||
func (d *dbHelperTransactionImpl) Rollback() error {
|
||||
err := d.TX.Rollback(d.CTX)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.isCommited {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.TX.Rollback(d.CTX);
|
||||
}
|
||||
|
||||
// RollbackOnError implements DbHelperTransaction.
|
||||
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
|
||||
if err != nil {
|
||||
if d.isCommited || err == nil {
|
||||
return d.Rollback()
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec
|
||||
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many
|
||||
UPDATE sessions
|
||||
SET terminated = TRUE
|
||||
FROM users
|
||||
WHERE sessions.user_id = users.id AND users.username = $1::text
|
||||
RETURNING sessions.guid
|
||||
`
|
||||
|
||||
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error {
|
||||
_, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username)
|
||||
return err
|
||||
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) {
|
||||
rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []pgtype.UUID
|
||||
for rows.Next() {
|
||||
var guid pgtype.UUID
|
||||
if err := rows.Scan(&guid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, guid)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateBannedUser = `-- name: UpdateBannedUser :exec
|
||||
|
||||
@@ -31,7 +31,7 @@ var (
|
||||
ErrServerError = errors.New("Internal server error")
|
||||
|
||||
ErrTokenExpired = errors.New("Token is expired")
|
||||
ErrTokenInvalid = errors.New("Token is invalid")
|
||||
ErrTokenInvalid = ErrInvalidToken
|
||||
ErrWrongTokenType = errors.New("Invalid token type")
|
||||
ErrSessionNotFound = errors.New("Could not find session in database")
|
||||
ErrSessionTerminated = errors.New("Session is terminated")
|
||||
|
||||
@@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc
|
||||
}
|
||||
return
|
||||
} else {
|
||||
c.Set("session_info", sessionInfo)
|
||||
c.Set("session_info", *sessionInfo)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
|
||||
@@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct {
|
||||
type PasswordResetCompleteResponse struct {
|
||||
Tokens
|
||||
}
|
||||
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"password" binding:"required" validate:"password"`
|
||||
TOTP string `json:"totp"`
|
||||
}
|
||||
|
||||
type ChangePasswordResponse struct {
|
||||
Tokens
|
||||
}
|
||||
|
||||
@@ -40,11 +40,12 @@ import (
|
||||
|
||||
type AuthService interface {
|
||||
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
||||
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||
Login(request models.LoginRequest) (*models.LoginResponse, error)
|
||||
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||
Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
|
||||
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
|
||||
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
|
||||
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
|
||||
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
|
||||
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
|
||||
}
|
||||
|
||||
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
|
||||
panic("Failed to cache terminated session: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
|
||||
}
|
||||
|
||||
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
|
||||
return authService
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
|
||||
|
||||
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to terminate older sessions for user trying to log in",
|
||||
zap.String("username", username),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := a.redis.Pipeline()
|
||||
for _, guid := range sessionGuids {
|
||||
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache terminated sessions",
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
|
||||
|
||||
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
|
||||
UserID: userID,
|
||||
Name: utils.NewPointer(cinfo.UserAgent),
|
||||
Platform: utils.NewPointer(cinfo.UserAgent),
|
||||
LatestIp: utils.NewPointer(cinfo.IP),
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to add session to database",
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.log.Info(
|
||||
"Registered a new user session",
|
||||
zap.String("username", cinfo.Username),
|
||||
zap.String("session", cinfo.Session))
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
|
||||
|
||||
var occupationStatus database.CheckUserRegistrationAvailabilityRow
|
||||
@@ -95,14 +147,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
if isInProgress, err := a.redis.Get(
|
||||
isInProgress, err := a.redis.Get(
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress",
|
||||
request.Email),
|
||||
).Bool(); err != nil {
|
||||
).Bool(); if err != nil {
|
||||
if err != redis.Nil {
|
||||
a.log.Error(
|
||||
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
|
||||
@@ -143,7 +194,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
true,
|
||||
time.Duration(10 * time.Minute), // XXX: magic number
|
||||
time.Duration(10*time.Minute), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
||||
@@ -248,7 +299,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
context.TODO(),
|
||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||
true,
|
||||
time.Duration(10 * time.Minute), // XXX: magic number
|
||||
time.Duration(10*time.Minute), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache registration_in_progress state for email",
|
||||
@@ -267,11 +318,10 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
||||
func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
||||
|
||||
var user database.User
|
||||
var profile database.Profile
|
||||
var session database.Session
|
||||
var confirmationCode database.ConfirmationCode
|
||||
var accessToken, refreshToken string
|
||||
var err error
|
||||
@@ -360,7 +410,6 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
if err != nil {
|
||||
a.log.Error("Failed to create profile for user",
|
||||
zap.String("username", user.Username),
|
||||
|
||||
)
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
@@ -374,19 +423,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
// TODO: session info
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
UserID: user.ID,
|
||||
Name: utils.NewPointer("First device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create a new session during registration, rolling back registration",
|
||||
zap.String("username", user.Username),
|
||||
zap.Error(err))
|
||||
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
|
||||
a.log.Error("", zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
@@ -421,9 +459,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
||||
}
|
||||
|
||||
// TODO: totp
|
||||
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) {
|
||||
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
||||
var userRow database.GetValidUserByLoginCredentialsRow
|
||||
var session database.Session
|
||||
var err error
|
||||
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
||||
}
|
||||
|
||||
// Until release 4, only 1 session at a time is supported
|
||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil {
|
||||
err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to terminate older sessions for user trying to log in",
|
||||
zap.String("username", request.Username),
|
||||
"Failed to terminate user's sessions during login",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
// TODO: use actual values for session metadata
|
||||
UserID: userRow.ID,
|
||||
Name: utils.NewPointer("New device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create session for a new login",
|
||||
zap.String("username", userRow.Username),
|
||||
zap.Error(err))
|
||||
session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
|
||||
a.log.Error("", zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
|
||||
@@ -509,7 +534,8 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
||||
|
||||
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
||||
|
||||
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
|
||||
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
|
||||
if err != nil {
|
||||
|
||||
if utils.ErrorIsOneOf(
|
||||
err,
|
||||
@@ -532,7 +558,8 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
||||
sessionInfo.Username,
|
||||
sessionInfo.Session,
|
||||
sessionInfo.Role,
|
||||
); if err != nil {
|
||||
)
|
||||
if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to generate tokens for user during refresh",
|
||||
zap.String("username", sessionInfo.Username),
|
||||
@@ -573,7 +600,8 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
return nil, errs.ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
|
||||
claims, ok := token.Claims.(*dto.UserClaims)
|
||||
if ok && token.Valid {
|
||||
|
||||
if claims.Type != tokenType {
|
||||
return nil, errs.ErrWrongTokenType
|
||||
@@ -613,7 +641,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
||||
ctx,
|
||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
||||
*session.Terminated,
|
||||
time.Duration(8 * time.Hour), // XXX: magic number
|
||||
time.Duration(8*time.Hour), // XXX: magic number
|
||||
).Err(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to cache session's is_terminated state",
|
||||
@@ -828,22 +856,20 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
}
|
||||
|
||||
if request.LogOutSessions {
|
||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil {
|
||||
err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to log out older sessions as part of user password reset",
|
||||
zap.String("email", request.Email),
|
||||
zap.String("username", user.Username),
|
||||
"Failed to terminate user's sessions during login",
|
||||
zap.Error(err))
|
||||
return nil, errs.ErrServerError
|
||||
}
|
||||
}
|
||||
|
||||
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||
UserID: user.ID,
|
||||
Name: utils.NewPointer("First device"),
|
||||
Platform: utils.NewPointer("Unknown"),
|
||||
LatestIp: utils.NewPointer("Unknown"),
|
||||
}); err != nil {
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to create new session for user as part of user password reset",
|
||||
zap.String("email", request.Email),
|
||||
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
|
||||
|
||||
var err error
|
||||
|
||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to open a transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
defer helper.RollbackOnError(err)
|
||||
|
||||
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to get user login information",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
|
||||
a.log.Warn(
|
||||
"Provided invalid old password while changing password",
|
||||
zap.String("username", uinfo.Username))
|
||||
return false, errs.ErrForbidden
|
||||
}
|
||||
|
||||
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to hash new password while changing password",
|
||||
zap.String("username", uinfo.Username),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
||||
Username: uinfo.Username,
|
||||
PasswordHash: newPasswordHash,
|
||||
}); if err != nil {
|
||||
a.log.Error(
|
||||
"Failed to save new password into database",
|
||||
zap.String("username", uinfo.Username),
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
if err := helper.Commit(); err != nil {
|
||||
a.log.Error(
|
||||
"Failed to commit transaction",
|
||||
zap.Error(err))
|
||||
return false, errs.ErrServerError
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -22,8 +22,8 @@ import "errors"
|
||||
func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
|
||||
for _, ignore := range ignoreErrors {
|
||||
if errors.Is(err, ignore) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -259,11 +259,12 @@ WHERE
|
||||
terminated IS TRUE AND
|
||||
last_refresh_exp_time > CURRENT_TIMESTAMP;
|
||||
|
||||
;-- name: TerminateAllSessionsForUserByUsername :exec
|
||||
;-- name: TerminateAllSessionsForUserByUsername :many
|
||||
UPDATE sessions
|
||||
SET terminated = TRUE
|
||||
FROM users
|
||||
WHERE sessions.user_id = users.id AND users.username = @username::text;
|
||||
WHERE sessions.user_id = users.id AND users.username = @username::text
|
||||
RETURNING sessions.guid;
|
||||
|
||||
;-- name: PruneTerminatedSessions :exec
|
||||
DELETE FROM sessions
|
||||
|
||||
@@ -62,8 +62,8 @@ CREATE TABLE IF NOT EXISTS "sessions" (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
guid UUID NOT NULL DEFAULT gen_random_uuid(),
|
||||
name VARCHAR(100),
|
||||
platform VARCHAR(32),
|
||||
name VARCHAR(175),
|
||||
platform VARCHAR(175),
|
||||
latest_ip VARCHAR(16),
|
||||
login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',
|
||||
|
||||
Reference in New Issue
Block a user