feat: add change password endpoint using old password;
feat: implement change password service method with validation; fix: correct ErrorIsOneOf function logic to return true on match; refactor: rename 'log_out_accounts' to 'log_out_sessions' for clarity; refactor: update session termination to return GUIDs and cache in Redis; fix: ensure RollbackOnError only rolls back uncommitted transactions; fix: handle transaction commit errors properly in dbHelper; refactor: add helper methods for session termination and registration; refactor: pass client info to login and registration complete methods; fix: improve token validation error handling in refresh endpoint; refactor: update auth middleware to set session info correctly; chore: remove unused ClientInfo DTO; fix: correct password reset complete to use session termination helper; refactor: adjust database queries for session management; chore: update SQL schema and queries for sessions; docs: update swagger docs with new endpoint and model changes
This commit is contained in:
@@ -38,6 +38,44 @@ const docTemplate = `{
|
|||||||
"responses": {}
|
"responses": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/auth/changePassword": {
|
||||||
|
"post": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"JWT": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"Auth"
|
||||||
|
],
|
||||||
|
"summary": "Set new password using the old password",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"description": " ",
|
||||||
|
"name": "request",
|
||||||
|
"in": "body",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/models.ChangePasswordRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Password successfully changed"
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Invalid old password"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/auth/login": {
|
"/auth/login": {
|
||||||
"post": {
|
"post": {
|
||||||
"consumes": [
|
"consumes": [
|
||||||
@@ -391,6 +429,24 @@ const docTemplate = `{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"models.ChangePasswordRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"old_password",
|
||||||
|
"password"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"old_password": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"totp": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"models.LoginRequest": {
|
"models.LoginRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@@ -445,7 +501,7 @@ const docTemplate = `{
|
|||||||
"email": {
|
"email": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"log_out_accounts": {
|
"log_out_sessions": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"password": {
|
"password": {
|
||||||
|
|||||||
@@ -34,6 +34,44 @@
|
|||||||
"responses": {}
|
"responses": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/auth/changePassword": {
|
||||||
|
"post": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"JWT": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"Auth"
|
||||||
|
],
|
||||||
|
"summary": "Set new password using the old password",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"description": " ",
|
||||||
|
"name": "request",
|
||||||
|
"in": "body",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/models.ChangePasswordRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Password successfully changed"
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Invalid old password"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/auth/login": {
|
"/auth/login": {
|
||||||
"post": {
|
"post": {
|
||||||
"consumes": [
|
"consumes": [
|
||||||
@@ -387,6 +425,24 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"models.ChangePasswordRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"old_password",
|
||||||
|
"password"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"old_password": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"totp": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"models.LoginRequest": {
|
"models.LoginRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@@ -441,7 +497,7 @@
|
|||||||
"email": {
|
"email": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"log_out_accounts": {
|
"log_out_sessions": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"password": {
|
"password": {
|
||||||
|
|||||||
@@ -5,6 +5,18 @@ definitions:
|
|||||||
healthy:
|
healthy:
|
||||||
type: boolean
|
type: boolean
|
||||||
type: object
|
type: object
|
||||||
|
models.ChangePasswordRequest:
|
||||||
|
properties:
|
||||||
|
old_password:
|
||||||
|
type: string
|
||||||
|
password:
|
||||||
|
type: string
|
||||||
|
totp:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- old_password
|
||||||
|
- password
|
||||||
|
type: object
|
||||||
models.LoginRequest:
|
models.LoginRequest:
|
||||||
properties:
|
properties:
|
||||||
password:
|
password:
|
||||||
@@ -38,7 +50,7 @@ definitions:
|
|||||||
properties:
|
properties:
|
||||||
email:
|
email:
|
||||||
type: string
|
type: string
|
||||||
log_out_accounts:
|
log_out_sessions:
|
||||||
type: boolean
|
type: boolean
|
||||||
password:
|
password:
|
||||||
type: string
|
type: string
|
||||||
@@ -125,6 +137,29 @@ paths:
|
|||||||
summary: Change account password
|
summary: Change account password
|
||||||
tags:
|
tags:
|
||||||
- Account
|
- Account
|
||||||
|
/auth/changePassword:
|
||||||
|
post:
|
||||||
|
consumes:
|
||||||
|
- application/json
|
||||||
|
parameters:
|
||||||
|
- description: ' '
|
||||||
|
in: body
|
||||||
|
name: request
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/models.ChangePasswordRequest'
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Password successfully changed
|
||||||
|
"403":
|
||||||
|
description: Invalid old password
|
||||||
|
security:
|
||||||
|
- JWT: []
|
||||||
|
summary: Set new password using the old password
|
||||||
|
tags:
|
||||||
|
- Auth
|
||||||
/auth/login:
|
/auth/login:
|
||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type AuthController interface {
|
|||||||
Refresh(c *gin.Context)
|
Refresh(c *gin.Context)
|
||||||
PasswordResetBegin(c *gin.Context)
|
PasswordResetBegin(c *gin.Context)
|
||||||
PasswordResetComplete(c *gin.Context)
|
PasswordResetComplete(c *gin.Context)
|
||||||
|
ChangePassword(c *gin.Context)
|
||||||
Router
|
Router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ func (a *authControllerImpl) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := a.auth.Login(request.Body)
|
response, err := a.auth.Login(request.User, request.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errs.ErrForbidden) {
|
if errors.Is(err, errs.ErrForbidden) {
|
||||||
@@ -155,18 +156,18 @@ func (a *authControllerImpl) Refresh(c *gin.Context) {
|
|||||||
|
|
||||||
response, err := a.auth.Refresh(request.Body)
|
response, err := a.auth.Refresh(request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errs.ErrTokenExpired) {
|
if utils.ErrorIsOneOf(
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is expired"})
|
err,
|
||||||
} else if errors.Is(err, errs.ErrTokenInvalid) {
|
errs.ErrTokenExpired,
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token is invalid"})
|
errs.ErrTokenInvalid,
|
||||||
} else if errors.Is(err, errs.ErrWrongTokenType) {
|
errs.ErrInvalidToken,
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token type"})
|
errs.ErrWrongTokenType,
|
||||||
} else if errors.Is(err, errs.ErrSessionNotFound) {
|
errs.ErrSessionNotFound,
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Could not find session in database"})
|
errs.ErrSessionTerminated,
|
||||||
} else if errors.Is(err, errs.ErrSessionTerminated) {
|
) {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session is terminated"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||||
} else {
|
} else {
|
||||||
c.Status(http.StatusInternalServerError)
|
c.JSON(http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -221,7 +222,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := a.auth.RegistrationComplete(request.Body)
|
response, err := a.auth.RegistrationComplete(request.User, request.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, errs.ErrForbidden) {
|
if errors.Is(err, errs.ErrForbidden) {
|
||||||
@@ -237,6 +238,36 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, response)
|
c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @Summary Set new password using the old password
|
||||||
|
// @Tags Auth
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security JWT
|
||||||
|
// @Param request body models.ChangePasswordRequest true " "
|
||||||
|
// @Success 200 "Password successfully changed"
|
||||||
|
// @Failure 403 "Invalid old password"
|
||||||
|
// @Router /auth/changePassword [post]
|
||||||
|
func (a *authControllerImpl) ChangePassword(c *gin.Context) {
|
||||||
|
request, ok := utils.GetRequest[models.ChangePasswordRequest](c)
|
||||||
|
if !ok {
|
||||||
|
c.Status(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := a.auth.ChangePassword(request.Body, request.User)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errs.ErrForbidden) {
|
||||||
|
c.Status(http.StatusForbidden)
|
||||||
|
} else {
|
||||||
|
c.Status(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
|
func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
|
||||||
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
|
group.POST("/registrationBegin", middleware.RequestMiddleware[models.RegistrationBeginRequest](enums.GuestRole), a.RegistrationBegin)
|
||||||
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
|
group.POST("/registrationComplete", middleware.RequestMiddleware[models.RegistrationCompleteRequest](enums.GuestRole), a.RegistrationComplete)
|
||||||
@@ -244,4 +275,5 @@ func (a *authControllerImpl) RegisterRoutes(group *gin.RouterGroup) {
|
|||||||
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
|
group.POST("/refresh", middleware.RequestMiddleware[models.RefreshRequest](enums.GuestRole), a.Refresh)
|
||||||
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
|
group.POST("/passwordResetBegin", middleware.RequestMiddleware[models.PasswordResetBeginRequest](enums.GuestRole), a.PasswordResetBegin)
|
||||||
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
|
group.POST("/passwordResetComplete", middleware.RequestMiddleware[models.PasswordResetCompleteRequest](enums.GuestRole), a.PasswordResetComplete)
|
||||||
|
group.POST("/changePassword", middleware.RequestMiddleware[models.ChangePasswordRequest](enums.UserRole), a.ChangePassword)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type dbHelperTransactionImpl struct {
|
|||||||
TXlessQueries Queries
|
TXlessQueries Queries
|
||||||
TX pgx.Tx
|
TX pgx.Tx
|
||||||
TXQueries Queries
|
TXQueries Queries
|
||||||
|
isCommited bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDbHelper(dbContext DbContext) DbHelper {
|
func NewDbHelper(dbContext DbContext) DbHelper {
|
||||||
@@ -79,30 +80,24 @@ func (d *dbHelperTransactionImpl) Commit() error {
|
|||||||
errCommit := d.TX.Commit(d.CTX)
|
errCommit := d.TX.Commit(d.CTX)
|
||||||
|
|
||||||
if errCommit != nil {
|
if errCommit != nil {
|
||||||
errRollback := d.TX.Rollback(d.CTX)
|
d.isCommited = true
|
||||||
|
|
||||||
if errRollback != nil {
|
|
||||||
return errRollback
|
|
||||||
}
|
|
||||||
|
|
||||||
return errCommit
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return errCommit
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback implements DbHelperTransaction.
|
// Rollback implements DbHelperTransaction.
|
||||||
func (d *dbHelperTransactionImpl) Rollback() error {
|
func (d *dbHelperTransactionImpl) Rollback() error {
|
||||||
err := d.TX.Rollback(d.CTX)
|
if d.isCommited {
|
||||||
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return d.TX.Rollback(d.CTX);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RollbackOnError implements DbHelperTransaction.
|
// RollbackOnError implements DbHelperTransaction.
|
||||||
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
|
func (d *dbHelperTransactionImpl) RollbackOnError(err error) error {
|
||||||
if err != nil {
|
if d.isCommited || err == nil {
|
||||||
return d.Rollback()
|
return d.Rollback()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -852,16 +852,32 @@ func (q *Queries) PruneTerminatedSessions(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :exec
|
const terminateAllSessionsForUserByUsername = `-- name: TerminateAllSessionsForUserByUsername :many
|
||||||
UPDATE sessions
|
UPDATE sessions
|
||||||
SET terminated = TRUE
|
SET terminated = TRUE
|
||||||
FROM users
|
FROM users
|
||||||
WHERE sessions.user_id = users.id AND users.username = $1::text
|
WHERE sessions.user_id = users.id AND users.username = $1::text
|
||||||
|
RETURNING sessions.guid
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) error {
|
func (q *Queries) TerminateAllSessionsForUserByUsername(ctx context.Context, username string) ([]pgtype.UUID, error) {
|
||||||
_, err := q.db.Exec(ctx, terminateAllSessionsForUserByUsername, username)
|
rows, err := q.db.Query(ctx, terminateAllSessionsForUserByUsername, username)
|
||||||
return err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []pgtype.UUID
|
||||||
|
for rows.Next() {
|
||||||
|
var guid pgtype.UUID
|
||||||
|
if err := rows.Scan(&guid); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, guid)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const updateBannedUser = `-- name: UpdateBannedUser :exec
|
const updateBannedUser = `-- name: UpdateBannedUser :exec
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ var (
|
|||||||
ErrServerError = errors.New("Internal server error")
|
ErrServerError = errors.New("Internal server error")
|
||||||
|
|
||||||
ErrTokenExpired = errors.New("Token is expired")
|
ErrTokenExpired = errors.New("Token is expired")
|
||||||
ErrTokenInvalid = errors.New("Token is invalid")
|
ErrTokenInvalid = ErrInvalidToken
|
||||||
ErrWrongTokenType = errors.New("Invalid token type")
|
ErrWrongTokenType = errors.New("Invalid token type")
|
||||||
ErrSessionNotFound = errors.New("Could not find session in database")
|
ErrSessionNotFound = errors.New("Could not find session in database")
|
||||||
ErrSessionTerminated = errors.New("Session is terminated")
|
ErrSessionTerminated = errors.New("Session is terminated")
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func AuthMiddleware(log *zap.Logger, auth services.AuthService) gin.HandlerFunc
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.Set("session_info", sessionInfo)
|
c.Set("session_info", *sessionInfo)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,3 +72,13 @@ type PasswordResetCompleteRequest struct {
|
|||||||
type PasswordResetCompleteResponse struct {
|
type PasswordResetCompleteResponse struct {
|
||||||
Tokens
|
Tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChangePasswordRequest struct {
|
||||||
|
OldPassword string `json:"old_password" binding:"required"`
|
||||||
|
NewPassword string `json:"password" binding:"required" validate:"password"`
|
||||||
|
TOTP string `json:"totp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChangePasswordResponse struct {
|
||||||
|
Tokens
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,19 +40,20 @@ import (
|
|||||||
|
|
||||||
type AuthService interface {
|
type AuthService interface {
|
||||||
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
RegistrationBegin(request models.RegistrationBeginRequest) (bool, error)
|
||||||
RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error)
|
||||||
Login(request models.LoginRequest) (*models.LoginResponse, error)
|
Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error)
|
||||||
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
|
Refresh(request models.RefreshRequest) (*models.RefreshResponse, error)
|
||||||
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
|
PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error)
|
||||||
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
|
PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error)
|
||||||
|
ChangePassword(request models.ChangePasswordRequest, cinfo dto.ClientInfo) (bool, error)
|
||||||
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
|
ValidateToken(token string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type authServiceImpl struct {
|
type authServiceImpl struct {
|
||||||
log *zap.Logger
|
log *zap.Logger
|
||||||
dbctx database.DbContext
|
dbctx database.DbContext
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
smtp SmtpService
|
smtp SmtpService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
|
func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService {
|
||||||
@@ -75,10 +76,61 @@ func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.C
|
|||||||
panic("Failed to cache terminated session: " + err.Error())
|
panic("Failed to cache terminated session: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
panic("Failed to execute redis pipeline request for caching terminated sessions: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
|
_log.Info("Cached terminated sessions' GUIDs in Redis", zap.Int("amount", len(guids)))
|
||||||
return authService
|
return authService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *authServiceImpl) terminateAllSessionsForUser(ctx context.Context, username string, queries *database.Queries) error {
|
||||||
|
|
||||||
|
sessionGuids, err := queries.TerminateAllSessionsForUserByUsername(ctx, username); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to terminate older sessions for user trying to log in",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := a.redis.Pipeline()
|
||||||
|
for _, guid := range sessionGuids {
|
||||||
|
pipe.Set(ctx, fmt.Sprintf("session::%s::is_terminated", guid), true, time.Duration(8 * time.Hour)) // XXX: magic number
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to cache terminated sessions",
|
||||||
|
zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *authServiceImpl) registerSession(ctx context.Context, userID int64, cinfo dto.ClientInfo, queries *database.Queries) (*database.Session, error) {
|
||||||
|
|
||||||
|
session, err := queries.CreateSession(ctx, database.CreateSessionParams{
|
||||||
|
UserID: userID,
|
||||||
|
Name: utils.NewPointer(cinfo.UserAgent),
|
||||||
|
Platform: utils.NewPointer(cinfo.UserAgent),
|
||||||
|
LatestIp: utils.NewPointer(cinfo.IP),
|
||||||
|
}); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to add session to database",
|
||||||
|
zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.log.Info(
|
||||||
|
"Registered a new user session",
|
||||||
|
zap.String("username", cinfo.Username),
|
||||||
|
zap.String("session", cinfo.Session))
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
|
func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) {
|
||||||
|
|
||||||
var occupationStatus database.CheckUserRegistrationAvailabilityRow
|
var occupationStatus database.CheckUserRegistrationAvailabilityRow
|
||||||
@@ -95,14 +147,13 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return false, errs.ErrServerError
|
return false, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
defer helper.RollbackOnError(err)
|
||||||
|
|
||||||
defer helper.RollbackOnError(err)
|
isInProgress, err := a.redis.Get(
|
||||||
|
|
||||||
if isInProgress, err := a.redis.Get(
|
|
||||||
context.TODO(),
|
context.TODO(),
|
||||||
fmt.Sprintf("email::%s::registration_in_progress",
|
fmt.Sprintf("email::%s::registration_in_progress",
|
||||||
request.Email),
|
request.Email),
|
||||||
).Bool(); err != nil {
|
).Bool(); if err != nil {
|
||||||
if err != redis.Nil {
|
if err != redis.Nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
|
"Failed to look up cached registration_in_progress state of email as part of registration procedure",
|
||||||
@@ -119,7 +170,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
|
if occupationStatus, err = db.TXQueries.CheckUserRegistrationAvailability(db.CTX, database.CheckUserRegistrationAvailabilityParams{
|
||||||
Email: request.Email,
|
Email: request.Email,
|
||||||
Username: request.Username,
|
Username: request.Username,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
@@ -143,7 +194,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
context.TODO(),
|
context.TODO(),
|
||||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||||
true,
|
true,
|
||||||
time.Duration(10 * time.Minute), // XXX: magic number
|
time.Duration(10*time.Minute), // XXX: magic number
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
"Failed to falsely set cache registration_in_progress state for email as a measure to prevent email enumeration",
|
||||||
@@ -161,7 +212,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
} else {
|
} else {
|
||||||
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
|
if _, err := db.TXQueries.DeleteUnverifiedAccountsHavingUsernameOrEmail(db.CTX, database.DeleteUnverifiedAccountsHavingUsernameOrEmailParams{
|
||||||
Username: request.Username,
|
Username: request.Username,
|
||||||
Email: request.Email,
|
Email: request.Email,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to purge unverified accounts as part of registration",
|
"Failed to purge unverified accounts as part of registration",
|
||||||
@@ -184,8 +235,8 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
|
if _, err = db.TXQueries.CreateLoginInformation(db.CTX, database.CreateLoginInformationParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Email: utils.NewPointer(request.Email),
|
Email: utils.NewPointer(request.Email),
|
||||||
PasswordHash: passwordHash, // Hashed in database
|
PasswordHash: passwordHash, // Hashed in database
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|
||||||
@@ -208,9 +259,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
CodeType: int32(enums.RegistrationCodeType),
|
CodeType: int32(enums.RegistrationCodeType),
|
||||||
CodeHash: generatedCodeHash, // Hashed in database
|
CodeHash: generatedCodeHash, // Hashed in database
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
a.log.Error("Failed to add registration code to database", zap.Error(err))
|
a.log.Error("Failed to add registration code to database", zap.Error(err))
|
||||||
return false, errs.ErrServerError
|
return false, errs.ErrServerError
|
||||||
@@ -248,7 +299,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
context.TODO(),
|
context.TODO(),
|
||||||
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
fmt.Sprintf("email::%s::registration_in_progress", request.Email),
|
||||||
true,
|
true,
|
||||||
time.Duration(10 * time.Minute), // XXX: magic number
|
time.Duration(10*time.Minute), // XXX: magic number
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to cache registration_in_progress state for email",
|
"Failed to cache registration_in_progress state for email",
|
||||||
@@ -267,12 +318,11 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
func (a *authServiceImpl) RegistrationComplete(cinfo dto.ClientInfo, request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) {
|
||||||
|
|
||||||
var user database.User
|
var user database.User
|
||||||
var profile database.Profile
|
var profile database.Profile
|
||||||
var session database.Session
|
var confirmationCode database.ConfirmationCode
|
||||||
var confirmationCode database.ConfirmationCode
|
|
||||||
var accessToken, refreshToken string
|
var accessToken, refreshToken string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -304,9 +354,9 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
}
|
}
|
||||||
|
|
||||||
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
confirmationCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
CodeType: int32(enums.RegistrationCodeType),
|
CodeType: int32(enums.RegistrationCodeType),
|
||||||
Code: request.VerificationCode,
|
Code: request.VerificationCode,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -327,7 +377,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
||||||
ID: confirmationCode.ID,
|
ID: confirmationCode.ID,
|
||||||
Used: utils.NewPointer(true),
|
Used: utils.NewPointer(true),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -340,27 +390,26 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{
|
||||||
ID: user.ID,
|
ID: user.ID,
|
||||||
Verified: utils.NewPointer(true),
|
Verified: utils.NewPointer(true),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error("Failed to update verified status for user",
|
a.log.Error("Failed to update verified status for user",
|
||||||
zap.String("username", user.Username),
|
zap.String("username", user.Username),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Name: request.Name,
|
Name: request.Name,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error("Failed to create profile for user",
|
a.log.Error("Failed to create profile for user",
|
||||||
zap.String("username", user.Username),
|
zap.String("username", user.Username),
|
||||||
|
|
||||||
)
|
)
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
@@ -369,24 +418,13 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error("Failed to create profile settings for user",
|
a.log.Error("Failed to create profile settings for user",
|
||||||
zap.String("username", user.Username),
|
zap.String("username", user.Username),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: session info
|
session, err := a.registerSession(context.TODO(), user.ID, cinfo, &db.TXQueries); if err != nil {
|
||||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
a.log.Error("", zap.Error(err))
|
||||||
UserID: user.ID,
|
|
||||||
Name: utils.NewPointer("First device"),
|
|
||||||
Platform: utils.NewPointer("Unknown"),
|
|
||||||
LatestIp: utils.NewPointer("Unknown"),
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
a.log.Error(
|
|
||||||
"Failed to create a new session during registration, rolling back registration",
|
|
||||||
zap.String("username", user.Username),
|
|
||||||
zap.Error(err))
|
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -413,7 +451,7 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
zap.String("username", request.Username))
|
zap.String("username", request.Username))
|
||||||
|
|
||||||
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
|
response := models.RegistrationCompleteResponse{Tokens: models.Tokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@@ -421,9 +459,8 @@ func (a *authServiceImpl) RegistrationComplete(request models.RegistrationComple
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: totp
|
// TODO: totp
|
||||||
func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) {
|
func (a *authServiceImpl) Login(cinfo dto.ClientInfo, request models.LoginRequest) (*models.LoginResponse, error) {
|
||||||
var userRow database.GetValidUserByLoginCredentialsRow
|
var userRow database.GetValidUserByLoginCredentialsRow
|
||||||
var session database.Session
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx)
|
||||||
@@ -433,7 +470,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
defer helper.RollbackOnError(err)
|
defer helper.RollbackOnError(err)
|
||||||
|
|
||||||
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
userRow, err = db.TXQueries.GetValidUserByLoginCredentials(db.CTX, database.GetValidUserByLoginCredentialsParams{
|
||||||
Username: request.Username,
|
Username: request.Username,
|
||||||
@@ -458,27 +495,15 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Until release 4, only 1 session at a time is supported
|
// Until release 4, only 1 session at a time is supported
|
||||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, request.Username); err != nil {
|
err = a.terminateAllSessionsForUser(context.TODO(), request.Username, &db.TXQueries); if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to terminate older sessions for user trying to log in",
|
"Failed to terminate user's sessions during login",
|
||||||
zap.String("username", request.Username),
|
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
session, err := a.registerSession(context.TODO(), userRow.ID, cinfo, &db.TXQueries); if err != nil {
|
||||||
// TODO: use actual values for session metadata
|
a.log.Error("", zap.Error(err))
|
||||||
UserID: userRow.ID,
|
|
||||||
Name: utils.NewPointer("New device"),
|
|
||||||
Platform: utils.NewPointer("Unknown"),
|
|
||||||
LatestIp: utils.NewPointer("Unknown"),
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
a.log.Error(
|
|
||||||
"Failed to create session for a new login",
|
|
||||||
zap.String("username", userRow.Username),
|
|
||||||
zap.Error(err))
|
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +525,7 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|||||||
}
|
}
|
||||||
|
|
||||||
response := models.LoginResponse{Tokens: models.Tokens{
|
response := models.LoginResponse{Tokens: models.Tokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@@ -509,7 +534,8 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo
|
|||||||
|
|
||||||
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) {
|
||||||
|
|
||||||
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType); if err != nil {
|
sessionInfo, err := a.ValidateToken(request.RefreshToken, enums.JwtRefreshTokenType)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
if utils.ErrorIsOneOf(
|
if utils.ErrorIsOneOf(
|
||||||
err,
|
err,
|
||||||
@@ -532,7 +558,8 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|||||||
sessionInfo.Username,
|
sessionInfo.Username,
|
||||||
sessionInfo.Session,
|
sessionInfo.Session,
|
||||||
sessionInfo.Role,
|
sessionInfo.Role,
|
||||||
); if err != nil {
|
)
|
||||||
|
if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to generate tokens for user during refresh",
|
"Failed to generate tokens for user during refresh",
|
||||||
zap.String("username", sessionInfo.Username),
|
zap.String("username", sessionInfo.Username),
|
||||||
@@ -543,7 +570,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|||||||
|
|
||||||
response := models.RefreshResponse{
|
response := models.RefreshResponse{
|
||||||
Tokens: models.Tokens{
|
Tokens: models.Tokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -551,7 +578,7 @@ func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.Refres
|
|||||||
return &response, nil
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtTokenType) (*dto.SessionInfo, error) {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -573,7 +600,8 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|||||||
return nil, errs.ErrInvalidToken
|
return nil, errs.ErrInvalidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(*dto.UserClaims); if ok && token.Valid {
|
claims, ok := token.Claims.(*dto.UserClaims)
|
||||||
|
if ok && token.Valid {
|
||||||
|
|
||||||
if claims.Type != tokenType {
|
if claims.Type != tokenType {
|
||||||
return nil, errs.ErrWrongTokenType
|
return nil, errs.ErrWrongTokenType
|
||||||
@@ -613,7 +641,7 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|||||||
ctx,
|
ctx,
|
||||||
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
fmt.Sprintf("session::%s::is_terminated", claims.Session),
|
||||||
*session.Terminated,
|
*session.Terminated,
|
||||||
time.Duration(8 * time.Hour), // XXX: magic number
|
time.Duration(8*time.Hour), // XXX: magic number
|
||||||
).Err(); err != nil {
|
).Err(); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to cache session's is_terminated state",
|
"Failed to cache session's is_terminated state",
|
||||||
@@ -633,8 +661,8 @@ func (a *authServiceImpl) ValidateToken(jwtToken string, tokenType enums.JwtToke
|
|||||||
|
|
||||||
sessionInfo := dto.SessionInfo{
|
sessionInfo := dto.SessionInfo{
|
||||||
Username: claims.Username,
|
Username: claims.Username,
|
||||||
Session: claims.Session,
|
Session: claims.Session,
|
||||||
Role: claims.Role,
|
Role: claims.Role,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &sessionInfo, nil
|
return &sessionInfo, nil
|
||||||
@@ -712,7 +740,7 @@ func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
if _, err = db.TXQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
CodeType: int32(enums.PasswordResetCodeType),
|
CodeType: int32(enums.PasswordResetCodeType),
|
||||||
CodeHash: hashedCode,
|
CodeHash: hashedCode,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -779,9 +807,9 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
if resetCode, err = db.TXQueries.GetValidConfirmationCodeByCode(db.CTX, database.GetValidConfirmationCodeByCodeParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
CodeType: int32(enums.PasswordResetCodeType),
|
CodeType: int32(enums.PasswordResetCodeType),
|
||||||
Code: request.VerificationCode,
|
Code: request.VerificationCode,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
a.log.Warn(
|
a.log.Warn(
|
||||||
@@ -795,7 +823,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
if err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{
|
||||||
ID: resetCode.ID,
|
ID: resetCode.ID,
|
||||||
Used: utils.NewPointer(true),
|
Used: utils.NewPointer(true),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
@@ -817,7 +845,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
if err = db.TXQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
PasswordHash: hashedPassword,
|
PasswordHash: hashedPassword,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
@@ -828,22 +856,20 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.LogOutSessions {
|
if request.LogOutSessions {
|
||||||
if err = db.TXQueries.TerminateAllSessionsForUserByUsername(db.CTX, user.Username); err != nil {
|
err = a.terminateAllSessionsForUser(context.TODO(), user.Username, &db.TXQueries); if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to log out older sessions as part of user password reset",
|
"Failed to terminate user's sessions during login",
|
||||||
zap.String("email", request.Email),
|
|
||||||
zap.String("username", user.Username),
|
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
return nil, errs.ErrServerError
|
return nil, errs.ErrServerError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Name: utils.NewPointer("First device"),
|
Name: utils.NewPointer("First device"),
|
||||||
Platform: utils.NewPointer("Unknown"),
|
Platform: utils.NewPointer("Unknown"),
|
||||||
LatestIp: utils.NewPointer("Unknown"),
|
LatestIp: utils.NewPointer("Unknown"),
|
||||||
}); err != nil {
|
}); if err != nil {
|
||||||
a.log.Error(
|
a.log.Error(
|
||||||
"Failed to create new session for user as part of user password reset",
|
"Failed to create new session for user as part of user password reset",
|
||||||
zap.String("email", request.Email),
|
zap.String("email", request.Email),
|
||||||
@@ -863,7 +889,7 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
|
|
||||||
response := models.PasswordResetCompleteResponse{
|
response := models.PasswordResetCompleteResponse{
|
||||||
Tokens: models.Tokens{
|
Tokens: models.Tokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -878,3 +904,57 @@ func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetComp
|
|||||||
return &response, nil
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *authServiceImpl) ChangePassword(request models.ChangePasswordRequest, uinfo dto.ClientInfo) (bool, error) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
helper, db, err := database.NewDbHelperTransaction(a.dbctx); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to open a transaction",
|
||||||
|
zap.Error(err))
|
||||||
|
return false, errs.ErrServerError
|
||||||
|
}
|
||||||
|
defer helper.RollbackOnError(err)
|
||||||
|
|
||||||
|
linfo, err := db.TXQueries.GetLoginInformationByUsername(db.CTX, uinfo.Username); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to get user login information",
|
||||||
|
zap.Error(err))
|
||||||
|
return false, errs.ErrServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.CheckPasswordHash(request.OldPassword, linfo.PasswordHash) {
|
||||||
|
a.log.Warn(
|
||||||
|
"Provided invalid old password while changing password",
|
||||||
|
zap.String("username", uinfo.Username))
|
||||||
|
return false, errs.ErrForbidden
|
||||||
|
}
|
||||||
|
|
||||||
|
newPasswordHash, err := utils.HashPassword(request.NewPassword); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to hash new password while changing password",
|
||||||
|
zap.String("username", uinfo.Username),
|
||||||
|
zap.Error(err))
|
||||||
|
return false, errs.ErrServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.TXlessQueries.UpdateLoginInformationByUsername(db.CTX, database.UpdateLoginInformationByUsernameParams{
|
||||||
|
Username: uinfo.Username,
|
||||||
|
PasswordHash: newPasswordHash,
|
||||||
|
}); if err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to save new password into database",
|
||||||
|
zap.String("username", uinfo.Username),
|
||||||
|
zap.Error(err))
|
||||||
|
return false, errs.ErrServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := helper.Commit(); err != nil {
|
||||||
|
a.log.Error(
|
||||||
|
"Failed to commit transaction",
|
||||||
|
zap.Error(err))
|
||||||
|
return false, errs.ErrServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ import "errors"
|
|||||||
func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
|
func ErrorIsOneOf(err error, ignoreErrors ...error) bool {
|
||||||
for _, ignore := range ignoreErrors {
|
for _, ignore := range ignoreErrors {
|
||||||
if errors.Is(err, ignore) {
|
if errors.Is(err, ignore) {
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,11 +259,12 @@ WHERE
|
|||||||
terminated IS TRUE AND
|
terminated IS TRUE AND
|
||||||
last_refresh_exp_time > CURRENT_TIMESTAMP;
|
last_refresh_exp_time > CURRENT_TIMESTAMP;
|
||||||
|
|
||||||
;-- name: TerminateAllSessionsForUserByUsername :exec
|
;-- name: TerminateAllSessionsForUserByUsername :many
|
||||||
UPDATE sessions
|
UPDATE sessions
|
||||||
SET terminated = TRUE
|
SET terminated = TRUE
|
||||||
FROM users
|
FROM users
|
||||||
WHERE sessions.user_id = users.id AND users.username = @username::text;
|
WHERE sessions.user_id = users.id AND users.username = @username::text
|
||||||
|
RETURNING sessions.guid;
|
||||||
|
|
||||||
;-- name: PruneTerminatedSessions :exec
|
;-- name: PruneTerminatedSessions :exec
|
||||||
DELETE FROM sessions
|
DELETE FROM sessions
|
||||||
|
|||||||
@@ -62,8 +62,8 @@ CREATE TABLE IF NOT EXISTS "sessions" (
|
|||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
guid UUID NOT NULL DEFAULT gen_random_uuid(),
|
guid UUID NOT NULL DEFAULT gen_random_uuid(),
|
||||||
name VARCHAR(100),
|
name VARCHAR(175),
|
||||||
platform VARCHAR(32),
|
platform VARCHAR(175),
|
||||||
latest_ip VARCHAR(16),
|
latest_ip VARCHAR(16),
|
||||||
login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
login_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',
|
last_refresh_exp_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '10080 seconds',
|
||||||
|
|||||||
Reference in New Issue
Block a user