Skip to content

Commit

Permalink
use helper function for constructing state updates
Browse files Browse the repository at this point in the history
This helps preventing messages being sent with the wrong update type
and payload combination, and it is shorter/neater.

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Feb 6, 2025
1 parent 9ae3570 commit d85bd2f
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 107 deletions.
20 changes: 6 additions & 14 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,9 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
h.cfg.TailcfgDNSConfig.ExtraRecords = records

ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
Type: types.StateFullUpdate,
})
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
}
Expand Down Expand Up @@ -511,9 +509,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not

if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
notif.NotifyAll(ctx, types.UpdateFull())
}

return nil
Expand All @@ -535,9 +531,7 @@ func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not

if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
notif.NotifyAll(ctx, types.UpdateFull())

return true, nil
}
Expand Down Expand Up @@ -872,9 +866,7 @@ func (h *Headscale) Serve() error {
Msg("ACL policy successfully reloaded, notifying nodes of change")

ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
default:
info := func(msg string) { log.Info().Msg(msg) }
Expand Down
14 changes: 4 additions & 10 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,9 @@ func (h *Headscale) handleExistingNode(
}

ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
if changedNodes != nil {
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}
}

Expand All @@ -114,7 +108,7 @@ func (h *Headscale) handleExistingNode(
}

ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
}

return &tailcfg.RegisterResponse{
Expand Down Expand Up @@ -249,7 +243,7 @@ func (h *Headscale) handleRegisterWithAuthKey(

if !updateSent {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID))
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}

return &tailcfg.RegisterResponse{
Expand Down
12 changes: 3 additions & 9 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)

const (
Expand Down Expand Up @@ -626,11 +627,7 @@ func enableRoutes(tx *gorm.DB,

node.Routes = nRoutes

return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "created in db.enableRoutes",
}, nil
return ptr.To(types.UpdatePeerChanged(node.ID)), nil
}

func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
Expand Down Expand Up @@ -717,10 +714,7 @@ func ExpireExpiredNodes(tx *gorm.DB,
}

if len(expired) > 0 {
return started, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}, true
return started, types.UpdatePeerPatch(expired...), true
}

return started, types.StateUpdate{}, false
Expand Down
7 changes: 2 additions & 5 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)

Expand Down Expand Up @@ -470,11 +471,7 @@ nodeRouteLoop:
})

if len(changedNodes) != 0 {
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: chng,
Message: "called from db.FailoverNodeRoutesIfNecessary",
}, nil
return ptr.To(types.UpdatePeerChanged(chng...)), nil
}

return nil, nil
Expand Down
48 changes: 10 additions & 38 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,7 @@ func (api headscaleV1APIServer) RegisterNode(
}
if !updateSent {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}

return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
Expand Down Expand Up @@ -319,11 +316,7 @@ func (api headscaleV1APIServer) SetTags(
}

ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.SetTags",
}, node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)

log.Trace().
Str("node", node.Hostname).
Expand Down Expand Up @@ -364,16 +357,10 @@ func (api headscaleV1APIServer) DeleteNode(
}

ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []types.NodeID{node.ID},
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))

if changedNodes != nil {
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(changedNodes...))
}

return &v1.DeleteNodeResponse{}, nil
Expand Down Expand Up @@ -401,14 +388,11 @@ func (api headscaleV1APIServer) ExpireNode(
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
types.UpdateSelf(node.ID),
node.ID)

ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)

log.Trace().
Str("node", node.Hostname).
Expand Down Expand Up @@ -439,11 +423,7 @@ func (api headscaleV1APIServer) RenameNode(
}

ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.RenameNode",
}, node.ID)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)

log.Trace().
Str("node", node.Hostname).
Expand Down Expand Up @@ -602,10 +582,7 @@ func (api headscaleV1APIServer) DisableRoute(

if update != nil {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}

return &v1.DisableRouteResponse{}, nil
Expand Down Expand Up @@ -644,10 +621,7 @@ func (api headscaleV1APIServer) DeleteRoute(

if update != nil {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(update...))
}

return &v1.DeleteRouteResponse{}, nil
Expand Down Expand Up @@ -809,9 +783,7 @@ func (api headscaleV1APIServer) SetPolicy(
// Only send update if the packet filter has changed.
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}

response := &v1.SetPolicyResponse{
Expand Down
10 changes: 2 additions & 8 deletions hscontrol/notifier/notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,13 @@ func (b *batcher) flush() {
})

if b.changedNodeIDs.Slice().Len() > 0 {
update := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
}
update := types.UpdatePeerChanged(changedNodes...)

b.n.sendAll(update)
}

if len(patches) > 0 {
patchUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: patches,
}
patchUpdate := types.UpdatePeerPatch(patches...)

b.n.sendAll(patchUpdate)
}
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,12 @@ func (a *AuthProviderOIDC) handleRegistration(
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateSelf(node.ID),
types.UpdateSelf(node.ID),
node.ID,
)

ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID)
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}

return newNode, nil
Expand Down
22 changes: 4 additions & 18 deletions hscontrol/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ func (h *Headscale) newMapSession(
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.StateUpdate{
Type: types.StateFullUpdate,
}
updateChan <- types.UpdateFull()
}

ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
Expand Down Expand Up @@ -428,12 +426,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
}

ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
change,
},
}, node.ID)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID)
}

func (m *mapSession) handleEndpointUpdate() {
Expand Down Expand Up @@ -506,10 +499,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{m.node.ID},
},
types.UpdateSelf(m.node.ID),
m.node.ID)
}

Expand All @@ -530,11 +520,7 @@ func (m *mapSession) handleEndpointUpdate() {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{m.node.ID},
Message: "called from handlePoll -> update",
},
types.UpdatePeerChanged(m.node.ID),
m.node.ID,
)

Expand Down
26 changes: 23 additions & 3 deletions hscontrol/types/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,41 @@ func (su *StateUpdate) Empty() bool {
return false
}

func StateSelf(nodeID NodeID) StateUpdate {
func UpdateFull() StateUpdate {
return StateUpdate{
Type: StateFullUpdate,
}
}

func UpdateSelf(nodeID NodeID) StateUpdate {
return StateUpdate{
Type: StateSelfUpdate,
ChangeNodes: []NodeID{nodeID},
}
}

func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate {
func UpdatePeerChanged(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{
Type: StatePeerChanged,
ChangeNodes: nodeIDs,
}
}

func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
func UpdatePeerPatch(changes ...*tailcfg.PeerChange) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: changes,
}
}

func UpdatePeerRemoved(nodeIDs ...NodeID) StateUpdate {
return StateUpdate{
Type: StatePeerRemoved,
Removed: nodeIDs,
}
}

func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
return StateUpdate{
Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
Expand Down

0 comments on commit d85bd2f

Please sign in to comment.