From b27986a5b0f11f54064e80811b524a3fb9be863b Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Wed, 5 Feb 2025 16:01:44 +0300 Subject: [PATCH] refactor: remove rolename from invitations table Signed-off-by: Felix Gateru --- domains/api/http/invitations.go | 5 ----- domains/events/events.go | 3 --- domains/invitations.go | 1 - domains/postgres/init.go | 1 - domains/postgres/invitations.go | 14 ++++---------- domains/postgres/invitations_test.go | 26 -------------------------- domains/service.go | 10 ++++++---- domains/service_test.go | 12 +++++++++++- pkg/sdk/invitations.go | 6 +++--- pkg/sdk/invitations_test.go | 26 +++++++++++++++----------- 10 files changed, 39 insertions(+), 65 deletions(-) diff --git a/domains/api/http/invitations.go b/domains/api/http/invitations.go index c999058480..551d567d80 100644 --- a/domains/api/http/invitations.go +++ b/domains/api/http/invitations.go @@ -121,10 +121,6 @@ func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - roleName, err := apiutil.ReadStringQuery(r, roleNameKey, "") - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } domainID, err := apiutil.ReadStringQuery(r, domainIDKey, "") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) @@ -144,7 +140,6 @@ func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, InvitedBy: invitedBy, UserID: userID, RoleID: roleID, - RoleName: roleName, DomainID: domainID, State: state, }, diff --git a/domains/events/events.go b/domains/events/events.go index aa81f7e3d2..f910d4148d 100644 --- a/domains/events/events.go +++ b/domains/events/events.go @@ -357,9 +357,6 @@ func (lie listInvitationsEvent) Encode() (map[string]interface{}, error) { if lie.RoleID != "" { val["role_id"] = lie.RoleID } - if lie.RoleName != "" { - val["role_name"] = lie.RoleName - } if lie.State.String() != "" { val["state"] = lie.State.String() } diff --git a/domains/invitations.go b/domains/invitations.go index 2051ea0013..b6b60ef3e4 100644 --- a/domains/invitations.go +++ b/domains/invitations.go @@ -52,7 +52,6 @@ type InvitationPageMeta struct { UserID string `json:"user_id,omitempty" db:"user_id,omitempty"` DomainID string `json:"domain_id,omitempty" db:"domain_id,omitempty"` RoleID string `json:"role_id,omitempty" db:"role_id,omitempty"` - RoleName string `json:"role_name,omitempty" db:"role_name,omitempty"` InvitedByOrUserID string `db:"invited_by_or_user_id,omitempty"` State State `json:"state,omitempty"` } diff --git a/domains/postgres/init.go b/domains/postgres/init.go index 64ba04423c..2703663dc6 100644 --- a/domains/postgres/init.go +++ b/domains/postgres/init.go @@ -48,7 +48,6 @@ func Migration() (*migrate.MemoryMigrationSource, error) { user_id VARCHAR(36) NOT NULL, domain_id VARCHAR(36) NOT NULL, role_id VARCHAR(36) NOT NULL, - role_name VARCHAR(200), created_at TIMESTAMP NOT NULL, updated_at TIMESTAMP, confirmed_at TIMESTAMP, diff --git a/domains/postgres/invitations.go b/domains/postgres/invitations.go index d5f7035d4e..e677d007f6 100644 --- a/domains/postgres/invitations.go +++ b/domains/postgres/invitations.go @@ -16,8 +16,8 @@ import ( ) func (repo domainRepo) SaveInvitation(ctx context.Context, invitation domains.Invitation) (err error) { - q := `INSERT INTO invitations (invited_by, user_id, domain_id, role_id, role_name, created_at) - VALUES (:invited_by, :user_id, :domain_id, :role_id, :role_name, :created_at)` + q := `INSERT INTO invitations (invited_by, user_id, domain_id, role_id, created_at) + VALUES (:invited_by, :user_id, :domain_id, :role_id, :created_at)` dbInv := toDBInvitation(invitation) if _, err = repo.db.NamedExecContext(ctx, q, dbInv); err != nil { @@ -28,7 +28,7 @@ func (repo domainRepo) SaveInvitation(ctx context.Context, invitation domains.In } func (repo domainRepo) RetrieveInvitation(ctx context.Context, userID, domainID string) (domains.Invitation, error) { - q := `SELECT invited_by, user_id, domain_id, role_id, role_name, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE user_id = :user_id AND domain_id = :domain_id;` + q := `SELECT invited_by, user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE user_id = :user_id AND domain_id = :domain_id;` dbinv := dbInvitation{ UserID: userID, @@ -55,7 +55,7 @@ func (repo domainRepo) RetrieveInvitation(ctx context.Context, userID, domainID func (repo domainRepo) RetrieveAllInvitations(ctx context.Context, pm domains.InvitationPageMeta) (domains.InvitationPage, error) { query := pageQuery(pm) - q := fmt.Sprintf("SELECT invited_by, user_id, domain_id, role_id, role_name, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query) + q := fmt.Sprintf("SELECT invited_by, user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query) rows, err := repo.db.NamedQueryContext(ctx, q, pm) if err != nil { @@ -148,9 +148,6 @@ func pageQuery(pm domains.InvitationPageMeta) string { if pm.RoleID != "" { query = append(query, "role_id = :role_id") } - if pm.RoleName != "" { - query = append(query, "role_name = :role_name") - } if pm.InvitedByOrUserID != "" { query = append(query, "(invited_by = :invited_by_or_user_id OR user_id = :invited_by_or_user_id)") } @@ -176,7 +173,6 @@ type dbInvitation struct { UserID string `db:"user_id"` DomainID string `db:"domain_id"` RoleID string `db:"role_id,omitempty"` - RoleName string `db:"role_name,omitempty"` Relation string `db:"relation"` CreatedAt time.Time `db:"created_at"` UpdatedAt sql.NullTime `db:"updated_at,omitempty"` @@ -201,7 +197,6 @@ func toDBInvitation(inv domains.Invitation) dbInvitation { UserID: inv.UserID, DomainID: inv.DomainID, RoleID: inv.RoleID, - RoleName: inv.RoleName, CreatedAt: inv.CreatedAt, UpdatedAt: updatedAt, ConfirmedAt: confirmedAt, @@ -226,7 +221,6 @@ func toInvitation(dbinv dbInvitation) domains.Invitation { UserID: dbinv.UserID, DomainID: dbinv.DomainID, RoleID: dbinv.RoleID, - RoleName: dbinv.RoleName, CreatedAt: dbinv.CreatedAt, UpdatedAt: updatedAt, ConfirmedAt: confirmedAt, diff --git a/domains/postgres/invitations_test.go b/domains/postgres/invitations_test.go index 4bb69463a3..3e1680dbce 100644 --- a/domains/postgres/invitations_test.go +++ b/domains/postgres/invitations_test.go @@ -46,7 +46,6 @@ func TestSaveInvitation(t *testing.T) { UserID: userID, DomainID: domainID, RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: nil, @@ -59,7 +58,6 @@ func TestSaveInvitation(t *testing.T) { DomainID: domainID, CreatedAt: time.Now(), RoleID: roleID, - RoleName: "admin", ConfirmedAt: time.Now(), }, err: nil, @@ -71,7 +69,6 @@ func TestSaveInvitation(t *testing.T) { UserID: userID, DomainID: domainID, RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: repoerr.ErrConflict, @@ -83,7 +80,6 @@ func TestSaveInvitation(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: repoerr.ErrMalformedEntity, @@ -95,7 +91,6 @@ func TestSaveInvitation(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: invalidUUID, RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: repoerr.ErrMalformedEntity, @@ -107,7 +102,6 @@ func TestSaveInvitation(t *testing.T) { UserID: invalidUUID, DomainID: testsutil.GenerateUUID(t), RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: repoerr.ErrMalformedEntity, @@ -118,7 +112,6 @@ func TestSaveInvitation(t *testing.T) { InvitedBy: testsutil.GenerateUUID(t), UserID: testsutil.GenerateUUID(t), RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: repoerr.ErrCreateEntity, @@ -129,7 +122,6 @@ func TestSaveInvitation(t *testing.T) { InvitedBy: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: nil, @@ -140,7 +132,6 @@ func TestSaveInvitation(t *testing.T) { DomainID: domainID, UserID: testsutil.GenerateUUID(t), RoleID: roleID, - RoleName: "admin", CreatedAt: time.Now(), }, err: nil, @@ -151,18 +142,6 @@ func TestSaveInvitation(t *testing.T) { InvitedBy: testsutil.GenerateUUID(t), UserID: testsutil.GenerateUUID(t), DomainID: domainID, - RoleName: "admin", - CreatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add invitation with empty invitation role name", - invitation: domains.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: domainID, - RoleID: roleID, CreatedAt: time.Now(), }, err: nil, @@ -192,7 +171,6 @@ func TestInvitationRetrieve(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", CreatedAt: time.Now().UTC().Truncate(time.Microsecond), } @@ -285,7 +263,6 @@ func TestInvitationRetrieveAll(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", CreatedAt: time.Now().UTC().Truncate(time.Microsecond), } err := repo.SaveInvitation(context.Background(), invitation) @@ -669,7 +646,6 @@ func TestInvitationUpdateConfirmation(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", CreatedAt: time.Now(), } err := repo.SaveInvitation(context.Background(), invitation) @@ -732,7 +708,6 @@ func TestInvitationUpdateRejection(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", CreatedAt: time.Now(), } err := repo.SaveInvitation(context.Background(), invitation) @@ -795,7 +770,6 @@ func TestInvitationDelete(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: domainID, RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", CreatedAt: time.Now(), } err := repo.SaveInvitation(context.Background(), invitation) diff --git a/domains/service.go b/domains/service.go index b1d387e0c7..c8643b695c 100644 --- a/domains/service.go +++ b/domains/service.go @@ -178,12 +178,9 @@ func (svc service) ListDomains(ctx context.Context, session authn.Session, p Pag } func (svc *service) SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) error { - role, err := svc.repo.RetrieveRole(ctx, invitation.RoleID) - if err != nil { + if _, err := svc.repo.RetrieveRole(ctx, invitation.RoleID); err != nil { return errors.Wrap(svcerr.ErrInvalidRole, err) } - invitation.RoleName = role.Name - invitation.InvitedBy = session.UserID invitation.CreatedAt = time.Now() @@ -199,11 +196,16 @@ func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, u if err != nil { return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) } + role, err := svc.repo.RetrieveRole(ctx, inv.RoleID) + if err != nil { + return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) + } actions, err := svc.repo.RoleListActions(ctx, inv.RoleID) if err != nil { return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) } inv.Actions = actions + inv.RoleName = role.Name return inv, nil } diff --git a/domains/service_test.go b/domains/service_test.go index 891a02e118..58e1c21e94 100644 --- a/domains/service_test.go +++ b/domains/service_test.go @@ -628,7 +628,6 @@ func TestViewInvitation(t *testing.T) { UserID: testsutil.GenerateUUID(t), DomainID: testsutil.GenerateUUID(t), RoleID: testsutil.GenerateUUID(t), - RoleName: "admin", Actions: []string{"read", "delete"}, CreatedAt: time.Now().Add(-time.Hour), UpdatedAt: time.Now().Add(-time.Hour), @@ -643,6 +642,7 @@ func TestViewInvitation(t *testing.T) { resp domains.Invitation retrieveInvitationErr error listRolesErr error + retrieveRoleErr error err error }{ { @@ -669,17 +669,27 @@ func TestViewInvitation(t *testing.T) { listRolesErr: repoerr.ErrNotFound, err: svcerr.ErrViewEntity, }, + { + desc: "view invitation with failed to retrieve role", + userID: validInvitation.UserID, + domainID: validInvitation.DomainID, + session: validSession, + retrieveRoleErr: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := drepo.On("RetrieveInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.retrieveInvitationErr) repoCall1 := drepo.On("RoleListActions", context.Background(), tc.resp.RoleID).Return(tc.resp.Actions, tc.listRolesErr) + repoCall2 := drepo.On("RetrieveRole", context.Background(), tc.resp.RoleID).Return(roles.Role{}, tc.retrieveRoleErr) inv, err := svc.ViewInvitation(context.Background(), tc.session, tc.userID, tc.domainID) assert.True(t, errors.Contains(err, tc.err)) assert.Equal(t, tc.resp, inv, tc.desc) repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() }) } } diff --git a/pkg/sdk/invitations.go b/pkg/sdk/invitations.go index 64cb400f58..434a07c361 100644 --- a/pkg/sdk/invitations.go +++ b/pkg/sdk/invitations.go @@ -43,7 +43,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error) return errors.NewSDKError(err) } - url := sdk.domainsURL + "/" + invitationsEndpoint + url := sdk.domainsURL + "/" + invitation.DomainID + "/" + invitationsEndpoint _, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated) @@ -51,7 +51,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error) } func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitation, err error) { - url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID + url := sdk.domainsURL + "/" + domainID + "/" + invitationsEndpoint + "/" + userID _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { @@ -121,7 +121,7 @@ func (sdk mgSDK) RejectInvitation(domainID, token string) (err error) { } func (sdk mgSDK) DeleteInvitation(userID, domainID, token string) (err error) { - url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID + url := sdk.domainsURL + "/" + domainID + "/" + invitationsEndpoint + "/" + userID _, _, sdkerr := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) diff --git a/pkg/sdk/invitations_test.go b/pkg/sdk/invitations_test.go index c520957a9e..078f61a90b 100644 --- a/pkg/sdk/invitations_test.go +++ b/pkg/sdk/invitations_test.go @@ -118,7 +118,11 @@ func TestSendInvitation(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.sendInvitationReq.UserID, DomainID: tc.sendInvitationReq.DomainID} + tc.session = smqauthn.Session{ + UserID: tc.sendInvitationReq.UserID, + DomainID: tc.sendInvitationReq.DomainID, + DomainUserID: tc.sendInvitationReq.DomainID + "_" + tc.sendInvitationReq.UserID, + } } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := svc.On("SendInvitation", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr) @@ -186,14 +190,14 @@ func TestViewInvitation(t *testing.T) { err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, { - desc: "view invitation with empty userID", + desc: "view invitation with empty domainID", token: validToken, - userID: "", - domainID: invitation.DomainID, + userID: invitation.UserID, + domainID: "", svcRes: domains.Invitation{}, svcErr: nil, response: sdk.Invitation{}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest), }, { desc: "view invitation with invalid domainID", @@ -209,7 +213,7 @@ func TestViewInvitation(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID} + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.userID} } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcRes, tc.svcErr) @@ -504,12 +508,12 @@ func TestDeleteInvitation(t *testing.T) { err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, { - desc: "delete invitation with empty userID", + desc: "delete invitation with empty domainID", token: validToken, - userID: "", - domainID: invitation.DomainID, + userID: invitation.UserID, + domainID: "", svcErr: nil, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest), }, { desc: "delete invitation with invalid domainID", @@ -523,7 +527,7 @@ func TestDeleteInvitation(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID} + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.userID} } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr)