From 2c5435ef558260d34ff6a99be3cfa88cd8cceae6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 24 Jan 2025 15:47:28 +0100 Subject: [PATCH 1/5] Implement iptables dnat/snat rules --- client/firewall/iptables/acl_linux.go | 21 +- client/firewall/iptables/manager_linux.go | 13 +- client/firewall/iptables/router_linux.go | 238 ++++++++++++++---- client/firewall/iptables/router_linux_test.go | 18 +- client/firewall/iptables/rule.go | 2 +- client/firewall/manager/firewall.go | 5 +- client/firewall/manager/forward_rule.go | 6 +- client/firewall/uspfilter/rule.go | 2 +- client/firewall/uspfilter/uspfilter.go | 14 +- client/firewall/uspfilter/uspfilter_test.go | 4 +- client/internal/acl/id/id.go | 2 +- client/internal/acl/manager.go | 4 +- client/internal/acl/manager_test.go | 4 +- client/internal/ingressgw/manager.go | 8 +- client/internal/ingressgw/manager_test.go | 4 +- 15 files changed, 250 insertions(+), 95 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 6c4895e05ed..e380faf90b0 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -30,10 +30,8 @@ type entry struct { } type aclManager struct { - iptablesClient *iptables.IPTables - wgIface iFaceMapper - routingFwChainName string - + iptablesClient *iptables.IPTables + wgIface iFaceMapper entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore @@ -41,12 +39,10 @@ type aclManager struct { stateManager *statemanager.Manager } -func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { m := &aclManager{ - iptablesClient: iptablesClient, - wgIface: wgIface, - routingFwChainName: routingFwChainName, - + iptablesClient: iptablesClient, + wgIface: wgIface, entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), @@ -314,9 +310,12 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) + // inbound is handled by our ACLs, the rest is dropped + // for outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules. m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) - m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) + + m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) } func (m *aclManager) seedInitialOptionalEntries() { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 4a977aea07c..192fae749f9 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { return nil, fmt.Errorf("create router: %w", err) } - m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) + m.aclMgr, err = newAclManager(iptablesClient, wgIface) if err != nil { return nil, fmt.Errorf("create acl manager: %w", err) } @@ -213,13 +213,20 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// AddDNATRule adds a DNAT rule func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { - return nil, fmt.Errorf("not implemented") + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddDNATRule(rule) } // DeleteDNATRule deletes a DNAT rule func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { - return nil + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteDNATRule(rule) } func getConntrackEstablished() []string { diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a47d3ffe698..6d55cebe4c0 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -23,22 +23,36 @@ import ( // constants needed to manage and create iptable rules const ( - tableFilter = "filter" - tableNat = "nat" - tableMangle = "mangle" + tableFilter = "filter" + tableNat = "nat" + tableMangle = "mangle" + chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" chainRTNAT = "NETBIRD-RT-NAT" - chainRTFWD = "NETBIRD-RT-FWD" + chainRTFWDIN = "NETBIRD-RT-FWD-IN" + chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" + chainRTRDR = "NETBIRD-RT-RDR" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" - jumpPre = "jump-pre" - jumpNat = "jump-nat" - matchSet = "--match-set" + jumpManglePre = "jump-mangle-pre" + jumpNatPre = "jump-nat-pre" + jumpNatPost = "jump-nat-post" + matchSet = "--match-set" + + dnatSuffix = "_dnat" + snatSuffix = "_snat" + fwdSuffix = "_fwd" ) +type ruleInfo struct { + chain string + table string + rule []string +} + type routeFilteringRuleParams struct { Sources []netip.Prefix Destination netip.Prefix @@ -135,7 +149,7 @@ func (r *router) AddRouteFiltering( } rule := genRouteFilteringRuleSpec(params) - if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil { return nil, fmt.Errorf("add route rule: %v", err) } @@ -147,12 +161,12 @@ func (r *router) AddRouteFiltering( } func (r *router) DeleteRouteRule(rule firewall.Rule) error { - ruleKey := rule.GetRuleID() + ruleKey := rule.ID() if rule, exists := r.rules[ruleKey]; exists { setName := r.findSetNameInRule(rule) - if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("delete route rule: %v", err) } delete(r.rules, ruleKey) @@ -255,7 +269,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { } rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} - if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } @@ -268,7 +282,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } delete(r.rules, ruleKey) @@ -296,7 +310,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { continue } - if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } else { delete(r.rules, k) @@ -334,9 +348,11 @@ func (r *router) cleanUpDefaultForwardRules() error { chain string table string }{ - {chainRTFWD, tableFilter}, - {chainRTNAT, tableNat}, + {chainRTFWDIN, tableFilter}, + {chainRTFWDOUT, tableFilter}, {chainRTPRE, tableMangle}, + {chainRTNAT, tableNat}, + {chainRTRDR, tableNat}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -356,16 +372,22 @@ func (r *router) createContainers() error { chain string table string }{ - {chainRTFWD, tableFilter}, + {chainRTFWDIN, tableFilter}, + {chainRTFWDOUT, tableFilter}, {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, + {chainRTRDR, tableNat}, } { - if err := r.createAndSetupChain(chainInfo.chain); err != nil { + if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } } - if err := r.insertEstablishedRule(chainRTFWD); err != nil { + if err := r.insertEstablishedRule(chainRTFWDIN); err != nil { + return fmt.Errorf("insert established rule: %w", err) + } + + if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil { return fmt.Errorf("insert established rule: %w", err) } @@ -406,27 +428,6 @@ func (r *router) addPostroutingRules() error { return nil } -func (r *router) createAndSetupChain(chain string) error { - table := r.getTableForChain(chain) - - if err := r.iptablesClient.NewChain(table, chain); err != nil { - return fmt.Errorf("failed creating chain %s, error: %v", chain, err) - } - - return nil -} - -func (r *router) getTableForChain(chain string) string { - switch chain { - case chainRTNAT: - return tableNat - case chainRTPRE: - return tableMangle - default: - return tableFilter - } -} - func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -445,28 +446,43 @@ func (r *router) addJumpRules() error { // Jump to NAT chain natRule := []string{"-j", chainRTNAT} if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { - return fmt.Errorf("add nat jump rule: %v", err) + return fmt.Errorf("add nat postrouting jump rule: %v", err) } - r.rules[jumpNat] = natRule + r.rules[jumpNatPost] = natRule - // Jump to prerouting chain + // Jump to mangle prerouting chain preRule := []string{"-j", chainRTPRE} if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { - return fmt.Errorf("add prerouting jump rule: %v", err) + return fmt.Errorf("add mangle prerouting jump rule: %v", err) + } + r.rules[jumpManglePre] = preRule + + // Jump to nat prerouting chain + rdrRule := []string{"-j", chainRTRDR} + if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil { + return fmt.Errorf("add nat prerouting jump rule: %v", err) } - r.rules[jumpPre] = preRule + r.rules[jumpNatPre] = rdrRule return nil } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNat, jumpPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { if rule, exists := r.rules[ruleKey]; exists { - table := tableNat - chain := chainPOSTROUTING - if ruleKey == jumpPre { + var table, chain string + switch ruleKey { + case jumpNatPost: + table = tableNat + chain = chainPOSTROUTING + case jumpManglePre: table = tableMangle chain = chainPREROUTING + case jumpNatPre: + table = tableNat + chain = chainPREROUTING + default: + return fmt.Errorf("unknown jump rule: %s", ruleKey) } if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { @@ -511,6 +527,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } r.rules[ruleKey] = rule + + r.updateState() return nil } @@ -526,6 +544,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { log.Debugf("marking rule %s not found", ruleKey) } + r.updateState() return nil } @@ -555,6 +574,129 @@ func (r *router) updateState() { } } +func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + ruleKey := rule.ID() + if _, exists := r.rules[ruleKey+dnatSuffix]; exists { + return rule, nil + } + + toDestination := rule.TranslatedAddress.String() + switch { + case len(rule.TranslatedPort.Values) == 0: + // no translated port, use original port + case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2: + // need the "/originalport" suffix to avoid dnat port randomization + toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0]) + case len(rule.TranslatedPort.Values) == 1: + toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0]) + default: + return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort) + } + + proto := strings.ToLower(string(rule.Protocol)) + + rules := make(map[string]ruleInfo, 3) + + // DNAT rule + dnatRule := []string{ + "!", "-i", r.wgIface.Name(), + "-p", proto, + "-j", "DNAT", + "--to-destination", toDestination, + } + dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...) + rules[ruleKey+dnatSuffix] = ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + // SNAT rule + snatRule := []string{ + "-o", r.wgIface.Name(), + "-p", proto, + "-d", rule.TranslatedAddress.String(), + "-j", "MASQUERADE", + } + snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...) + rules[ruleKey+snatSuffix] = ruleInfo{ + table: tableNat, + chain: chainRTNAT, + rule: snatRule, + } + + // Forward filtering rule, if fwd policy is DROP + forwardRule := []string{ + "-o", r.wgIface.Name(), + "-p", proto, + "-d", rule.TranslatedAddress.String(), + "-j", "ACCEPT", + } + forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...) + rules[ruleKey+fwdSuffix] = ruleInfo{ + table: tableFilter, + chain: chainRTFWDOUT, + rule: forwardRule, + } + + for key, ruleInfo := range rules { + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + if rollbackErr := r.rollbackRules(rules); rollbackErr != nil { + log.Errorf("rollback failed: %v", rollbackErr) + } + return nil, fmt.Errorf("add rule %s: %w", key, err) + } + r.rules[key] = ruleInfo.rule + } + + r.updateState() + return rule, nil +} + +func (r *router) rollbackRules(rules map[string]ruleInfo) error { + var merr *multierror.Error + for key, ruleInfo := range rules { + if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err)) + // On rollback error, add to rules map for next cleanup + r.rules[key] = ruleInfo.rule + } + } + if merr != nil { + r.updateState() + } + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) DeleteDNATRule(rule firewall.Rule) error { + ruleKey := rule.ID() + + var merr *multierror.Error + if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err)) + } + delete(r.rules, ruleKey+dnatSuffix) + } + + if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err)) + } + delete(r.rules, ruleKey+snatSuffix) + } + + if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) + } + delete(r.rules, ruleKey+fwdSuffix) + } + + r.updateState() + return nberrors.FormatErrorOrNil(merr) +} + func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { var rule []string diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 0eb20756756..3f132504fec 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { }() // Now 5 rules: - // 1. established rule in forward chain - // 2. jump rule to NAT chain - // 3. jump rule to PRE chain - // 4. static outbound masquerade rule - // 5. static return masquerade rule - require.Len(t, manager.rules, 5, "should have created rules map") + // 1. established rule forward in + // 2. estbalished rule forward out + // 3. jump rule to POST nat chain + // 4. jump rule to PRE mangle chain + // 5. jump rule to PRE nat chain + // 6. static outbound masquerade rule + // 7. static return masquerade rule + require.Len(t, manager.rules, 7, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -332,14 +334,14 @@ func TestRouter_AddRouteFiltering(t *testing.T) { require.NoError(t, err, "AddRouteFiltering failed") // Check if the rule is in the internal map - rule, ok := r.rules[ruleKey.GetRuleID()] + rule, ok := r.rules[ruleKey.ID()] assert.True(t, ok, "Rule not found in internal map") // Log the internal rule t.Logf("Internal rule: %v", rule) // Check if the rule exists in iptables - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) + exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...) assert.NoError(t, err, "Failed to check rule existence") assert.True(t, exists, "Rule not found in iptables") diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index e90e32f8b02..aa4d2d07900 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -12,6 +12,6 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *Rule) ID() string { return r.ruleID } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index dc4b737b601..a031213ea4a 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -26,8 +26,8 @@ const ( // Each firewall type for different OS can use different type // of the properties to hold data of the created rule type Rule interface { - // GetRuleID returns the rule id - GetRuleID() string + // ID returns the rule id + ID() string } // RuleDirection is the traffic direction which a rule is applied @@ -104,7 +104,6 @@ type Manager interface { AddDNATRule(ForwardRule) (Rule, error) // DeleteDNATRule deletes a DNAT rule - // todo: do you need a string ID or the complete rule? DeleteDNATRule(Rule) error } diff --git a/client/firewall/manager/forward_rule.go b/client/firewall/manager/forward_rule.go index 347a62ac865..21a43520e58 100644 --- a/client/firewall/manager/forward_rule.go +++ b/client/firewall/manager/forward_rule.go @@ -5,8 +5,6 @@ import ( "net/netip" ) -type ForwardRuleID string - // ForwardRule todo figure out better place to this to avoid circular imports type ForwardRule struct { Protocol Protocol @@ -15,13 +13,13 @@ type ForwardRule struct { TranslatedPort Port } -func (r ForwardRule) ID() ForwardRuleID { +func (r ForwardRule) ID() string { id := fmt.Sprintf("%s;%s;%s;%s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String()) - return ForwardRuleID(id) + return id } func (r ForwardRule) String() string { diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index c59d4b264ce..aa346dea6f6 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -24,6 +24,6 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *Rule) ID() string { return r.id } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index b5f0c811a16..049011f73c7 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "errors" "fmt" "net" "net/netip" @@ -25,7 +26,8 @@ const layerTypeAll = 0 const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" var ( - errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") + errRouteNotSupported = errors.New("route not supported with userspace firewall") + errNatNotSupported = errors.New("nat not supported with userspace firewall") ) // IFaceMapper defines subset methods of interface required for manager @@ -250,12 +252,18 @@ func (m *Manager) Flush() error { return nil } // AddDNATRule adds a DNAT rule func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { - return nil, fmt.Errorf("not implemented") + if m.nativeFirewall == nil { + return nil, errNatNotSupported + } + return m.nativeFirewall.AddDNATRule(rule) } // DeleteDNATRule deletes a DNAT rule func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { - return nil + if m.nativeFirewall == nil { + return errNatNotSupported + } + return m.nativeFirewall.DeleteDNATRule(rule) } // DropOutgoing filter outgoing packets diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 9d795de691f..c4c02330b9a 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -114,7 +114,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { + if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { t.Errorf("rule2 is not in the incomingRules") } } @@ -128,7 +128,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok { + if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { t.Errorf("rule2 is not in the incomingRules") } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index 8ce73655d5f..93f16b429e1 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -12,7 +12,7 @@ import ( type RuleID string -func (r RuleID) GetRuleID() string { +func (r RuleID) ID() string { return string(r) } diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 9ec0bb031de..e810d4179fe 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -245,7 +245,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul return "", fmt.Errorf("add route rule: %w", err) } - return id.RuleID(addedRule.GetRuleID()), nil + return id.RuleID(addedRule.ID()), nil } func (d *DefaultManager) protoRuleToFirewallRule( @@ -499,7 +499,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { for _, rules := range newRulePairs { for _, rule := range rules { if err := d.firewall.DeletePeerRule(rule); err != nil { - log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) + log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) } } } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 6049b4f48e2..1edbeb9aec2 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -73,7 +73,7 @@ func TestDefaultManager(t *testing.T) { t.Run("add extra rules", func(t *testing.T) { existedPairs := map[string]struct{}{} for id := range acl.peerRulesPairs { - existedPairs[id.GetRuleID()] = struct{}{} + existedPairs[id.ID()] = struct{}{} } // remove first rule @@ -99,7 +99,7 @@ func TestDefaultManager(t *testing.T) { // check that old rule was removed previousCount := 0 for id := range acl.peerRulesPairs { - if _, ok := existedPairs[id.GetRuleID()]; ok { + if _, ok := existedPairs[id.ID()]; ok { previousCount++ } } diff --git a/client/internal/ingressgw/manager.go b/client/internal/ingressgw/manager.go index 5fbf939da98..829c8a09a9e 100644 --- a/client/internal/ingressgw/manager.go +++ b/client/internal/ingressgw/manager.go @@ -24,14 +24,14 @@ type RulePair struct { type Manager struct { dnatFirewall DNATFirewall - rules map[firewall.ForwardRuleID]RulePair // keys is the ID of the ForwardRule + rules map[string]RulePair // keys is the ID of the ForwardRule rulesMu sync.Mutex } func NewManager(dnatFirewall DNATFirewall) *Manager { return &Manager{ dnatFirewall: dnatFirewall, - rules: make(map[firewall.ForwardRuleID]RulePair), + rules: make(map[string]RulePair), } } @@ -40,7 +40,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error { defer h.rulesMu.Unlock() var mErr *multierror.Error - toDelete := make(map[firewall.ForwardRuleID]RulePair) + toDelete := make(map[string]RulePair) for id, r := range h.rules { toDelete[id] = r } @@ -89,7 +89,7 @@ func (h *Manager) Close() error { } } - h.rules = make(map[firewall.ForwardRuleID]RulePair) + h.rules = make(map[string]RulePair) return nberrors.FormatErrorOrNil(mErr) } diff --git a/client/internal/ingressgw/manager_test.go b/client/internal/ingressgw/manager_test.go index 009d8c46b29..c334e82cc94 100644 --- a/client/internal/ingressgw/manager_test.go +++ b/client/internal/ingressgw/manager_test.go @@ -14,10 +14,10 @@ var ( ) type MocFwRule struct { - id firewall.ForwardRuleID + id string } -func (m *MocFwRule) GetRuleID() string { +func (m *MocFwRule) ID() string { return string(m.id) } From e9e112747a1142b3e3f78a96fe90f29be45eb5d0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 25 Jan 2025 13:18:22 +0100 Subject: [PATCH 2/5] Add nftables --- client/firewall/iptables/router_linux.go | 4 +- client/firewall/nftables/acl_linux.go | 6 +- client/firewall/nftables/manager_linux.go | 13 +- client/firewall/nftables/router_linux.go | 320 +++++++++++++++++- client/firewall/nftables/router_linux_test.go | 4 +- client/firewall/nftables/rule_linux.go | 2 +- 6 files changed, 322 insertions(+), 27 deletions(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 6d55cebe4c0..d3441c69a3a 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -584,11 +584,11 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { switch { case len(rule.TranslatedPort.Values) == 0: // no translated port, use original port + case len(rule.TranslatedPort.Values) == 1: + toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0]) case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2: // need the "/originalport" suffix to avoid dnat port randomization toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0]) - case len(rule.TranslatedPort.Values) == 1: - toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0]) default: return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort) } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index fc5cc6873cf..df55aa0439e 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { log.Errorf("failed to delete mangle rule: %v", err) } } - delete(m.rules, r.GetRuleID()) + delete(m.rules, r.ID()) return m.rConn.Flush() } @@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { log.Errorf("failed to delete mangle rule: %v", err) } } - delete(m.rules, r.GetRuleID()) + delete(m.rules, r.ID()) return m.rConn.Flush() } @@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { return err } - delete(m.rules, r.GetRuleID()) + delete(m.rules, r.ID()) m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name) if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 2a064e7dafc..3f2ae8a9730 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -331,15 +331,18 @@ func (m *Manager) Flush() error { // AddDNATRule adds a DNAT rule func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { - r := &Rule{ - ruleID: string(rule.ID()), - } - return r, nil + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddDNATRule(rule) } // DeleteDNATRule deletes a DNAT rule func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { - return nil + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteDNATRule(rule) } func (m *Manager) createWorkTable() (*nftables.Table, error) { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 19734673b72..03c8b08ab2e 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -14,6 +14,7 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -25,12 +26,18 @@ import ( ) const ( - chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-postrouting" - chainNameForward = "FORWARD" + tableNat = "nat" + chainNameNatPrerouting = "PREROUTING" + chainNameRoutingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-postrouting" + chainNameRoutingRdr = "netbird-rt-redirect" + chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + + dnatSuffix = "_dnat" + snatSuffix = "_snat" ) const refreshRulesMapError = "refresh rules map: %w" @@ -98,7 +105,53 @@ func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() - return r.removeAcceptForwardRules() + var merr *multierror.Error + + if err := r.removeAcceptForwardRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + } + + if err := r.removeNatPreroutingRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) removeNatPreroutingRules() error { + table := &nftables.Table{ + Name: tableNat, + Family: nftables.TableFamilyIPv4, + } + chain := &nftables.Chain{ + Name: chainNameNatPrerouting, + Table: table, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityNATDest, + Type: nftables.ChainTypeNAT, + } + rules, err := r.conn.GetRules(table, chain) + if err != nil { + log.Debugf("err type: %T", err) + return fmt.Errorf("get rules from nat table: %w", err) + } + + var merr *multierror.Error + + // Delete rules that have our UserData suffix + for _, rule := range rules { + if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) { + continue + } + if err := r.conn.DelRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err)) + } + } + + if err := r.conn.Flush(); err != nil { + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) + } + return nberrors.FormatErrorOrNil(merr) } func (r *router) loadFilterTable() (*nftables.Table, error) { @@ -133,14 +186,22 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) + r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingRdr, + Table: r.workTable, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityNATDest, + Type: nftables.ChainTypeNAT, + }) + // Chain is created by acl manager // TODO: move creation to a common place r.chains[chainNamePrerouting] = &nftables.Chain{ Name: chainNamePrerouting, Table: r.workTable, - Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, } // Add the single NAT rule that matches on mark @@ -275,7 +336,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return fmt.Errorf(refreshRulesMapError, err) } - ruleKey := rule.GetRuleID() + ruleKey := rule.ID() nftRule, exists := r.rules[ruleKey] if !exists { log.Debugf("route rule %s not found", ruleKey) @@ -890,6 +951,241 @@ func (r *router) refreshRulesMap() error { return nil } +func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + ruleKey := rule.ID() + if _, exists := r.rules[ruleKey+dnatSuffix]; exists { + return rule, nil + } + + protoNum, err := protoToInt(rule.Protocol) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %w", err) + } + + if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil { + return nil, err + } + + r.addDnatMasq(rule, protoNum, ruleKey) + + // Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT. + // To overcome DROP policies in other chains, we'd have to add rules to the chains there. + // We also cannot just add "oif accept" there and filter in our own table as we don't know what is supposed to be allowed. + // TODO: find chains with drop policies and add rules there + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush rules: %w", err) + } + + return &rule, nil +} + +func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error { + dnatExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + } + dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...) + + var regProtoMin, regProtoMax uint32 + switch { + case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2: + // shifted translated port is not supported in nftables, so we hand this over to xtables + if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] || + rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] { + return r.addXTablesRedirect(dnatExprs, ruleKey, rule) + } + + dnatExprs = append(dnatExprs, + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), + }, + &expr.Immediate{ + Register: 3, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]), + }, + ) + regProtoMin = 2 + regProtoMax = 3 + case len(rule.TranslatedPort.Values) == 0: + // address only + dnatExprs = append(dnatExprs, + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + ) + case len(rule.TranslatedPort.Values) == 1: + dnatExprs = append(dnatExprs, + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), + }, + ) + regProtoMin = 2 + default: + return fmt.Errorf("invalid translated port: %v", rule.TranslatedPort) + } + + dnatExprs = append(dnatExprs, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: regProtoMin, + RegProtoMax: regProtoMax, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: dnatExprs, + UserData: []byte(ruleKey + dnatSuffix), + } + r.conn.AddRule(dnatRule) + r.rules[ruleKey+dnatSuffix] = dnatRule + + return nil +} + +func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error { + dnatExprs = append(dnatExprs, + &expr.Counter{}, + &expr.Target{ + Name: "DNAT", + Rev: 2, + Info: &xt.NatRange2{ + NatRange: xt.NatRange{ + Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset), + MinIP: rule.TranslatedAddress.AsSlice(), + MaxIP: rule.TranslatedAddress.AsSlice(), + MinPort: rule.TranslatedPort.Values[0], + MaxPort: rule.TranslatedPort.Values[1], + }, + BasePort: rule.DestinationPort.Values[0], + }, + }, + ) + + dnatRule := &nftables.Rule{ + Table: &nftables.Table{ + Name: tableNat, + Family: nftables.TableFamilyIPv4, + }, + Chain: &nftables.Chain{ + Name: chainNameNatPrerouting, + Table: r.filterTable, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityNATDest, + }, + Exprs: dnatExprs, + UserData: []byte(ruleKey + dnatSuffix), + } + r.conn.AddRule(dnatRule) + r.rules[ruleKey+dnatSuffix] = dnatRule + + return nil +} + +func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) { + masqExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + } + + masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...) + masqExprs = append(masqExprs, &expr.Masq{}) + + masqRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: masqExprs, + UserData: []byte(ruleKey + snatSuffix), + } + r.conn.AddRule(masqRule) + r.rules[ruleKey+snatSuffix] = masqRule +} + +func (r *router) DeleteDNATRule(rule firewall.Rule) error { + ruleKey := rule.ID() + + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + var merr *multierror.Error + if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists { + if err := r.conn.DelRule(dnatRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err)) + } + } + + if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists { + if err := r.conn.DelRule(masqRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err)) + } + } + + if err := r.conn.Flush(); err != nil { + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) + } + + if merr == nil { + delete(r.rules, ruleKey+dnatSuffix) + delete(r.rules, ruleKey+snatSuffix) + } + + return nberrors.FormatErrorOrNil(merr) +} + // generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { var offset uint32 @@ -953,15 +1249,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any { if port.IsRange && len(port.Values) == 2 { // Handle port range exprs = append(exprs, - &expr.Cmp{ - Op: expr.CmpOpGte, - Register: 1, - Data: binaryutil.BigEndian.PutUint16(port.Values[0]), - }, - &expr.Cmp{ - Op: expr.CmpOpLte, + &expr.Range{ + Op: expr.CmpOpEq, Register: 1, - Data: binaryutil.BigEndian.PutUint16(port.Values[1]), + FromData: binaryutil.BigEndian.PutUint16(port.Values[0]), + ToData: binaryutil.BigEndian.PutUint16(port.Values[1]), }, ) } else { diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 2a5d7168d5c..78763c2a079 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { }) // Check if the rule is in the internal map - rule, ok := r.rules[ruleKey.GetRuleID()] + rule, ok := r.rules[ruleKey.ID()] assert.True(t, ok, "Rule not found in internal map") t.Log("Internal rule expressions:") @@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { var nftRule *nftables.Rule for _, rule := range rules { - if string(rule.UserData) == ruleKey.GetRuleID() { + if string(rule.UserData) == ruleKey.ID() { nftRule = rule break } diff --git a/client/firewall/nftables/rule_linux.go b/client/firewall/nftables/rule_linux.go index 4d652346b95..a90b74e36c9 100644 --- a/client/firewall/nftables/rule_linux.go +++ b/client/firewall/nftables/rule_linux.go @@ -16,6 +16,6 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *Rule) ID() string { return r.ruleID } From 2ed63561665732e1a350096b2077d734dbb43345 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 29 Jan 2025 14:31:25 +0100 Subject: [PATCH 3/5] Fix map --- client/firewall/iptables/acl_linux.go | 4 ++-- client/firewall/nftables/router_linux.go | 1 - client/internal/ingressgw/manager.go | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index e380faf90b0..8f1b231b835 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -310,8 +310,8 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - // inbound is handled by our ACLs, the rest is dropped - // for outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules. + // Inbound is handled by our ACLs, the rest is dropped. + // For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules. m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 03c8b08ab2e..8af30e51fbc 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -132,7 +132,6 @@ func (r *router) removeNatPreroutingRules() error { } rules, err := r.conn.GetRules(table, chain) if err != nil { - log.Debugf("err type: %T", err) return fmt.Errorf("get rules from nat table: %w", err) } diff --git a/client/internal/ingressgw/manager.go b/client/internal/ingressgw/manager.go index 07404703317..33882de4a4c 100644 --- a/client/internal/ingressgw/manager.go +++ b/client/internal/ingressgw/manager.go @@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error { var mErr *multierror.Error - toDelete := make(map[firewall.ForwardRuleID]RulePair, len(h.rules)) + toDelete := make(map[string]RulePair, len(h.rules)) for id, r := range h.rules { toDelete[id] = r } From dbd53604eb9c819b7e03ee2a8f5fa8142fb17780 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 29 Jan 2025 14:31:25 +0100 Subject: [PATCH 4/5] Reduce complexity --- client/firewall/nftables/router_linux.go | 104 ++++++++++++++--------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 8af30e51fbc..3a96ea39bfe 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1002,54 +1002,19 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule } dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...) - var regProtoMin, regProtoMax uint32 - switch { - case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2: - // shifted translated port is not supported in nftables, so we hand this over to xtables + // shifted translated port is not supported in nftables, so we hand this over to xtables + if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 { if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] || rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] { return r.addXTablesRedirect(dnatExprs, ruleKey, rule) } + } - dnatExprs = append(dnatExprs, - &expr.Immediate{ - Register: 1, - Data: rule.TranslatedAddress.AsSlice(), - }, - &expr.Immediate{ - Register: 2, - Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), - }, - &expr.Immediate{ - Register: 3, - Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]), - }, - ) - regProtoMin = 2 - regProtoMax = 3 - case len(rule.TranslatedPort.Values) == 0: - // address only - dnatExprs = append(dnatExprs, - &expr.Immediate{ - Register: 1, - Data: rule.TranslatedAddress.AsSlice(), - }, - ) - case len(rule.TranslatedPort.Values) == 1: - dnatExprs = append(dnatExprs, - &expr.Immediate{ - Register: 1, - Data: rule.TranslatedAddress.AsSlice(), - }, - &expr.Immediate{ - Register: 2, - Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), - }, - ) - regProtoMin = 2 - default: - return fmt.Errorf("invalid translated port: %v", rule.TranslatedPort) + additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule) + if err != nil { + return err } + dnatExprs = append(dnatExprs, additionalExprs...) dnatExprs = append(dnatExprs, &expr.NAT{ @@ -1073,6 +1038,61 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule return nil } +func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) { + switch { + case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2: + return r.handlePortRange(rule) + case len(rule.TranslatedPort.Values) == 0: + return r.handleAddressOnly(rule) + case len(rule.TranslatedPort.Values) == 1: + return r.handleSinglePort(rule) + default: + return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort) + } +} + +func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) { + exprs := []expr.Any{ + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), + }, + &expr.Immediate{ + Register: 3, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]), + }, + } + return exprs, 2, 3, nil +} + +func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) { + exprs := []expr.Any{ + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + } + return exprs, 0, 0, nil +} + +func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) { + exprs := []expr.Any{ + &expr.Immediate{ + Register: 1, + Data: rule.TranslatedAddress.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]), + }, + } + return exprs, 2, 0, nil +} + func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error { dnatExprs = append(dnatExprs, &expr.Counter{}, From 94ffafdb49e2e675fbf7216ed0ca12d59972c4f6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 29 Jan 2025 15:21:33 +0100 Subject: [PATCH 5/5] Fix test regarding port ranges --- client/firewall/nftables/router_linux_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 78763c2a079..bb4b6e89bff 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { payloadFound = true } - case *expr.Cmp: - if port.IsRange { - if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { + case *expr.Range: + if port.IsRange && len(port.Values) == 2 { + fromPort := binary.BigEndian.Uint16(ex.FromData) + toPort := binary.BigEndian.Uint16(ex.ToData) + if fromPort == port.Values[0] && toPort == port.Values[1] { portMatchFound = true } - } else { + } + case *expr.Cmp: + if !port.IsRange { if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { portValue := binary.BigEndian.Uint16(ex.Data) for _, p := range port.Values { - if uint16(p) == portValue { + if p == portValue { portMatchFound = true break }