Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow sign in with email #100

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions backend/internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,29 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
userService := service.NewUserService(db, jwtService, auditLogService)
userService := service.NewUserService(db, jwtService, auditLogService, emailService)
customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)

rateLimitMiddleware := middleware.NewRateLimitMiddleware()

// Setup global middleware
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))

// Initialize middleware
// Initialize middleware for specific routes
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()

// Set up API routes
apiGroup := r.Group("/api")
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService)
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, rateLimitMiddleware, webauthnService)
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, rateLimitMiddleware, userService, appConfigService)
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService, emailService)
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)
Expand Down
6 changes: 4 additions & 2 deletions backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func (e *OidcInvalidAuthorizationCodeError) HttpStatusCode() int { return 400 }

type OidcInvalidCallbackURLError struct{}

func (e *OidcInvalidCallbackURLError) Error() string { return "invalid callback URL, it might be necessary for an admin to fix this" }
func (e *OidcInvalidCallbackURLError) Error() string {
return "invalid callback URL, it might be necessary for an admin to fix this"
}
func (e *OidcInvalidCallbackURLError) HttpStatusCode() int { return 400 }

type FileTypeNotSupportedError struct{}
Expand Down Expand Up @@ -95,7 +97,7 @@ func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbid
type TooManyRequestsError struct{}

func (e *TooManyRequestsError) Error() string {
return "Too many requests. Please wait a while before trying again."
return "Too many requests"
}
func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests }

Expand Down
21 changes: 19 additions & 2 deletions backend/internal/controller/user_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler)
}

type UserController struct {
Expand Down Expand Up @@ -141,7 +142,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
return
}

token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent())
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
c.Error(err)
return
Expand All @@ -150,8 +151,24 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
c.JSON(http.StatusCreated, gin.H{"token": token})
}

func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

err := uc.UserService.RequestOneTimeAccessEmail(input.Email)
if err != nil {
c.Error(err)
return
}

c.Status(http.StatusNoContent)
}

func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
return
Expand Down
25 changes: 13 additions & 12 deletions backend/internal/dto/app_config_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ type AppConfigVariableDto struct {
}

type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
EmailEnabled string `json:"emailEnabled" binding:"required"`
SmtHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpUser string `json:"smtpUser"`
SmtpPassword string `json:"smtpPassword"`
SmtpTls string `json:"smtpTls"`
SmtpSkipCertVerify string `json:"smtpSkipCertVerify"`
AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"`
EmailEnabled string `json:"emailEnabled" binding:"required"`
SmtHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpUser string `json:"smtpUser"`
SmtpPassword string `json:"smtpPassword"`
SmtpTls string `json:"smtpTls"`
SmtpSkipCertVerify string `json:"smtpSkipCertVerify"`
}
4 changes: 4 additions & 0 deletions backend/internal/dto/user_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}

type OneTimeAccessEmailDto struct {
Email string `json:"email" binding:"required,email"`
}
16 changes: 8 additions & 8 deletions backend/internal/middleware/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ func NewRateLimitMiddleware() *RateLimitMiddleware {
}

func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Start the cleanup routine
go cleanupClients()
go cleanupClients(&mu, clients)

return func(c *gin.Context) {
ip := c.ClientIP()
Expand All @@ -29,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
return
}

limiter := getLimiter(ip, limit, burst)
limiter := getLimiter(ip, limit, burst, &mu, clients)
if !limiter.Allow() {
c.Error(&common.TooManyRequestsError{})
c.Abort()
Expand All @@ -45,12 +49,8 @@ type client struct {
lastSeen time.Time
}

// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Cleanup routine to remove stale clients that haven't been seen for a while
func cleanupClients() {
func cleanupClients(mu *sync.Mutex, clients map[string]*client) {
for {
time.Sleep(time.Minute)
mu.Lock()
Expand All @@ -64,7 +64,7 @@ func cleanupClients() {
}

// getLimiter retrieves the rate limiter for a given IP address, creating one if it doesn't exist
func getLimiter(ip string, limit rate.Limit, burst int) *rate.Limiter {
func getLimiter(ip string, limit rate.Limit, burst int, mu *sync.Mutex, clients map[string]*client) *rate.Limiter {
mu.Lock()
defer mu.Unlock()

Expand Down
9 changes: 5 additions & 4 deletions backend/internal/model/app_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ type AppConfigVariable struct {
}

type AppConfig struct {
AppName AppConfigVariable
SessionDuration AppConfigVariable
EmailsVerified AppConfigVariable
AllowOwnAccountEdit AppConfigVariable
AppName AppConfigVariable
SessionDuration AppConfigVariable
EmailsVerified AppConfigVariable
AllowOwnAccountEdit AppConfigVariable
EmailOneTimeAccessEnabled AppConfigVariable

BackgroundImageType AppConfigVariable
LogoLightImageType AppConfigVariable
Expand Down
13 changes: 13 additions & 0 deletions backend/internal/service/app_config_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ var defaultDbConfig = model.AppConfig{
IsPublic: true,
DefaultValue: "true",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
BackgroundImageType: model.AppConfigVariable{
Key: "backgroundImageType",
Type: "string",
Expand Down Expand Up @@ -119,6 +125,13 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()

// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key {
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false"
}
}

var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
Expand Down
15 changes: 13 additions & 2 deletions backend/internal/service/email_service_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
/**
How to add new template:
- pick unique and descriptive template ${name} (for example "login-with-new-device")
- in backend/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- in backend/resources/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- create xxxxTemplate and xxxxTemplateData (for example NewLoginTemplate and NewLoginTemplateData)
- Path *must* be ${name}
- add xxxTemplate.Path to "emailTemplatePaths" at the end
Expand All @@ -27,6 +27,13 @@ var NewLoginTemplate = email.Template[NewLoginTemplateData]{
},
}

var OneTimeAccessTemplate = email.Template[OneTimeAccessTemplateData]{
Path: "one-time-access",
Title: func(data *email.TemplateData[OneTimeAccessTemplateData]) string {
return "One time access"
},
}

var TestTemplate = email.Template[struct{}]{
Path: "test",
Title: func(data *email.TemplateData[struct{}]) string {
Expand All @@ -42,5 +49,9 @@ type NewLoginTemplateData struct {
DateTime time.Time
}

type OneTimeAccessTemplateData = struct {
Link string
}

// this is list of all template paths used for preloading templates
var emailTemplatesPaths = []string{NewLoginTemplate.Path, TestTemplate.Path}
var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path}
50 changes: 44 additions & 6 deletions backend/internal/service/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@ package service

import (
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/model"
"github.com/stonith404/pocket-id/backend/internal/model/types"
"github.com/stonith404/pocket-id/backend/internal/utils"
"github.com/stonith404/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
"log"
"time"
)

type UserService struct {
db *gorm.DB
jwtService *JwtService
auditLogService *AuditLogService
emailService *EmailService
}

func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService) *UserService {
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService}
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService) *UserService {
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService}
}

func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]model.User, utils.PaginationResponse, error) {
Expand Down Expand Up @@ -89,7 +93,39 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
return user, nil
}

func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time, ipAddress, userAgent string) (string, error) {
func (s *UserService) RequestOneTimeAccessEmail(emailAddress string) error {
var user model.User
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
// Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
} else {
return err
}
}

oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
if err != nil {
return err
}
link := fmt.Sprintf("%s/login/%s", common.EnvConfig.AppURL, oneTimeAccessToken)

go func() {
err := SendEmail(s.emailService, email.Address{
Name: user.Username,
Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
Link: link,
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
}
}()

return nil
}

func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
return "", err
Expand All @@ -105,12 +141,10 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
return "", err
}

s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, userID, model.AuditLogData{})

return oneTimeAccessToken.Token, nil
}

func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) {
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) {
var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -127,6 +161,10 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, stri
return model.User{}, "", err
}

if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{})
}

return oneTimeAccessToken.User, accessToken, nil
}

Expand Down
2 changes: 0 additions & 2 deletions backend/internal/utils/email/email_service_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
ttemplate "text/template"
)

const templateComponentsDir = "components"

type Template[V any] struct {
Path string
Title func(data *TemplateData[V]) string
Expand Down
15 changes: 15 additions & 0 deletions backend/resources/email-templates/components/style_html.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,20 @@
font-size: 1rem;
line-height: 1.5;
}
.button {
border-radius: 0.375rem;
font-size: 1rem;
font-weight: 500;
background-color: #000000;
color: #ffffff;
padding: 0.7rem 1.5rem;
outline: none;
border: none;
text-decoration: none;
}
.button-container {
text-align: center;
margin-top: 24px;
}
</style>
{{ end }}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{{ define "base" }}
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="Pocket ID"/>
<img src="{{ .LogoURL }}" alt="{{ .AppName }}"/>
<h1>{{ .AppName }}</h1>
</div>
<div class="warning">Warning</div>
Expand Down
Loading
Loading