diff --git a/management/server/account.go b/management/server/account.go index 72966065076..26559ce4891 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -42,6 +42,8 @@ const ( DefaultPeerLoginExpiration = 24 * time.Hour ) +type ExternalCacheManager cache.CacheInterface[*idp.UserData] + func cacheEntryExpiration() time.Duration { r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds()) return time.Duration(r) * time.Millisecond @@ -57,12 +59,14 @@ type AccountManager interface { InviteUser(accountID string, initiatorUserID string, targetUserID string) error ListSetupKeys(accountID, userID string) ([]*SetupKey, error) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) + SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) MarkPATUsed(tokenID string) error GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) + ListUsers(accountID string) ([]*User, error) GetPeers(accountID, userID string) ([]*Peer, error) MarkPeerConnected(peerKey string, connected bool) error DeletePeer(accountID, peerID, userID string) error @@ -106,6 +110,7 @@ type AccountManager interface { LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) + GetExternalCacheManager() ExternalCacheManager } type DefaultAccountManager struct { @@ -113,12 +118,13 @@ type DefaultAccountManager struct { // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded - cacheLoading map[string]chan struct{} - peersUpdateManager *PeersUpdateManager - idpManager idp.Manager - cacheManager cache.CacheInterface[[]*idp.UserData] - ctx context.Context - eventStore activity.Store + cacheLoading map[string]chan struct{} + peersUpdateManager *PeersUpdateManager + idpManager idp.Manager + cacheManager cache.CacheInterface[[]*idp.UserData] + externalCacheManager ExternalCacheManager + ctx context.Context + eventStore activity.Store // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. @@ -817,9 +823,13 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) goCacheStore := cacheStore.NewGoCache(goCacheClient) - am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore)) + // TODO: what is max expiration time? Should be quite long + am.externalCacheManager = cache.New[*idp.UserData]( + cacheStore.NewGoCache(goCacheClient), + ) + if !isNil(am.idpManager) { go func() { err := am.warmupIDPCache() @@ -834,6 +844,10 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage return am, nil } +func (am *DefaultAccountManager) GetExternalCacheManager() ExternalCacheManager { + return am.externalCacheManager +} + // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. @@ -1095,10 +1109,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { users := make(map[string]struct{}, len(account.Users)) + // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { - if !user.IsServiceUser { - users[user.Id] = struct{}{} + if user.IsServiceUser { + continue + } + if user.Issued == UserIssuedIntegration && user.LastLogin.IsZero() { + continue } + users[user.Id] = struct{}{} } log.Debugf("looking up user %s of account %s in cache", userID, account.Id) userData, err := am.lookupCache(users, account.Id) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46b8551c56b..d953e2fbf5d 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -20,6 +20,7 @@ type MockAccountManager struct { GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) + ListUsersFunc func(accountID string) ([]*server.User, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool) error DeletePeerFunc func(accountID, peerKey, userID string) error @@ -54,6 +55,7 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) + SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error @@ -77,6 +79,7 @@ type MockAccountManager struct { SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) + GetExternalCacheManagerFunc func() server.ExternalCacheManager } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -339,7 +342,7 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst if am.UpdatePeerMetaFunc != nil { return am.UpdatePeerMetaFunc(peerID, meta) } - return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented") + return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented") } // GetUser mock implementation of GetUser from server.AccountManager interface @@ -347,7 +350,14 @@ func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*se if am.GetUserFunc != nil { return am.GetUserFunc(claims) } - return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") +} + +func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) { + if am.ListUsersFunc != nil { + return am.ListUsers(accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } // UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager @@ -363,7 +373,7 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *server. if am.UpdatePeerFunc != nil { return am.UpdatePeerFunc(accountID, userID, peer) } - return nil, status.Errorf(codes.Unimplemented, "method UpdatePeerFunc is is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is is not implemented") } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface @@ -441,6 +451,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") } +// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface +func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { + if am.SaveUserFunc != nil { + return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists) + } + return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") +} + // DeleteUser mocks DeleteUser of the AccountManager interface func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error { if am.DeleteUserFunc != nil { @@ -519,7 +537,7 @@ func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer if am.GetAccountFromTokenFunc != nil { return am.GetPeersFunc(accountID, userID) } - return nil, status.Errorf(codes.Unimplemented, "method GetAllPeers is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") } // GetDNSDomain mocks GetDNSDomain of the AccountManager interface @@ -535,7 +553,7 @@ func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.E if am.GetEventsFunc != nil { return am.GetEventsFunc(accountID, userID) } - return nil, status.Errorf(codes.Unimplemented, "method GetAllEvents is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented") } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface @@ -600,3 +618,11 @@ func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta) } } + +// GetExternalCacheManager mocks GetExternalCacheManager of the AccountManager interface +func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheManager { + if am.GetExternalCacheManagerFunc() != nil { + return am.GetExternalCacheManagerFunc() + } + return nil +} diff --git a/management/server/user.go b/management/server/user.go index 2eaea88eb6c..adb3539b0a7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -52,7 +52,14 @@ type IntegrationReference struct { } func (ir IntegrationReference) String() string { - return fmt.Sprintf("%d:%s", ir.ID, ir.IntegrationType) + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) } // User represents a user of the system @@ -355,6 +362,25 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( return user, nil } +// ListUsers returns lists of all users under the account. +// It doesn't populate user information such a email or name. +func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + users := make([]*User, 0, len(account.Users)) + for _, item := range account.Users { + users = append(users, item) + } + + return users, nil +} + func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) { meta := map[string]any{"name": targetUser.ServiceUserName} am.StoreEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) @@ -654,8 +680,13 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID st } // SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. -// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) { + return am.SaveOrAddUser(accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound +} + +// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist +// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. +func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -679,7 +710,11 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd oldUser := account.Users[update.Id] if oldUser == nil { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist") + if !addIfNotExists { + return nil, status.Errorf(status.NotFound, "user to update doesn't exist") + } + // will add a user based on input + oldUser = update } if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked { @@ -691,6 +726,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd } // only auto groups, revoked status, and name can be updated for now + // when addIfNotExists is set to true the newUser will use all fields from the update input newUser := oldUser.Copy() newUser.Role = update.Role newUser.Blocked = update.Blocked @@ -839,7 +875,19 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { users := make(map[string]struct{}, len(account.Users)) + usersFromIntegration := make([]*idp.UserData, 0) for _, user := range account.Users { + if user.Issued == UserIssuedIntegration && !user.LastLogin.IsZero() { + key := user.IntegrationReference.CacheKey(accountID, user.Id) + info, err := am.externalCacheManager.Get(am.ctx, key) + if err != nil { + log.Infof("Get ExternalCache for key: %s, error: %s", key, err) + users[user.Id] = struct{}{} + continue + } + usersFromIntegration = append(usersFromIntegration, info) + continue + } if !user.IsServiceUser { users[user.Id] = struct{}{} } @@ -848,6 +896,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( if err != nil { return nil, err } + log.Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID) + log.Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID) + queriedUsers = append(queriedUsers, usersFromIntegration...) } userInfos := make([]*UserInfo, 0) diff --git a/management/server/user_test.go b/management/server/user_test.go index 40e7f9a2d27..818a267d4ae 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1,16 +1,21 @@ package server import ( + "context" "fmt" "reflect" "testing" "time" + "github.com/eko/gocache/v3/cache" + cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -549,6 +554,95 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { assert.False(t, user.IsBlocked()) } +func TestDefaultAccountManager_ListUsers(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users["normal_user1"] = NewRegularUser("normal_user1") + account.Users["normal_user2"] = NewRegularUser("normal_user2") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + users, err := am.ListUsers(mockAccountID) + if err != nil { + t.Fatalf("Error when checking user role: %s", err) + } + + admins := 0 + regular := 0 + for _, user := range users { + if user.IsAdmin() { + admins++ + continue + } + regular++ + } + assert.Equal(t, 3, len(users)) + assert.Equal(t, 1, admins) + assert.Equal(t, 2, regular) +} + +func TestDefaultAccountManager_ExternalCache(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + externalUser := &User{ + Id: "externalUser", + Role: UserRoleUser, + Issued: UserIssuedIntegration, + IntegrationReference: IntegrationReference{ + ID: 1, + IntegrationType: "external", + }, + } + account.Users[externalUser.Id] = externalUser + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + idpManager: &idp.GoogleWorkspaceManager{}, // empty manager + cacheLoading: map[string]chan struct{}{}, + cacheManager: cache.New[[]*idp.UserData]( + cacheStore.NewGoCache(gocache.New(CacheExpirationMax, 30*time.Minute)), + ), + externalCacheManager: cache.New[*idp.UserData]( + cacheStore.NewGoCache(gocache.New(CacheExpirationMax, 30*time.Minute)), + ), + } + + // pretend that we receive mockUserID from IDP + err = am.cacheManager.Set(am.ctx, mockAccountID, []*idp.UserData{{Name: mockUserID, ID: mockUserID}}) + assert.NoError(t, err) + + cacheManager := am.GetExternalCacheManager() + cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) + err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) + assert.NoError(t, err) + + infos, err := am.GetUsersFromAccount(mockAccountID, mockUserID) + assert.NoError(t, err) + assert.Equal(t, 2, len(infos)) + var user *UserInfo + for _, info := range infos { + if info.ID == externalUser.Id { + user = info + } + } + assert.NotNil(t, user) + assert.Equal(t, "user@example.com", user.Email) +} + func TestUser_IsAdmin(t *testing.T) { user := NewAdminUser(mockUserID) @@ -710,5 +804,4 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked) } } - }