diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 7a8f329..612a022 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -45,6 +45,7 @@ import ( "easywish/internal/controllers" "easywish/internal/database" "easywish/internal/logger" + redisclient "easywish/internal/redisClient" "easywish/internal/routes" "easywish/internal/services" "easywish/internal/validation" @@ -59,10 +60,13 @@ func main() { panic(err) } + cfg := config.GetConfig() + fx.New( fx.Provide( logger.NewLogger, logger.NewSyncLogger, + redisclient.NewRedisClient, gin.Default, ), database.Module, @@ -80,7 +84,7 @@ func main() { // Gin server := &http.Server{ - Addr: fmt.Sprintf(":%s", strconv.Itoa(int(config.GetConfig().Port))), + Addr: fmt.Sprintf(":%s", strconv.Itoa(int(cfg.Port))), Handler: router, } diff --git a/backend/go.mod b/backend/go.mod index 3af408a..09e54d1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -19,7 +19,9 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/bytedance/sonic v1.13.3 // indirect github.com/bytedance/sonic/loader v0.2.4 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.5 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/gin-contrib/sse v1.1.0 // indirect @@ -29,8 +31,10 @@ require ( github.com/go-openapi/swag v0.23.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/backend/go.sum b/backend/go.sum index de96d7a..58bd2be 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -5,12 +5,16 @@ github.com/bytedance/sonic v1.13.3/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1 github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -39,6 +43,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= @@ -48,6 +54,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVI github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/backend/internal/controllers/auth.go b/backend/internal/controllers/auth.go index 537201e..0691f6f 100644 --- a/backend/internal/controllers/auth.go +++ b/backend/internal/controllers/auth.go @@ -42,12 +42,12 @@ type AuthController interface { } type authControllerImpl struct { - authService services.AuthService log *zap.Logger + auth services.AuthService } -func NewAuthController(_log *zap.Logger, as services.AuthService) AuthController { - return &authControllerImpl{log: _log, authService: as} +func NewAuthController(_log *zap.Logger, _auth services.AuthService) AuthController { + return &authControllerImpl{log: _log, auth: _auth} } // @Summary Acquire tokens via login credentials (and 2FA code if needed) @@ -65,7 +65,7 @@ func (a *authControllerImpl) Login(c *gin.Context) { return } - response, err := a.authService.Login(request.Body) + response, err := a.auth.Login(request.Body) if err != nil { if errors.Is(err, errs.ErrForbidden) { @@ -134,7 +134,7 @@ func (a *authControllerImpl) RegistrationBegin(c *gin.Context) { return } - _, err := a.authService.RegistrationBegin(request.Body) + _, err := a.auth.RegistrationBegin(request.Body) if err != nil { if errors.Is(err, errs.ErrUsernameTaken) || errors.Is(err, errs.ErrEmailTaken) { @@ -164,7 +164,7 @@ func (a *authControllerImpl) RegistrationComplete(c *gin.Context) { return } - response, err := a.authService.RegistrationComplete(request.Body) + response, err := a.auth.RegistrationComplete(request.Body) if err != nil { if errors.Is(err, errs.ErrForbidden) { diff --git a/backend/internal/errors/general.go b/backend/internal/errors/general.go index 641a4f6..40e9f1d 100644 --- a/backend/internal/errors/general.go +++ b/backend/internal/errors/general.go @@ -25,4 +25,5 @@ var ( ErrNotImplemented = errors.New("Feature is not implemented") ErrBadRequest = errors.New("Bad request") ErrForbidden = errors.New("Access is denied") + ErrTooManyRequests = errors.New("Too many requests") ) diff --git a/backend/internal/services/auth.go b/backend/internal/services/auth.go index a416d2d..076428c 100644 --- a/backend/internal/services/auth.go +++ b/backend/internal/services/auth.go @@ -18,6 +18,7 @@ package services import ( + "context" "easywish/config" "easywish/internal/database" errs "easywish/internal/errors" @@ -26,7 +27,10 @@ import ( "easywish/internal/utils/enums" "errors" "fmt" + "time" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "go.uber.org/zap" @@ -37,16 +41,19 @@ type AuthService interface { RegistrationComplete(request models.RegistrationCompleteRequest) (*models.RegistrationCompleteResponse, error) Login(request models.LoginRequest) (*models.LoginResponse, error) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) + PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) + PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error) } type authServiceImpl struct { log *zap.Logger - smtp SmtpService dbctx database.DbContext + redis *redis.Client + smtp SmtpService } -func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _smtp SmtpService) AuthService { - return &authServiceImpl{log: _log, dbctx: _dbctx, smtp: _smtp} +func NewAuthService(_log *zap.Logger, _dbctx database.DbContext, _redis *redis.Client, _smtp SmtpService) AuthService { + return &authServiceImpl{log: _log, dbctx: _dbctx, redis: _redis, smtp: _smtp} } func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequest) (bool, error) { @@ -61,6 +68,9 @@ func (a *authServiceImpl) RegistrationBegin(request models.RegistrationBeginRequ var err error + // TODO: get user if it exists. If not verified and no valid code exists, delete + // and recreate + if user, err = db.TXQueries.CreateUser(db.CTX, request.Username); err != nil { if errs.MatchPgError(err, pgerrcode.UniqueViolation) { @@ -368,3 +378,106 @@ func (a *authServiceImpl) Login(request models.LoginRequest) (*models.LoginRespo func (a *authServiceImpl) Refresh(request models.RefreshRequest) (*models.RefreshResponse, error) { return nil, errs.ErrNotImplemented } + +func (a *authServiceImpl) PasswordResetBegin(request models.PasswordResetBeginRequest) (bool, error) { + + var user database.User + var generatedCode, hashedCode string + var err error + + helper, db, err := database.NewDbHelperTransaction(a.dbctx) + defer helper.Rollback() + + ctx := context.TODO() + + cooldownTimeUnix, redisErr := a.redis.Get(ctx, fmt.Sprintf("email::%s::reset_cooldown", request.Email)).Int64() + if redisErr != nil && redisErr != redis.Nil { + a.log.Error( + "Failed to get reset_cooldown state for user", + zap.String("email", request.Email), + zap.Error(redisErr)) + return false, errs.ErrServerError + + } else if err == nil { + current_time := time.Now() + if current_time.Unix() < cooldownTimeUnix { + a.log.Warn( + "Attempted to request a new password reset code for email on active reset cooldown", + zap.String("email", request.Email)) + return false, errs.ErrTooManyRequests + } + } + + if user, err = db.TXQueries.GetUserByEmail(db.CTX, request.Email); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + // Enable cooldown for the email despite that account does not exist + err := a.redis.Set( + ctx, + fmt.Sprintf("email::%s::reset_cooldown", request.Email), + time.Now().Add(10*time.Minute), + time.Duration(10*time.Minute), + ).Err() + + if err != nil { + a.log.Error( + "Failed to set reset cooldown for email", + zap.Error(err)) + return false, err + } + + a.log.Warn( + "Requested password reset email for unexistent user", + zap.String("email", request.Email)) + return true, nil + } + a.log.Error( + "Failed to retrieve user from database", + zap.String("email", request.Email), + zap.Error(err)) + return false, errs.ErrServerError + } + + generatedCode = uuid.New().String() + if hashedCode, err = utils.HashPassword(generatedCode); err != nil { + a.log.Error( + "Failed to hash password reset code for user", + zap.String("username", user.Username), + zap.Error(err)) + return false, errs.ErrServerError + } + + if _, err = db.TXlessQueries.CreateConfirmationCode(db.CTX, database.CreateConfirmationCodeParams{ + UserID: user.ID, + CodeType: int32(enums.PasswordResetCodeType), + CodeHash: hashedCode, + }); err != nil { + a.log.Error( + "Failed to save user password reset code to the database", + zap.String("username", user.Username), + zap.Error(err)) + } + + err = a.redis.Set( + ctx, + fmt.Sprintf("email::%s::reset_cooldown", request.Email), + time.Now().Add(10*time.Minute), + time.Duration(10*time.Minute), + ).Err() + + if err != nil { + a.log.Error( + "Failed to set reset cooldown for email. Cancelling password reset", + zap.Error(err)) + return false, err + } + + helper.Commit() + + return true, nil +} + +func (a *authServiceImpl) PasswordResetComplete(request models.PasswordResetCompleteRequest) (*models.PasswordResetCompleteResponse, error) { + + return nil, errs.ErrNotImplemented +} +