Skip to content

Commit

Permalink
fix modify user reauth bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dreth committed Aug 16, 2024
1 parent 78761c5 commit 28617b0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
61 changes: 46 additions & 15 deletions backend/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"hbd/encryption"
Expand Down Expand Up @@ -145,8 +146,14 @@ func Register(c *gin.Context) {
// As the user was successfully created, send a telegram message through the bot and ID to confirm the registration
telegram.SendTelegramMessage(req.TelegramBotAPIKey, req.TelegramUserID, fmt.Sprintf("🎂 Your user has been successfully registered, through this bot and user ID you'll receive your birthday reminders (if there's any) at %s (Timezone: %s).\n\nIf you encounter any issues using the app or want to give any feedback to us. Please open an issue here: https://github.com/dreth/hbd/issues, thanks and we hope you find the application useful!", req.ReminderTime, req.Timezone))

// Return the token and the user's details
token, err := GenerateJWT(req.Email, 720)
// Get the JWT duration from the header or use the default
jwtDuration, err := GetJWTDurationFromHeader(c, 720)
if err != nil {
jwtDuration = 720
}

// Generate JWT token
token, err := GenerateJWT(req.Email, jwtDuration)
if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) {
return
} else {
Expand Down Expand Up @@ -213,16 +220,23 @@ func Login(c *gin.Context) {
emailHash := encryption.HashStringWithSHA256(req.Email)
passwordHash := encryption.HashStringWithSHA256(req.Password)

// Fetch the user with the given email hash and password hash from the database
_, err := models.Users(
qm.Where("email_hash = ?", emailHash),
qm.Where("password_hash = ?", passwordHash),
).One(c.Request.Context(), boil.GetContextDB())
if err != nil {

// If no user is found, return a 401 Unauthorized
if err == sql.ErrNoRows {
c.JSON(http.StatusUnauthorized, structs.Error{Error: "invalid email or password"})
return
}

// Handle other errors separately
if err != nil {
c.JSON(http.StatusInternalServerError, structs.Error{Error: "an unexpected error occurred"})
return
}

// Set the user email in the context
c.Set("Email", req.Email)

Expand All @@ -238,7 +252,6 @@ func Login(c *gin.Context) {
}

// Generate JWT token
println(jwtDuration)
token, err := GenerateJWT(req.Email, jwtDuration)
if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) {
return
Expand Down Expand Up @@ -287,7 +300,7 @@ func Me(c *gin.Context) {
// @x-order 4
func ModifyUser(c *gin.Context) {
// Retrieve the user from the database
user, err := GetUserByEmail(c)
user, originalEmail, err := GetUserByEmail(c)
if helper.HE(c, err, http.StatusUnauthorized, "invalid email", false) {
return
}
Expand Down Expand Up @@ -388,21 +401,39 @@ func ModifyUser(c *gin.Context) {
return
}

// After committing the transaction, emit another JWT token with the new email
token, err := GenerateJWT(req.NewEmail, 720)
if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) {
return
}

// Get user data post-changes
userData, err := GetUserData(c)
if helper.HE(c, err, http.StatusInternalServerError, "invalid email or password", true) {
return
}

// Update the email in the context
if req.NewEmail != "" {
c.Set("Email", req.NewEmail)
if (req.NewEmail != "") || (req.NewPassword != "" && req.NewEmail == "") {
// Possible scenarios
// 1. New email is empty, but password is not
// 2. New email is not empty (regardless of password)
// Case 1: New email is empty, but password is not
if req.NewPassword != "" && req.NewEmail == "" {
req.NewEmail = originalEmail
}

// Case 2: New email is not empty (regardless of password)
if req.NewEmail != "" {
c.Set("Email", req.NewEmail)
}

// After committing the transaction, emit another JWT token with the new email
// Get the JWT duration from the header or use the default
jwtDuration, err := GetJWTDurationFromHeader(c, 720)
if err != nil {
jwtDuration = 720
}

// Generate JWT token
token, err := GenerateJWT(req.NewEmail, jwtDuration)
if helper.HE(c, err, http.StatusInternalServerError, "failed to generate token", false) {
return
}

// Return the new token with the new user data
c.JSON(http.StatusOK, structs.LoginSuccess{
Expand Down Expand Up @@ -433,7 +464,7 @@ func ModifyUser(c *gin.Context) {
// @x-order 5
func DeleteUser(c *gin.Context) {
// Retrieve the user from the database
user, err := GetUserByEmail(c)
user, _, err := GetUserByEmail(c)
if helper.HE(c, err, http.StatusUnauthorized, "invalid email", false) {
return
}
Expand Down
8 changes: 4 additions & 4 deletions backend/auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// 2. Hashes the email using SHA-256.
// 3. Queries the database for a user with the given email hash.
// 4. Returns the user object or an error if the user is not found or an error occurs.
func GetUserByEmail(c *gin.Context) (*models.User, error) {
func GetUserByEmail(c *gin.Context) (*models.User, string, error) {
// Get the email from the context
email := c.GetString("Email")

Expand All @@ -40,10 +40,10 @@ func GetUserByEmail(c *gin.Context) (*models.User, error) {
qm.Where("email_hash = ?", emailHash),
).One(c.Request.Context(), boil.GetContextDB())
if err != nil {
return nil, errors.New("invalid email")
return nil, email, errors.New("invalid email")
}

return user, nil
return user, email, nil
}

// GetUserData fetches and returns user data including decrypted Telegram bot API key and user ID,
Expand All @@ -69,7 +69,7 @@ func GetUserByEmail(c *gin.Context) (*models.User, error) {
// - (*structs.UserData, error): A pointer to the UserData struct containing user details and birthdays, or an error.
func GetUserData(c *gin.Context) (*structs.UserData, error) {
// Get the user by its email
user, err := GetUserByEmail(c)
user, _, err := GetUserByEmail(c)
if err != nil {
return nil, errors.New("invalid email")
}
Expand Down

0 comments on commit 28617b0

Please sign in to comment.