diff --git a/go/controller/post_signin_idtoken.go b/go/controller/post_signin_idtoken.go index bb9f26c8..27ba2067 100644 --- a/go/controller/post_signin_idtoken.go +++ b/go/controller/post_signin_idtoken.go @@ -50,15 +50,15 @@ func (ctrl *Controller) postSigninIdtokenValidateRequest( func (ctrl *Controller) postSigninIdtokenCheckUserExists( ctx context.Context, email, providerID, providerUserID string, logger *slog.Logger, -) (sql.AuthUser, bool, *APIError) { +) (sql.AuthUser, bool, bool, *APIError) { user, apiError := ctrl.wf.GetUserByProviderUserID(ctx, providerID, providerUserID, logger) switch { case errors.Is(apiError, ErrUserProviderNotFound): case apiError != nil: logger.Error("error getting user by provider user id", logError(apiError)) - return user, false, apiError + return user, false, false, apiError default: - return user, true, nil + return user, true, true, nil } user, apiError = ctrl.wf.GetUserByEmail(ctx, email, logger) @@ -66,12 +66,12 @@ func (ctrl *Controller) postSigninIdtokenCheckUserExists( case errors.Is(apiError, ErrUserEmailNotFound): case apiError != nil: logger.Error("error getting user by email", logError(apiError)) - return sql.AuthUser{}, false, ErrInternalServerError + return sql.AuthUser{}, false, false, ErrInternalServerError default: - return user, true, nil + return user, true, false, nil } - return user, false, nil + return user, false, false, nil } func (ctrl *Controller) PostSigninIdtoken( //nolint:ireturn @@ -94,15 +94,17 @@ func (ctrl *Controller) PostSigninIdtoken( //nolint:ireturn return ctrl.respondWithError(ErrInvalidEmailPassword), nil } - user, found, apiError := ctrl.postSigninIdtokenCheckUserExists( + user, userFound, providerFound, apiError := ctrl.postSigninIdtokenCheckUserExists( ctx, profile.Email, string(req.Body.Provider), profile.ProviderUserID, logger, ) if apiError != nil { return ctrl.respondWithError(apiError), nil } - if found { - return ctrl.postSigninIdtokenSignin(ctx, user, logger) + if userFound { + return ctrl.postSigninIdtokenSignin( + ctx, user, providerFound, req.Body.Provider, profile.ProviderUserID, logger, + ) } return ctrl.postSigninIdtokenSignup(ctx, req, profile, logger) @@ -233,10 +235,25 @@ func (ctrl *Controller) postSigninIdtokenSignupWithoutSession( func (ctrl *Controller) postSigninIdtokenSignin( //nolint:ireturn ctx context.Context, user sql.AuthUser, + providerFound bool, + provider api.Provider, + providerUserID string, logger *slog.Logger, ) (api.PostSigninIdtokenResponseObject, error) { logger.Info("user found, signing in") + if !providerFound { + if _, apiErr := ctrl.wf.InsertUserProvider( + ctx, + user.ID, + string(provider), + providerUserID, + logger, + ); apiErr != nil { + return ctrl.respondWithError(apiErr), nil + } + } + session, err := ctrl.wf.NewSession(ctx, user, logger) if err != nil { logger.Error("error getting new session", logError(err)) diff --git a/go/controller/post_signin_idtoken_test.go b/go/controller/post_signin_idtoken_test.go index 7960e690..5c7ef9e8 100644 --- a/go/controller/post_signin_idtoken_test.go +++ b/go/controller/post_signin_idtoken_test.go @@ -606,6 +606,26 @@ func TestPostSigninIdToken(t *testing.T) { //nolint:maintidx WebauthnCurrentChallenge: pgtype.Text{}, }, nil) + mock.EXPECT().InsertUserProvider( + gomock.Any(), + sql.InsertUserProviderParams{ + UserID: userID, + ProviderID: "fake", + ProviderUserID: "106964149809169421082", + }, + ).Return( + sql.AuthUserProvider{ + ID: userID, + CreatedAt: pgtype.Timestamptz{}, //nolint:exhaustruct + UpdatedAt: pgtype.Timestamptz{}, //nolint:exhaustruct + UserID: userID, + AccessToken: "unset", + RefreshToken: pgtype.Text{}, //nolint:exhaustruct + ProviderID: "fake", + ProviderUserID: "106964149809169421082", + }, nil, + ) + mock.EXPECT().GetUserRoles( gomock.Any(), userID, ).Return([]sql.AuthUserRole{