diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index f2c9a34..889f1b8 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -34,6 +34,8 @@ type Claims struct { jwt.RegisteredClaims } +// TODO: validate token type +// TODO: validate session guid func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.GetConfig() diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index bccaa75..a89812a 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -1,17 +1,17 @@ // Copyright (c) 2025 Nikolai Papin -// +// // This file is part of Easywish -// +// // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. -// +// // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See // the GNU General Public License for more details. -// +// // You should have received a copy of the GNU General Public License // along with this program. If not, see . @@ -24,12 +24,13 @@ import ( "easywish/internal/utils" "easywish/internal/utils/enums" + "github.com/jackc/pgx/v5" "go.uber.org/zap" ) type AuthService interface { RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) - RegistrationComplete(request models.RegistrationBeginRequest) (*models.RegistrationCompleteResponse, error) + RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) Login(request models.LoginRequest) (*models.LoginResponse, error) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) } @@ -53,7 +54,7 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ var err error - if user, err = db.TXQueries.CreateUser(db.CTX, request.Username); err != nil { // TODO: validation + if user, err = db.TXQueries.CreateUser(db.CTX, request.Username); err != nil { a.log.Error("Failed to add user to database", zap.Error(err)) return false, errs.ErrServerError } @@ -94,31 +95,190 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ return true, nil } -func (a *authServiceImpl) RegistrationComplete(request models.RegistrationBeginRequest) (*models.RegistrationCompleteResponse, error) { - return nil, errs.ErrNotImplemented +func (a *authServiceImpl) RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) { + + var user database.User + var profile database.Profile + var session database.Session + var confirmationCode database.ConfirmationCode + var accessToken, refreshToken string + var err error + + helper, db, _ := database.NewDbHelperTransaction(a.dbctx) + + user, err = db.TXQueries.GetUserByUsername(db.CTX, request.Username) + + if err != nil { + a.log.Error( + "Failed to find user attempting to complete registration", + zap.String("username", request.Username), + zap.Error(err)) + return nil, errs.ErrUserNotFound + } + + confirmationCode, err = db.TXQueries.GetConfirmationCodeByCode(db.CTX, database.GetConfirmationCodeByCodeParams{ + UserID: user.ID, + CodeType: int32(enums.RegistrationCodeType), + Code: request.VerificationCode, + }) + + if err != nil { + a.log.Warn( + "User supplied wrong confirmation code for completing registration", + zap.String("username", user.Username), + zap.Error(err)) + return nil, errs.ErrForbidden + } + + err = db.TXQueries.UpdateConfirmationCode(db.CTX, database.UpdateConfirmationCodeParams{ + ID: confirmationCode.ID, + Used: utils.NewPointer(true), + }) + + if err != nil { + a.log.Error( + "Failed to update the user's registration code used state", + zap.String("username", user.Username), + zap.Int64("confirmation_code_id", confirmationCode.ID), + zap.Error(err), + ) + return nil, errs.ErrServerError + } + + err = db.TXQueries.UpdateUser(db.CTX, database.UpdateUserParams{ + Verified: utils.NewPointer(true), + }) + + if err != nil { + a.log.Error("Failed to update verified status for user", + zap.String("username", user.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + profile, err = db.TXQueries.CreateProfile(db.CTX, database.CreateProfileParams{ + UserID: user.ID, + Name: request.Name, + AvatarUrl: request.AvatarUrl, + }) + + if err != nil { + a.log.Error("Failed to create profile for user", + zap.String("username", user.Username), + + ) + return nil, errs.ErrServerError + } + + _, err = db.TXQueries.CreateProfileSettings(db.CTX, profile.ID) + + if err != nil { + a.log.Error("Failed to create profile settings for user", + zap.String("username", user.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + session, err = db.TXQueries.CreateSession(db.CTX, database.CreateSessionParams{ + UserID: user.ID, + Name: utils.NewPointer("First device"), + Platform: utils.NewPointer("Unknown"), + LatestIp: utils.NewPointer("Unknown"), + }) + + if err != nil { + a.log.Error( + "Failed to create a new session during registration, rolling back registration", + zap.String("username", user.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + accessToken, refreshToken, err = utils.GenerateTokens(user.Username, session.Guid.String()) + + if err != nil { + a.log.Error( + "Failed to create tokens for newly registered user, rolling back registration", + zap.String("username", user.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + helper.Commit() + + response := models.RegistrationCompleteResponse{Tokens: models.Tokens{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }} + + return &response, errs.ErrNotImplemented } +// TODO: totp +// TODO: banned user check func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginResponse, error) { - conn, ctx, err := utils.GetDbConn() - if err != nil { - return nil, err - } - defer conn.Close(ctx) + var userRow database.GetUserByLoginCredentialsRow + var session database.Session - queries := database.New(conn) + helper, db, _ := database.NewDbHelperTransaction(a.dbctx) + defer helper.Rollback() - user, err := queries.GetUserByLoginCredentials(ctx, database.GetUserByLoginCredentialsParams{ + var err error + + userRow, err = db.TXQueries.GetUserByLoginCredentials(db.CTX, database.GetUserByLoginCredentialsParams{ Username: request.Username, Password: request.Password, }) if err != nil { - return nil, errs.ErrUnauthorized + a.log.Warn( + "Failed login attempt", + zap.Error(err)) + + var returnedError error + + switch err { + case pgx.ErrNoRows: + returnedError = errs.ErrForbidden + default: + returnedError = errs.ErrServerError + } + + return nil, returnedError } - accessToken, refreshToken, err := utils.GenerateTokens(user.Username) + session, err = db.TXlessQueries.CreateSession(db.CTX, database.CreateSessionParams{ + UserID: userRow.ID, + Name: utils.NewPointer("New device"), + Platform: utils.NewPointer("Unknown"), + LatestIp: utils.NewPointer("Unknown"), + }) - return &models.LoginResponse{Tokens: models.Tokens{AccessToken: accessToken, RefreshToken: refreshToken}}, nil + if err != nil { + a.log.Error( + "Failed to create session for a new login", + zap.String("username", userRow.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + accessToken, refreshToken, err := utils.GenerateTokens(userRow.Username, session.Guid.String()) + if err != nil { + a.log.Error( + "Failed to generate tokens for a new login", + zap.String("username", userRow.Username), + zap.Error(err)) + return nil, errs.ErrServerError + } + + helper.Commit() + + response := models.LoginResponse{Tokens: models.Tokens{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }} + + return &response, nil } func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { diff --git a/backend/internal/utils/enums/enums.go b/backend/internal/utils/enums/enums.go index 8f7cd5e..6a8f39e 100644 --- a/backend/internal/utils/enums/enums.go +++ b/backend/internal/utils/enums/enums.go @@ -29,3 +29,9 @@ const ( UserRole AdminRole ) + +type JwtTokenType int32 +const ( + JwtAccessTokenType JwtTokenType = iota + JwtRefreshTokenType +) diff --git a/backend/internal/utils/jwt.go b/backend/internal/utils/jwt.go index bf25213..6b38273 100644 --- a/backend/internal/utils/jwt.go +++ b/backend/internal/utils/jwt.go @@ -1,17 +1,17 @@ // Copyright (c) 2025 Nikolai Papin -// +// // This file is part of Easywish -// +// // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. -// +// // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See // the GNU General Public License for more details. -// +// // You should have received a copy of the GNU General Public License // along with this program. If not, see . @@ -19,22 +19,27 @@ package utils import ( "easywish/config" + "easywish/internal/utils/enums" "time" "github.com/golang-jwt/jwt/v5" ) -func GenerateTokens(username string) (accessToken, refreshToken string, err error) { +func GenerateTokens(username string, sessionGuid string) (accessToken, refreshToken string, err error) { cfg := config.GetConfig() accessClaims := jwt.MapClaims{ "username": username, + "guid": sessionGuid, + "type": enums.JwtAccessTokenType, "exp": time.Now().Add(time.Minute * time.Duration(cfg.JwtExpAccess)).Unix(), } accessToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims).SignedString([]byte(cfg.JwtSecret)) refreshClaims := jwt.MapClaims{ "username": username, + "guid": sessionGuid, + "type": enums.JwtRefreshTokenType, "exp": time.Now().Add(time.Hour * time.Duration(cfg.JwtExpRefresh)).Unix(), } refreshToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims).SignedString([]byte(cfg.JwtSecret))