diff --git a/pkg/packetfilter/iptables/adapter.go b/pkg/packetfilter/iptables/adapter.go new file mode 100644 index 000000000..e50bdb64e --- /dev/null +++ b/pkg/packetfilter/iptables/adapter.go @@ -0,0 +1,155 @@ +/* +SPDX-License-Identifier: Apache-2.0 + +Copyright Contributors to the Submariner project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "strings" + + "github.com/coreos/go-iptables/iptables" + "github.com/pkg/errors" + level "github.com/submariner-io/admiral/pkg/log" + "k8s.io/utils/set" +) + +func (a *Adapter) createChainIfNotExists(table, chain string) error { + exists, err := a.ipt.ChainExists(table, chain) + if err == nil && exists { + return nil + } + + if err != nil { + return errors.Wrapf(err, "error finding IP table chain %q in table %q", chain, table) + } + + return errors.Wrap(a.ipt.NewChain(table, chain), "error creating IP table chain") +} + +func (a *Adapter) insertUnique(table, chain string, position int, ruleSpec []string) error { + rules, err := a.ipt.List(table, chain) + if err != nil { + return errors.Wrapf(err, "error listing the rules in %s chain", chain) + } + + isPresentAtRequiredPosition := false + numOccurrences := 0 + + for index, rule := range rules { + if strings.Contains(rule, strings.Join(ruleSpec, " ")) { + logger.V(level.DEBUG).Infof("In %s table, iptables rule \"%s\", exists at index %d.", table, strings.Join(ruleSpec, " "), index) + numOccurrences++ + + if index == position { + isPresentAtRequiredPosition = true + } + } + } + + // The required rule is present in the Chain, but either there are multiple occurrences or its + // not at the desired location + if numOccurrences > 1 || !isPresentAtRequiredPosition { + for i := 0; i < numOccurrences; i++ { + if err = a.ipt.Delete(table, chain, ruleSpec...); err != nil { + return errors.Wrapf(err, "error deleting stale IP table rule %q", strings.Join(ruleSpec, " ")) + } + } + } + + // The required rule is present only once and is at the desired location + if numOccurrences == 1 && isPresentAtRequiredPosition { + logger.V(level.DEBUG).Infof("In %s table, iptables rule \"%s\", already exists.", table, strings.Join(ruleSpec, " ")) + return nil + } else if err := a.ipt.Insert(table, chain, position, ruleSpec...); err != nil { + return errors.Wrapf(err, "error inserting IP table rule %q", strings.Join(ruleSpec, " ")) + } + + return nil +} + +func (a *Adapter) prependUnique(table, chain string, ruleSpec []string) error { + // Submariner requires certain iptable rules to be programmed at the beginning of an iptables Chain + // so that we can preserve the sourceIP for inter-cluster traffic and avoid K8s SDN making changes + // to the traffic. + // In this API, we check if the required iptable rule is present at the beginning of the chain. + // If the rule is already present and there are no stale[1] flows, we simply return. If not, we create one. + // [1] Sometimes after we program the rule at the beginning of the chain, K8s SDN might insert some + // new rules ahead of the rule that we programmed. In such cases, the rule that we programmed will + // not be the first rule to hit and Submariner behavior might get affected. So, we query the rules + // in the chain to see if the rule slipped its position, and if so, delete all such occurrences. + // We then re-program a new rule at the beginning of the chain as required. + return a.insertUnique(table, chain, 1, ruleSpec) +} + +func (a *Adapter) updateChainRules(table, chain string, rules [][]string) error { + existingRules, err := a.ipt.List(table, chain) + if err != nil { + return errors.Wrapf(err, "error listing the rules in table %q, chain %q", table, chain) + } + + ruleStrings := set.New[string]() + + for _, existingRule := range existingRules { + ruleSpec := strings.Split(existingRule, " ") + if ruleSpec[0] == "-A" { + ruleSpec = ruleSpec[2:] // remove "-A", "$chain" + ruleStrings.Insert(strings.Trim(strings.Join(ruleSpec, " "), " ")) + } + } + + for _, ruleSpec := range rules { + ruleString := strings.Join(ruleSpec, " ") + + if ruleStrings.Has(ruleString) { + ruleStrings.Delete(ruleString) + } else { + logger.V(level.DEBUG).Infof("Adding iptables rule in %q, %q: %q", table, chain, ruleSpec) + + if err := a.ipt.Append(table, chain, ruleSpec...); err != nil { + return errors.Wrapf(err, "error adding rule to %v to %q, %q", ruleSpec, table, chain) + } + } + } + + // remaining elements should not be there, remove them + for _, rule := range ruleStrings.UnsortedList() { + logger.V(level.DEBUG).Infof("Deleting stale iptables rule in %q, %q: %q", table, chain, rule) + ruleSpec := strings.Split(rule, " ") + + if err := a.ipt.Delete(table, chain, ruleSpec...); err != nil { + // Log and let go, as this is not a fatal error, or something that will make real harm, + // it's more harmful to keep retrying. At this point on next update deletion of stale rules + // will happen again + logger.Warningf("Unable to delete iptables entry from table %q, chain %q: %q", table, chain, rule) + } + } + + return nil +} + +func (a *Adapter) delete(table, chain string, rulespec ...string) error { + err := a.ipt.Delete(table, chain, rulespec...) + + var iptError *iptables.Error + + ok := errors.As(err, &iptError) + if ok && iptError.IsNotExist() { + return nil + } + + return errors.Wrap(err, "error deleting IP table rule") +} diff --git a/pkg/packetfilter/iptables/iptables.go b/pkg/packetfilter/iptables/iptables.go new file mode 100644 index 000000000..8e70ee953 --- /dev/null +++ b/pkg/packetfilter/iptables/iptables.go @@ -0,0 +1,578 @@ +/* +SPDX-License-Identifier: Apache-2.0 + +Copyright Contributors to the Submariner project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables + +import ( + "crypto/sha256" + "encoding/base32" + "strconv" + "strings" + + "github.com/coreos/go-iptables/iptables" + "github.com/pkg/errors" + "github.com/submariner-io/admiral/pkg/log" + "github.com/submariner-io/submariner/pkg/ipset" + "github.com/submariner-io/submariner/pkg/packetfilter" + utilexec "k8s.io/utils/exec" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + remoteCIDRIPSet = "SUBMARINER-REMOTECIDRS" + localCIDRIPSet = "SUBMARINER-LOCALCIDRS" + smPostRoutingChain = "SUBMARINER-POSTROUTING" + smGlobalnetEgressChainForPods = "SM-GN-EGRESS-PODS" + smGlobalnetEgressChainForNamespace = "SM-GN-EGRESS-NS" + gnIPSetPrefix = "SM-GN-" +) + +var ( + tableTypeToStr = [packetfilter.TableTypeMAX]string{"filter", "mangle", "nat"} + iphookChainTypeToStr = [packetfilter.ChainTypeMAX]string{"filter", "mangle", "nat"} + iphookChainTypeToTableType = [packetfilter.ChainTypeMAX]packetfilter.TableType{ + packetfilter.TableTypeFilter, + packetfilter.TableTypeRoute, + packetfilter.TableTypeNAT, + } + chainHookToStr = [packetfilter.ChainHookMAX]string{"PREROUTING", "INPUT", "FORWARD", "OUTPUT", "POSTROUTING"} + ruleActiontoStr = [packetfilter.RuleActionMAX]string{"", "ACCEPT", "TCPMSS", "MARK", "SNAT", "DNAT"} + logger = log.Logger{Logger: logf.Log.WithName("IPTables")} + NewFunc func() (packetfilter.Driver, error) +) + +type globalEgressSet struct { + ipSetNamed ipset.Named + snatCIDR string + markValue string +} + +type Adapter struct { + // Basic packetfilter.Driver + ipt *iptables.IPTables + remoteIPSet ipset.Named + localIPSet ipset.Named + stringToGlobalEgressSet map[packetfilter.GlobalEgressHandle]*globalEgressSet + ipSetIface ipset.Interface +} + +func init() { + // TODO: register based on iptables support status + packetfilter.AddDriver("IPTables", true, NewDriver) + logger.Info("Registered as packetflter driver") +} + +func NewDriver() (packetfilter.Driver, error) { + var stringToGlobalEgressSet map[packetfilter.GlobalEgressHandle]*globalEgressSet + + if NewFunc != nil { + return NewFunc() + } + + ipt, err := iptables.New(iptables.IPFamily(iptables.ProtocolIPv4), iptables.Timeout(5)) + if err != nil { + return nil, errors.Wrap(err, "error creating IP tables") + } + + stringToGlobalEgressSet = make(map[packetfilter.GlobalEgressHandle]*globalEgressSet) + ipSetIface := ipset.New(utilexec.New()) + + return &Adapter{ + ipt: ipt, + stringToGlobalEgressSet: stringToGlobalEgressSet, + ipSetIface: ipSetIface, + }, nil +} + +func (a *Adapter) ChainExists(table packetfilter.TableType, chain string) (bool, error) { + ok, err := a.ipt.ChainExists(tableTypeToStr[table], chain) + return ok, errors.Wrap(err, "ChainExists failed") +} + +func (a *Adapter) AppendUnique(table packetfilter.TableType, chain string, ruleSpec *packetfilter.Rule) error { + ruleSpecStr, err := ruleToRuleSpec(ruleSpec) + if err != nil { + return errors.Wrap(err, "AppendUnique failed") + } + + return errors.Wrap(a.ipt.AppendUnique(tableTypeToStr[table], chain, ruleSpecStr...), "error AppendUnique rule") +} + +func (a *Adapter) CreateIPHookChainIfNotExists(chain *packetfilter.ChainIPHook) error { + if err := a.createChainIfNotExists(iphookChainTypeToStr[chain.Type], chain.Name); err != nil { + return errors.Wrapf(err, "error creating IP tables %s:%s chain", iphookChainTypeToStr[chain.Type], chain.Name) + } + + jumpRule := chain.JumpRule + if jumpRule == nil { + jumpRule = &packetfilter.Rule{ + TargetChain: chain.Name, + Action: packetfilter.RuleActionJump, + } + } + + if chain.Priority == packetfilter.ChainPriorityFirst { + if err := a.PrependUnique(iphookChainTypeToTableType[chain.Type], chainHookToStr[chain.Hook], jumpRule); err != nil { + return errors.Wrap(err, "error PrependUnique rule") + } + } else { + if err := a.AppendUnique(iphookChainTypeToTableType[chain.Type], chainHookToStr[chain.Hook], jumpRule); err != nil { + return errors.Wrap(err, "error AppendUnique rule") + } + } + + return nil +} + +func (a *Adapter) CreateRegularChainIfNotExists(table packetfilter.TableType, chain *packetfilter.ChainRegular) error { + if err := a.createChainIfNotExists(tableTypeToStr[table], chain.Name); err != nil { + return errors.Wrapf(err, "error creating IP tables %s chain", chain.Name) + } + + return nil +} + +func (a *Adapter) DeleteIPHookChain(chain *packetfilter.ChainIPHook) error { + tableType := iphookChainTypeToTableType[chain.Type] + jumpRule := chain.JumpRule + + if jumpRule == nil { + jumpRule = &packetfilter.Rule{ + TargetChain: chain.Name, + Action: packetfilter.RuleActionJump, + } + } + + if err := a.Delete(tableType, chainHookToStr[chain.Hook], jumpRule); err != nil { + return errors.Wrap(err, "error deleting Jump Rule") + } + + if err := a.ipt.DeleteChain(tableTypeToStr[tableType], chain.Name); err != nil { + return errors.Wrap(err, "error deleting chain") + } + + return nil +} + +func (a *Adapter) DeleteRegularChain(table packetfilter.TableType, chain string) error { + if err := a.ipt.DeleteChain(tableTypeToStr[table], chain); err != nil { + return errors.Wrap(err, "error deleting chain") + } + + return nil +} + +func (a *Adapter) ClearChain(table packetfilter.TableType, chain string) error { + if err := a.ipt.ClearChain(tableTypeToStr[table], chain); err != nil { + return errors.Wrap(err, "error clearing chain") + } + + return nil +} + +func (a *Adapter) Delete(table packetfilter.TableType, chain string, ruleSpec *packetfilter.Rule) error { + ruleSpecStr, err := ruleToRuleSpec(ruleSpec) + if err != nil { + return errors.Wrap(err, "error translating ruleSpec to str") + } + + return a.delete(tableTypeToStr[table], chain, ruleSpecStr...) +} + +func (a *Adapter) InsertUnique(table packetfilter.TableType, chain string, position int, ruleSpec *packetfilter.Rule) error { + ruleSpecStr, err := ruleToRuleSpec(ruleSpec) + if err != nil { + return errors.Wrap(err, "error translating ruleSpec to str") + } + + return a.insertUnique(tableTypeToStr[table], chain, position, ruleSpecStr) +} + +func (a *Adapter) PrependUnique(table packetfilter.TableType, chain string, ruleSpec *packetfilter.Rule) error { + ruleSpecStr, err := ruleToRuleSpec(ruleSpec) + if err != nil { + return errors.Wrap(err, "error translating ruleSpec to str") + } + + return a.prependUnique(tableTypeToStr[table], chain, ruleSpecStr) +} + +func (a *Adapter) UpdateChainRules(table packetfilter.TableType, chain string, rules []*packetfilter.Rule) error { + rulesStr := make([][]string, len(rules)) + + for i, rule := range rules { + ruleStr, err := ruleToRuleSpec(rule) + if err != nil { + return errors.Wrap(err, "error translating rule to str") + } + + rulesStr[i] = ruleStr + } + + return a.updateChainRules(tableTypeToStr[table], chain, rulesStr) +} + +func (a *Adapter) SetMultiClusterMssClamp(clampType packetfilter.MssClampType, mssValue int) error { + if err := a.CreateIPHookChainIfNotExists(&packetfilter.ChainIPHook{ + Name: smPostRoutingChain, + Type: packetfilter.ChainTypeRoute, + Hook: packetfilter.ChainHookPostrouting, + Priority: packetfilter.ChainPriorityFirst, + }); err != nil { + return errors.Wrapf(err, "error creating IPHookChain chain %s", smPostRoutingChain) + } + + a.localIPSet = ipset.NewNamed(&ipset.IPSet{ + Name: localCIDRIPSet, + SetType: ipset.HashNet, + HashFamily: ipset.ProtocolFamilyIPV4, + }, a.ipSetIface) + + if err := a.localIPSet.Create(true); err != nil { + return errors.Wrapf(err, "error creating ipset %q", localCIDRIPSet) + } + + a.remoteIPSet = ipset.NewNamed(&ipset.IPSet{ + Name: remoteCIDRIPSet, + SetType: ipset.HashNet, + HashFamily: ipset.ProtocolFamilyIPV4, + }, a.ipSetIface) + + if err := a.remoteIPSet.Create(true); err != nil { + return errors.Wrapf(err, "error creating ipset %q", remoteCIDRIPSet) + } + + switch clampType { + case packetfilter.UndefinedMSS: + case packetfilter.ToPMTU, packetfilter.ToValue: + ruleSpecSource := []string{ + "-m", "set", "--match-set", localCIDRIPSet, "src", "-m", "set", "--match-set", + remoteCIDRIPSet, "dst", "-p", "tcp", "-m", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", + } + + ruleSpecDest := []string{ + "-m", "set", "--match-set", remoteCIDRIPSet, "src", "-m", "set", "--match-set", + localCIDRIPSet, "dst", "-p", "tcp", "-m", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", + } + + if clampType == packetfilter.ToPMTU { + ruleSpecDest = append(ruleSpecDest, "--clamp-mss-to-pmtu") + ruleSpecSource = append(ruleSpecSource, "--clamp-mss-to-pmtu") + } else { + ruleSpecDest = append(ruleSpecDest, "--set-mss", strconv.Itoa(mssValue)) + ruleSpecSource = append(ruleSpecSource, "--set-mss", strconv.Itoa(mssValue)) + } + + if err := a.updateChainRules(tableTypeToStr[packetfilter.TableTypeRoute], smPostRoutingChain, + [][]string{ruleSpecSource, ruleSpecDest}); err != nil { + return errors.Wrapf(err, "error updating chain %s table %s rules", smPostRoutingChain, tableTypeToStr[packetfilter.TableTypeRoute]) + } + } + + return nil +} + +func (a *Adapter) DelMultiClusterMssClamp() error { + if err := a.ipt.ClearChain(tableTypeToStr[packetfilter.TableTypeRoute], smPostRoutingChain); err != nil { + logger.Errorf(err, "Error flushing iptables chain %q of %q table", smPostRoutingChain, + tableTypeToStr[packetfilter.TableTypeRoute]) + } + + logger.Infof("Deleting iptable entry in %q chain of %q table", smPostRoutingChain, tableTypeToStr[packetfilter.TableTypeRoute]) + + ruleSpec := []string{"-j", smPostRoutingChain} + if err := a.ipt.Delete(tableTypeToStr[packetfilter.TableTypeRoute], + chainHookToStr[packetfilter.ChainHookPostrouting], ruleSpec...); err != nil { + logger.Errorf(err, "Error deleting iptables rule from %q chain", chainHookToStr[packetfilter.ChainHookPostrouting]) + } + + logger.Infof("Deleting iptable %q chain of %q table", smPostRoutingChain, tableTypeToStr[packetfilter.TableTypeRoute]) + + if err := a.ipt.DeleteChain(tableTypeToStr[packetfilter.TableTypeRoute], smPostRoutingChain); err != nil { + logger.Errorf(err, "Error deleting iptable chain %q of table %q", smPostRoutingChain, + tableTypeToStr[packetfilter.TableTypeRoute]) + } + + if err := a.localIPSet.Flush(); err != nil { + logger.Errorf(err, "Error flushing ipset %q", localCIDRIPSet) + } + + if err := a.localIPSet.Destroy(); err != nil { + logger.Errorf(err, "Error deleting ipset %q", localCIDRIPSet) + } + + if err := a.remoteIPSet.Flush(); err != nil { + logger.Errorf(err, "Error flushing ipset %q", remoteCIDRIPSet) + } + + if err := a.remoteIPSet.Destroy(); err != nil { + logger.Errorf(err, "Error deleting ipset %q", remoteCIDRIPSet) + } + + return nil +} + +func (a *Adapter) AddSubMultiClusterMssClamp(localCIDRs, destCIDRs []string) error { + for _, subnet := range localCIDRs { + err := a.localIPSet.AddEntry(subnet, true) + if err != nil { + return errors.Wrap(err, "error adding local IP set entry") + } + } + + for _, subnet := range destCIDRs { + err := a.remoteIPSet.AddEntry(subnet, true) + if err != nil { + return errors.Wrap(err, "error adding remote IP set entry") + } + } + + return nil +} + +func (a *Adapter) DelSubMultiClusterMssClamp(localCIDRs, destCIDRs []string) error { + for _, subnet := range localCIDRs { + err := a.localIPSet.DelEntry(subnet) + if err != nil { + return errors.Wrap(err, "error deleting local IP set entry") + } + } + + for _, subnet := range destCIDRs { + err := a.remoteIPSet.DelEntry(subnet) + if err != nil { + return errors.Wrap(err, "error deleting remote IP set entry") + } + } + + return nil +} + +func getGlobalEgressIPSetName(key string) string { + hash := sha256.Sum256([]byte(key)) + encoded := base32.StdEncoding.EncodeToString(hash[:]) + // Max length of IPSet name can be 31 + return gnIPSetPrefix + encoded[:25] +} + +func (a *Adapter) CreateGlobalEgressForPods(key, snatCIDR, markValue string) (packetfilter.GlobalEgressHandle, error) { + return a.createGlobalEgress(packetfilter.GlobalEgressHandle(key), snatCIDR, markValue, false) +} + +func (a *Adapter) DeleteGlobalEgressForPods(handle packetfilter.GlobalEgressHandle) error { + return a.deleteGlobalEgress(handle, false) +} + +func (a *Adapter) CreateGlobalEgressForNamespace(namespace, snatCIDR, markValue string) (packetfilter.GlobalEgressHandle, error) { + return a.createGlobalEgress(packetfilter.GlobalEgressHandle(namespace), snatCIDR, markValue, true) +} + +func (a *Adapter) DeleteGlobalEgressForNamespace(handle packetfilter.GlobalEgressHandle) error { + return a.deleteGlobalEgress(handle, true) +} + +func (a *Adapter) createGlobalEgress(key packetfilter.GlobalEgressHandle, + snatCIDR, markValue string, + isNamespace bool, +) (packetfilter.GlobalEgressHandle, error) { + chainName := smGlobalnetEgressChainForPods + EgressTarget := "pods" + + if isNamespace { + chainName = smGlobalnetEgressChainForNamespace + EgressTarget = "namespace" + } + + namedSet := ipset.NewNamed(&ipset.IPSet{ + Name: getGlobalEgressIPSetName(string(key)), + SetType: ipset.HashIP, + }, a.ipSetIface) + + err := namedSet.Create(true) + if err != nil { + return "", errors.Wrapf(err, "error creating namedSet %q", namedSet.Name()) + } + + ruleSpec := []string{ + "-p", "all", "-m", "set", "--match-set", namedSet.Name(), "src", "-m", "mark", + "--mark", markValue, "-j", "SNAT", "--to", snatCIDR, + } + logger.V(log.DEBUG).Infof("Installing iptable egress rules for %s %q: %s", EgressTarget, key, strings.Join(ruleSpec, " ")) + + if err := a.ipt.AppendUnique(tableTypeToStr[packetfilter.TableTypeNAT], chainName, ruleSpec...); err != nil { + return "", errors.Wrapf(err, "error appending iptables rule \"%s\" chain:%s", strings.Join(ruleSpec, " "), chainName) + } + + a.stringToGlobalEgressSet[key] = &globalEgressSet{ + ipSetNamed: namedSet, + snatCIDR: snatCIDR, + markValue: markValue, + } + + return key, nil +} + +func (a *Adapter) deleteGlobalEgress(handle packetfilter.GlobalEgressHandle, isNamespace bool) error { + chainName := smGlobalnetEgressChainForPods + EgressTarget := "pods" + + if isNamespace { + chainName = smGlobalnetEgressChainForNamespace + EgressTarget = "namespace" + } + + globalEgressSet, ok := a.stringToGlobalEgressSet[handle] + if !ok { + return errors.Errorf(" unable to find GlobalEgressSet with handle: %s", handle) + } + + ruleSpec := []string{ + "-p", "all", "-m", "set", "--match-set", globalEgressSet.ipSetNamed.Name(), "src", "-m", "mark", + "--mark", globalEgressSet.markValue, "-j", "SNAT", "--to", globalEgressSet.snatCIDR, + } + logger.V(log.DEBUG).Infof("Deleting iptable egress rules for %s %q: %s", EgressTarget, handle, strings.Join(ruleSpec, " ")) + + if err := a.ipt.Delete(tableTypeToStr[packetfilter.TableTypeNAT], chainName, ruleSpec...); err != nil { + return errors.Wrapf(err, "error deleting iptables rule \"%s\" chain:%s", strings.Join(ruleSpec, " "), chainName) + } + + if err := globalEgressSet.ipSetNamed.Flush(); err != nil { + return errors.Wrapf(err, "error flushing ipset %q", globalEgressSet.ipSetNamed.Name()) + } + + if err := globalEgressSet.ipSetNamed.Destroy(); err != nil { + return errors.Wrapf(err, "error deleting ipset %q", globalEgressSet.ipSetNamed.Name()) + } + + return nil +} + +func (a *Adapter) AddPodIPToGlobalEgress(handle packetfilter.GlobalEgressHandle, podIP string) error { + globalEgressSet, ok := a.stringToGlobalEgressSet[handle] + if !ok { + return errors.Errorf(" unable to find GlobalEgressSet with handle: %s", handle) + } + + if err := globalEgressSet.ipSetNamed.AddEntry(podIP, true); err != nil { + return errors.Wrapf(err, "Error adding entry to GlobalEgressSet with handle: %s", handle) + } + + return nil +} + +func (a *Adapter) RemovePodIPToGlobalEgress(handle packetfilter.GlobalEgressHandle, podIP string) error { + globalEgressSet, ok := a.stringToGlobalEgressSet[handle] + if !ok { + return errors.Errorf(" unable to find GlobalEgressSet with handle: %s", handle) + } + + if err := globalEgressSet.ipSetNamed.DelEntry(podIP); err != nil { + return errors.Wrapf(err, "error removing entry from GlobalEgressSet with handle: %s", handle) + } + + return nil +} + +func protoToRuleSpec(ruleSpec *[]string, proto packetfilter.RuleProto) { + switch proto { + case packetfilter.RuleProtoUDP: + *ruleSpec = append(*ruleSpec, "-p", "udp", "-m", "udp") + case packetfilter.RuleProtoTCP: + *ruleSpec = append(*ruleSpec, "-p", "tcp", "-m", "tcp") + case packetfilter.RuleProtoICMP: + *ruleSpec = append(*ruleSpec, "-p", "icmp") + case packetfilter.RuleProtoAll: + *ruleSpec = append(*ruleSpec, "-p", "all") + } +} + +func mssClampToRuleSpec(ruleSpec *[]string, clampType packetfilter.MssClampType, mssValue string) { + switch clampType { + case packetfilter.UndefinedMSS: + case packetfilter.ToPMTU: + *ruleSpec = append(*ruleSpec, "-p", "tcp", "-m", "tcp", "--tcp-flags", "SYN,RST", "SYN", "--clamp-mss-to-pmtu") + case packetfilter.ToValue: + *ruleSpec = append(*ruleSpec, "-p", "tcp", "-m", "tcp", "--tcp-flags", "SYN,RST", "SYN", "--set-mss", mssValue) + } +} + +func setToRuleSpec(ruleSpec *[]string, srcSetName, destSetName string) { + if srcSetName != "" { + *ruleSpec = append(*ruleSpec, "-m", "set", "--match-set", srcSetName, "src") + } + + if destSetName != "" { + *ruleSpec = append(*ruleSpec, "-m", "set", "--match-set", destSetName, "dst") + } +} + +func ruleToRuleSpec(rule *packetfilter.Rule) ([]string, error) { + var ruleSpec []string + protoToRuleSpec(&ruleSpec, rule.Proto) + + if rule.SrcCIDR != "" { + ruleSpec = append(ruleSpec, "-s", rule.SrcCIDR) + } + + if rule.DestCIDR != "" { + ruleSpec = append(ruleSpec, "-d", rule.DestCIDR) + } + + if rule.MarkValue != "" && rule.Action != packetfilter.RuleActionMark { + ruleSpec = append(ruleSpec, "-m", "mark", "--mark", rule.MarkValue) + } + + setToRuleSpec(&ruleSpec, rule.SrcSetName, rule.DestSetName) + + if rule.OutIntreface != "" { + ruleSpec = append(ruleSpec, "-o", rule.OutIntreface) + } + + if rule.DPort != "" { + ruleSpec = append(ruleSpec, "--dport", rule.DPort) + } + + switch rule.Action { + case packetfilter.RuleActionMAX: + case packetfilter.RuleActionJump: + ruleSpec = append(ruleSpec, "-j", rule.TargetChain) + case packetfilter.RuleActionAccept, packetfilter.RuleActionMss, + packetfilter.RuleActionMark, packetfilter.RuleActionSNAT, packetfilter.RuleActionDNAT: + ruleSpec = append(ruleSpec, "-j", ruleActiontoStr[rule.Action]) + default: + return ruleSpec, errors.Errorf(" rule.Action %d is invalid", rule.Action) + } + + if rule.SnatCIDR != "" { + ruleSpec = append(ruleSpec, "--to-source", rule.SnatCIDR) + } + + if rule.DnatCIDR != "" { + ruleSpec = append(ruleSpec, "--to-destination", rule.DnatCIDR) + } + + mssClampToRuleSpec(&ruleSpec, rule.ClampType, rule.MssValue) + + if rule.MarkValue != "" && rule.Action == packetfilter.RuleActionMark { + ruleSpec = append(ruleSpec, "--set-mark", rule.MarkValue) + } + + logger.Infof("ruleToRuleSpec IPtables rule Spec: %s", strings.Join(ruleSpec, " ")) + + return ruleSpec, nil +} diff --git a/pkg/packetfilter/packetfilter.go b/pkg/packetfilter/packetfilter.go new file mode 100644 index 000000000..e984c9415 --- /dev/null +++ b/pkg/packetfilter/packetfilter.go @@ -0,0 +1,273 @@ +/* +SPDX-License-Identifier: Apache-2.0 + +Copyright Contributors to the Submariner project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package packetfilter + +import ( + "github.com/pkg/errors" + "github.com/submariner-io/admiral/pkg/log" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + TableTypeFilter TableType = iota + TableTypeRoute // mangle + TableTypeNAT + TableTypeMAX +) + +type TableType uint32 + +type GlobalEgressHandle string + +type RuleAction uint32 + +const ( + RuleActionJump RuleAction = iota + RuleActionAccept + RuleActionMss + RuleActionMark + RuleActionSNAT + RuleActionDNAT + RuleActionMAX +) + +type RuleProto uint32 + +const ( + RuleProtoAll RuleProto = iota + RuleProtoTCP + RuleProtoUDP + RuleProtoICMP +) + +type MssClampType uint32 + +const ( + UndefinedMSS MssClampType = iota + ToPMTU + ToValue +) + +type ChainHook uint32 + +const ( + ChainHookPrerouting ChainHook = iota + ChainHookInput + ChainHookForward + ChainHookOutput + ChainHookPostrouting + ChainHookMAX +) + +type ChainPriority uint32 + +const ( + ChainPriorityFirst ChainPriority = iota + ChainPriorityLast +) + +type ChainType uint32 + +const ( + ChainTypeFilter ChainType = iota + ChainTypeRoute // mangle + ChainTypeNAT + ChainTypeMAX +) + +type ChainPolicy uint32 + +const ( + ChainPolicyAccept ChainType = iota + ChainPolicyDrop +) + +type Rule struct { + DestCIDR string + SrcCIDR string + + SrcSetName string + DestSetName string + + SnatCIDR string + DnatCIDR string + + OutIntreface string + TargetChain string + MssValue string + + DPort string + MarkValue string + Action RuleAction + Proto RuleProto + ClampType MssClampType +} + +// Supported policy values are accept (which is the default) or drop. +type ChainRegular struct { + Name string + Policy ChainPolicy +} + +type ChainIPHook struct { + Name string + Type ChainType + Hook ChainHook + Priority ChainPriority + Policy ChainPolicy + JumpRule *Rule +} + +type Base interface { + // Chains + ChainExists(table TableType, chain string) (bool, error) + CreateIPHookChainIfNotExists(chain *ChainIPHook) error + CreateRegularChainIfNotExists(table TableType, chain *ChainRegular) error + DeleteIPHookChain(chain *ChainIPHook) error + DeleteRegularChain(table TableType, chain string) error + ClearChain(table TableType, chain string) error + + // rules + Delete(table TableType, chain string, ruleSpec *Rule) error + InsertUnique(table TableType, chain string, position int, ruleSpec *Rule) error + PrependUnique(table TableType, chain string, ruleSpec *Rule) error + UpdateChainRules(table TableType, chain string, rules []*Rule) error + AppendUnique(table TableType, chain string, ruleSpec *Rule) error +} + +type MssClampInterface interface { + SetMultiClusterMssClamp(clampType MssClampType, mssValue int) error + DelMultiClusterMssClamp() error + AddSubMultiClusterMssClamp(localCIDRs []string, destCIDRs []string) error + DelSubMultiClusterMssClamp(localCIDRs []string, destCIDRs []string) error +} +type GlobalNetEgressInterface interface { + CreateGlobalEgressForPods(key, snatCIDR, markValue string) (GlobalEgressHandle, error) + DeleteGlobalEgressForPods(handle GlobalEgressHandle) error + CreateGlobalEgressForNamespace(namespace, snatCIDR, markValue string) (GlobalEgressHandle, error) + DeleteGlobalEgressForNamespace(handle GlobalEgressHandle) error + AddPodIPToGlobalEgress(handle GlobalEgressHandle, podIP string) error + RemovePodIPToGlobalEgress(handle GlobalEgressHandle, podIP string) error +} + +type Driver interface { + Base + MssClampInterface + GlobalNetEgressInterface +} +type DriverCreateFunc func() (Driver, error) + +type driverInfo struct { + createFunc DriverCreateFunc + name string + isDefault bool +} + +var ( + logger = log.Logger{Logger: logf.Log.WithName("PacketFilter")} + NewFunc func() (Base, error) + NewMMssClampFunc func() (MssClampInterface, error) + NewGlobalNetEgressFunc func() (GlobalNetEgressInterface, error) + + drivers []driverInfo + selectedDriverIndex int +) + +func AddDriver(name string, isDefault bool, createFunc DriverCreateFunc) { + drivers = append(drivers, + driverInfo{ + createFunc: createFunc, + name: name, + isDefault: isDefault, + }) + + logger.Info("%s registered as driver, isDefault:%t", name, isDefault) + + if len(drivers) == 1 { + selectedDriverIndex = 0 + } else { + for i := 0; i < len(drivers); i++ { + if drivers[i].isDefault { + selectedDriverIndex = i + break + } + } + } +} + +type Adapter struct { + Driver +} + +func New() (Base, error) { + if NewFunc != nil { + return NewFunc() + } + + if len(drivers) == 0 { + return nil, errors.New("no valid backend driver") + } + + logger.Info("Creating %s backend driver", drivers[selectedDriverIndex].name) + + driver, err := drivers[selectedDriverIndex].createFunc() + if err != nil { + return nil, errors.Wrap(err, "error creating driver") + } + + return &Adapter{driver}, nil +} + +func NewMssClamp() (MssClampInterface, error) { + if NewMMssClampFunc != nil { + return NewMMssClampFunc() + } + + if len(drivers) == 0 { + return nil, errors.New("no valid backend driver") + } + + logger.Info("Creating %s backend driver for MssClamp", drivers[selectedDriverIndex].name) + + driver, err := drivers[selectedDriverIndex].createFunc() + if err != nil { + return nil, errors.Wrap(err, "error creating driver") + } + + return &Adapter{driver}, nil +} + +func NewGlobalNetEgress() (GlobalNetEgressInterface, error) { + if NewGlobalNetEgressFunc != nil { + return NewGlobalNetEgressFunc() + } + + if len(drivers) == 0 { + return nil, errors.New("no valid backend driver") + } + + logger.Info("Creating %s backend driver for GlobalNetEgress", drivers[selectedDriverIndex].name) + + driver, err := drivers[selectedDriverIndex].createFunc() + if err != nil { + return nil, errors.Wrap(err, "error creating driver") + } + + return &Adapter{driver}, nil +} diff --git a/pkg/routeagent_driver/main.go b/pkg/routeagent_driver/main.go index 1c8af1a67..b58ed8fff 100644 --- a/pkg/routeagent_driver/main.go +++ b/pkg/routeagent_driver/main.go @@ -57,6 +57,9 @@ import ( nodeutil "k8s.io/component-helpers/node/util" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/manager/signals" + + // Add supported drivers. + _ "github.com/submariner-io/submariner/pkg/packetfilter/iptables" ) var (