Skip to content

Commit

Permalink
Extend AccountManager with external cache and group/user management m…
Browse files Browse the repository at this point in the history
…ethods (netbirdio#1289)
  • Loading branch information
surik authored Nov 13, 2023
1 parent 3fd09c0 commit 89846b7
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 19 deletions.
37 changes: 28 additions & 9 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour
)

type ExternalCacheManager cache.CacheInterface[*idp.UserData]

func cacheEntryExpiration() time.Duration {
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
return time.Duration(r) * time.Millisecond
Expand All @@ -57,12 +59,14 @@ type AccountManager interface {
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
DeletePeer(accountID, peerID, userID string) error
Expand Down Expand Up @@ -106,19 +110,21 @@ type AccountManager interface {
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
GetExternalCacheManager() ExternalCacheManager
}

type DefaultAccountManager struct {
Store Store
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context
eventStore activity.Store
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
externalCacheManager ExternalCacheManager
ctx context.Context
eventStore activity.Store

// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
Expand Down Expand Up @@ -817,9 +823,13 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage

goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
goCacheStore := cacheStore.NewGoCache(goCacheClient)

am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))

// TODO: what is max expiration time? Should be quite long
am.externalCacheManager = cache.New[*idp.UserData](
cacheStore.NewGoCache(goCacheClient),
)

if !isNil(am.idpManager) {
go func() {
err := am.warmupIDPCache()
Expand All @@ -834,6 +844,10 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
return am, nil
}

func (am *DefaultAccountManager) GetExternalCacheManager() ExternalCacheManager {
return am.externalCacheManager
}

// UpdateAccountSettings updates Account settings.
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
Expand Down Expand Up @@ -1095,10 +1109,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
if !user.IsServiceUser {
users[user.Id] = struct{}{}
if user.IsServiceUser {
continue
}
if user.Issued == UserIssuedIntegration && user.LastLogin.IsZero() {
continue
}
users[user.Id] = struct{}{}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
Expand Down
36 changes: 31 additions & 5 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type MockAccountManager struct {
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error
DeletePeerFunc func(accountID, peerKey, userID string) error
Expand Down Expand Up @@ -54,6 +55,7 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
Expand All @@ -77,6 +79,7 @@ type MockAccountManager struct {
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
GetExternalCacheManagerFunc func() server.ExternalCacheManager
}

// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
Expand Down Expand Up @@ -339,15 +342,22 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
if am.UpdatePeerMetaFunc != nil {
return am.UpdatePeerMetaFunc(peerID, meta)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
}

// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(claims)
}
return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
}

func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) {
if am.ListUsersFunc != nil {
return am.ListUsers(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
}

// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
Expand All @@ -363,7 +373,7 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *server.
if am.UpdatePeerFunc != nil {
return am.UpdatePeerFunc(accountID, userID, peer)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeerFunc is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
}

// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
Expand Down Expand Up @@ -441,6 +451,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
}

// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
if am.SaveUserFunc != nil {
return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
}

// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
Expand Down Expand Up @@ -519,7 +537,7 @@ func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllPeers is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
}

// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
Expand All @@ -535,7 +553,7 @@ func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.E
if am.GetEventsFunc != nil {
return am.GetEventsFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllEvents is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented")
}

// GetDNSSettings mocks GetDNSSettings of the AccountManager interface
Expand Down Expand Up @@ -600,3 +618,11 @@ func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string
am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta)
}
}

// GetExternalCacheManager mocks GetExternalCacheManager of the AccountManager interface
func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheManager {
if am.GetExternalCacheManagerFunc() != nil {
return am.GetExternalCacheManagerFunc()
}
return nil
}
68 changes: 64 additions & 4 deletions management/server/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ type IntegrationReference struct {
}

func (ir IntegrationReference) String() string {
return fmt.Sprintf("%d:%s", ir.ID, ir.IntegrationType)
return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID)
}

func (ir IntegrationReference) CacheKey(path ...string) string {
if len(path) == 0 {
return ir.String()
}
return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":"))
}

// User represents a user of the system
Expand Down Expand Up @@ -355,6 +362,25 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
return user, nil
}

// ListUsers returns lists of all users under the account.
// It doesn't populate user information such a email or name.
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()

account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}

users := make([]*User, 0, len(account.Users))
for _, item := range account.Users {
users = append(users, item)
}

return users, nil
}

func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) {
meta := map[string]any{"name": targetUser.ServiceUserName}
am.StoreEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta)
Expand Down Expand Up @@ -654,8 +680,13 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID st
}

// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) {
return am.SaveOrAddUser(accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
}

// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()

Expand All @@ -679,7 +710,11 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd

oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
}
// will add a user based on input
oldUser = update
}

if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
Expand All @@ -691,6 +726,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
}

// only auto groups, revoked status, and name can be updated for now
// when addIfNotExists is set to true the newUser will use all fields from the update input
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
Expand Down Expand Up @@ -779,7 +815,16 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
return nil, err
}
if userData == nil {
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
// lets check external cache
key := newUser.IntegrationReference.CacheKey(account.Id, newUser.Id)
log.Debugf("looking up user %s of account %s in external cache", key, account.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
}

return newUser.ToUserInfo(info)
}
return newUser.ToUserInfo(userData)
}
Expand Down Expand Up @@ -839,7 +884,19 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.Issued == UserIssuedIntegration && user.LastLogin.IsZero() {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = struct{}{}
continue
}
usersFromIntegration = append(usersFromIntegration, info)
continue
}
if !user.IsServiceUser {
users[user.Id] = struct{}{}
}
Expand All @@ -848,6 +905,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
if err != nil {
return nil, err
}
log.Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
log.Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
queriedUsers = append(queriedUsers, usersFromIntegration...)
}

userInfos := make([]*UserInfo, 0)
Expand Down
Loading

0 comments on commit 89846b7

Please sign in to comment.