diff --git a/pkg/apis/submariner.io/v1/endpoint.go b/pkg/apis/submariner.io/v1/endpoint.go index 869c2f45c..a5d97f420 100644 --- a/pkg/apis/submariner.io/v1/endpoint.go +++ b/pkg/apis/submariner.io/v1/endpoint.go @@ -172,3 +172,15 @@ func (ep *EndpointSpec) GetPrivateIP(family k8snet.IPFamily) string { func (ep *EndpointSpec) SetPrivateIP(ip string) { ep.PrivateIPs, ep.PrivateIP = setIP(ep.PrivateIPs, ep.PrivateIP, ip) } + +func (ep *EndpointSpec) GetIPFamilies() [2]k8snet.IPFamily { + var ipFamilies [2]k8snet.IPFamily + // TODO_IPV6: set ipFamilies according to Subnets content + ipFamilies[0] = k8snet.IPv4 + + return ipFamilies +} + +func (ep *EndpointSpec) GetFamilyCableName(family k8snet.IPFamily) string { + return ep.CableName + "-ipv" + string(family) +} diff --git a/pkg/cableengine/cableengine.go b/pkg/cableengine/cableengine.go index 832b05942..fd636ea07 100644 --- a/pkg/cableengine/cableengine.go +++ b/pkg/cableengine/cableengine.go @@ -31,6 +31,7 @@ import ( "github.com/submariner-io/submariner/pkg/natdiscovery" "github.com/submariner-io/submariner/pkg/types" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8snet "k8s.io/utils/net" logf "sigs.k8s.io/controller-runtime/pkg/log" // Add supported drivers. @@ -50,10 +51,10 @@ type Engine interface { // InstallCable performs any set up work needed for connecting to given remote endpoint. // Once InstallCable completes, it should be possible to connect to remote // Pods or Services behind the given endpoint. - InstallCable(remote *v1.Endpoint) error + InstallCable(remote *v1.Endpoint, family k8snet.IPFamily) error // RemoveCable disconnects the Engine from the given remote endpoint. Upon completion. // remote Pods and Service may not be accessible anymore. - RemoveCable(remote *v1.Endpoint) error + RemoveCable(remote *v1.Endpoint, family k8snet.IPFamily) error // ListCableConnections returns a list of cable connection, and the related status. ListCableConnections() ([]v1.Connection, error) // GetLocalEndpoint returns the local endpoint for this cable engine. @@ -156,13 +157,13 @@ func (i *engine) installCableWithNATInfo(rnat *natdiscovery.NATEndpointInfo) err i.Lock() defer i.Unlock() - if _, ok := i.natDiscoveryPending[rnat.Endpoint.Spec.CableName]; !ok { + if _, ok := i.natDiscoveryPending[rnat.Endpoint.Spec.GetFamilyCableName(rnat.Family)]; !ok { return nil } - i.natDiscoveryPending[rnat.Endpoint.Spec.CableName]-- - if i.natDiscoveryPending[rnat.Endpoint.Spec.CableName] == 0 { - delete(i.natDiscoveryPending, rnat.Endpoint.Spec.CableName) + i.natDiscoveryPending[rnat.Endpoint.Spec.GetFamilyCableName(rnat.Family)]-- + if i.natDiscoveryPending[rnat.Endpoint.Spec.GetFamilyCableName(rnat.Family)] == 0 { + delete(i.natDiscoveryPending, rnat.Endpoint.Spec.GetFamilyCableName(rnat.Family)) } if !i.running { @@ -230,7 +231,7 @@ func (i *engine) installCableWithNATInfo(rnat *natdiscovery.NATEndpointInfo) err return nil } -func (i *engine) InstallCable(endpoint *v1.Endpoint) error { +func (i *engine) InstallCable(endpoint *v1.Endpoint, family k8snet.IPFamily) error { if endpoint.Spec.ClusterID == i.localCluster.ID { logger.V(log.TRACE).Infof("Not installing cable for local cluster") return nil @@ -242,28 +243,28 @@ func (i *engine) InstallCable(endpoint *v1.Endpoint) error { } i.Lock() - i.natDiscoveryPending[endpoint.Spec.CableName]++ + i.natDiscoveryPending[endpoint.Spec.GetFamilyCableName(family)]++ i.Unlock() - i.natDiscovery.AddEndpoint(endpoint) + i.natDiscovery.AddEndpoint(endpoint, family) return nil } -func (i *engine) RemoveCable(endpoint *v1.Endpoint) error { +func (i *engine) RemoveCable(endpoint *v1.Endpoint, family k8snet.IPFamily) error { if endpoint.Spec.ClusterID == i.localCluster.ID { logger.V(log.DEBUG).Infof("Cables are not added/removed for the local cluster, skipping removal") return nil } - logger.Infof("Removing Endpoint cable %q", endpoint.Spec.CableName) + logger.Infof("Removing Endpoint IP%v cable %q", family, endpoint.Spec.CableName) - i.natDiscovery.RemoveEndpoint(endpoint.Spec.CableName) + i.natDiscovery.RemoveEndpoint(endpoint.Spec.GetFamilyCableName(family)) i.Lock() defer i.Unlock() - delete(i.natDiscoveryPending, endpoint.Spec.CableName) + delete(i.natDiscoveryPending, endpoint.Spec.GetFamilyCableName(family)) if _, ok := i.installedCables[endpoint.Spec.CableName]; !ok { return nil @@ -274,9 +275,9 @@ func (i *engine) RemoveCable(endpoint *v1.Endpoint) error { return errors.Wrapf(err, "error disconnecting Endpoint cable %q", endpoint.Spec.CableName) } - delete(i.installedCables, endpoint.Spec.CableName) + delete(i.installedCables, endpoint.Spec.GetFamilyCableName(family)) - logger.Infof("Successfully removed Endpoint cable %q", endpoint.Spec.CableName) + logger.Infof("Successfully removed IP%v Endpoint cable %q", family, endpoint.Spec.CableName) return nil } diff --git a/pkg/cableengine/cableengine_test.go b/pkg/cableengine/cableengine_test.go index 63a3368b2..9c136570f 100644 --- a/pkg/cableengine/cableengine_test.go +++ b/pkg/cableengine/cableengine_test.go @@ -128,7 +128,7 @@ var _ = Describe("Cable Engine", func() { When("install cable for a remote endpoint", func() { Context("and no endpoint was previously installed for the cluster", func() { It("should connect to the endpoint", func() { - Expect(engine.InstallCable(remoteEndpoint)).To(Succeed()) + Expect(engine.InstallCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(remoteEndpoint)) }) }) @@ -144,10 +144,10 @@ var _ = Describe("Cable Engine", func() { }) JustBeforeEach(func() { - Expect(engine.InstallCable(prevEndpoint)).To(Succeed()) + Expect(engine.InstallCable(prevEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(prevEndpoint)) - Expect(engine.InstallCable(newEndpoint)).To(Succeed()) + Expect(engine.InstallCable(newEndpoint, k8snet.IPv4)).To(Succeed()) }) testTimestamps := func() { @@ -226,10 +226,10 @@ var _ = Describe("Cable Engine", func() { CableName: "submariner-cable-other-1.1.1.1", }} - Expect(engine.InstallCable(&otherEndpoint)).To(Succeed()) + Expect(engine.InstallCable(&otherEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(&otherEndpoint)) - Expect(engine.InstallCable(remoteEndpoint)).To(Succeed()) + Expect(engine.InstallCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(remoteEndpoint)) fakeDriver.AwaitNoDisconnectFromEndpoint() }) @@ -241,11 +241,11 @@ var _ = Describe("Cable Engine", func() { }) It("should not connect to the endpoint", func() { - Expect(engine.InstallCable(remoteEndpoint)).To(Succeed()) + Expect(engine.InstallCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) Eventually(natDiscovery.captureAddEndpoint).Should(Receive()) - Expect(engine.RemoveCable(remoteEndpoint)).To(Succeed()) - Eventually(natDiscovery.removeEndpoint).Should(Receive(Equal(remoteEndpoint.Spec.CableName))) + Expect(engine.RemoveCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) + Eventually(natDiscovery.removeEndpoint).Should(Receive(Equal(remoteEndpoint.Spec.GetFamilyCableName(k8snet.IPv4)))) fakeDriver.AwaitNoDisconnectFromEndpoint() natDiscovery.notifyReady(remoteEndpoint) @@ -256,21 +256,21 @@ var _ = Describe("Cable Engine", func() { When("install cable for a local endpoint", func() { It("should not connect to the endpoint", func() { - Expect(engine.InstallCable(localEndpoint)).To(Succeed()) + Expect(engine.InstallCable(localEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitNoConnectToEndpoint() }) }) When("remove cable for a remote endpoint", func() { JustBeforeEach(func() { - Expect(engine.InstallCable(remoteEndpoint)).To(Succeed()) + Expect(engine.InstallCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(remoteEndpoint)) }) It("should disconnect from the endpoint", func() { - Expect(engine.RemoveCable(remoteEndpoint)).To(Succeed()) + Expect(engine.RemoveCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitDisconnectFromEndpoint(&remoteEndpoint.Spec) - Eventually(natDiscovery.removeEndpoint).Should(Receive(Equal(remoteEndpoint.Spec.CableName))) + Eventually(natDiscovery.removeEndpoint).Should(Receive(Equal(remoteEndpoint.Spec.GetFamilyCableName(k8snet.IPv4)))) }) Context("and the driver fails to disconnect from the endpoint", func() { @@ -279,19 +279,19 @@ var _ = Describe("Cable Engine", func() { }) It("should return an error", func() { - Expect(engine.RemoveCable(remoteEndpoint)).To(HaveOccurred()) + Expect(engine.RemoveCable(remoteEndpoint, k8snet.IPv4)).To(HaveOccurred()) }) }) }) When("remove cable for a local endpoint", func() { JustBeforeEach(func() { - Expect(engine.InstallCable(remoteEndpoint)).To(Succeed()) + Expect(engine.InstallCable(remoteEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitConnectToEndpoint(natEndpointInfoFor(remoteEndpoint)) }) It("should not disconnect from the endpoint", func() { - Expect(engine.RemoveCable(localEndpoint)).To(Succeed()) + Expect(engine.RemoveCable(localEndpoint, k8snet.IPv4)).To(Succeed()) fakeDriver.AwaitNoDisconnectFromEndpoint() Consistently(natDiscovery.removeEndpoint).ShouldNot(Receive()) }) @@ -389,7 +389,7 @@ func (n *fakeNATDiscovery) Run(_ <-chan struct{}) error { return nil } -func (n *fakeNATDiscovery) AddEndpoint(endpoint *subv1.Endpoint) { +func (n *fakeNATDiscovery) AddEndpoint(endpoint *subv1.Endpoint, _ k8snet.IPFamily) { if n.captureAddEndpoint != nil { n.captureAddEndpoint <- endpoint return @@ -415,5 +415,6 @@ func natEndpointInfoFor(endpoint *subv1.Endpoint) *natdiscovery.NATEndpointInfo UseIP: endpoint.Spec.GetPublicIP(k8snet.IPv4), UseNAT: true, Endpoint: *endpoint, + Family: k8snet.IPv4, } } diff --git a/pkg/cableengine/fake/cableengine.go b/pkg/cableengine/fake/cableengine.go index ac245f0fe..d9186f29d 100644 --- a/pkg/cableengine/fake/cableengine.go +++ b/pkg/cableengine/fake/cableengine.go @@ -26,6 +26,7 @@ import ( "github.com/submariner-io/submariner/pkg/cableengine" "github.com/submariner-io/submariner/pkg/natdiscovery" "github.com/submariner-io/submariner/pkg/types" + k8snet "k8s.io/utils/net" ) type Engine struct { //nolint:gocritic // This mutex is exposed but we tweak it in tests @@ -62,7 +63,7 @@ func (e *Engine) StartEngine() error { func (e *Engine) Stop() { } -func (e *Engine) InstallCable(endpoint *v1.Endpoint) error { +func (e *Engine) InstallCable(endpoint *v1.Endpoint, _ k8snet.IPFamily) error { err := e.ErrOnInstallCable if err != nil { e.ErrOnInstallCable = nil @@ -74,7 +75,7 @@ func (e *Engine) InstallCable(endpoint *v1.Endpoint) error { return nil } -func (e *Engine) RemoveCable(endpoint *v1.Endpoint) error { +func (e *Engine) RemoveCable(endpoint *v1.Endpoint, _ k8snet.IPFamily) error { err := e.ErrOnRemoveCable if err != nil { e.ErrOnRemoveCable = nil diff --git a/pkg/controllers/tunnel/tunnel.go b/pkg/controllers/tunnel/tunnel.go index 088f2227c..508947bd6 100644 --- a/pkg/controllers/tunnel/tunnel.go +++ b/pkg/controllers/tunnel/tunnel.go @@ -27,20 +27,39 @@ import ( v1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" "github.com/submariner-io/submariner/pkg/cableengine" "k8s.io/apimachinery/pkg/runtime" + k8snet "k8s.io/utils/net" logf "sigs.k8s.io/controller-runtime/pkg/log" ) type controller struct { - engine cableengine.Engine + engine cableengine.Engine + localIPFamilies [2]k8snet.IPFamily } var logger = log.Logger{Logger: logf.Log.WithName("Tunnel")} +func findCommonIPFamilies(local, remote [2]k8snet.IPFamily) []k8snet.IPFamily { + common := []k8snet.IPFamily{} + + for _, lf := range local { + for _, rf := range remote { + if lf == rf { + common = append(common, lf) + break + } + } + } + + return common +} + func StartController(engine cableengine.Engine, namespace string, config *watcher.Config, stopCh <-chan struct{}) error { logger.Info("Starting the tunnel controller") c := &controller{engine: engine} + c.localIPFamilies = c.engine.GetLocalEndpoint().Spec.GetIPFamilies() + config.ResourceConfigs = []watcher.ResourceConfig{ { Name: "Tunnel Controller", @@ -76,26 +95,42 @@ func (c *controller) handleCreatedOrUpdatedEndpoint(obj runtime.Object, _ int) b logger.V(log.TRACE).Infof("Tunnel controller processing added or updated submariner Endpoint object: %#v", endpoint) - err := c.engine.InstallCable(endpoint) - if err != nil { - logger.Errorf(err, "Error installing cable for Endpoint %#v", endpoint) - return true + commonIPFamilies := findCommonIPFamilies(c.localIPFamilies, endpoint.Spec.GetIPFamilies()) + + var errs []error + + for _, family := range commonIPFamilies { + err := c.engine.InstallCable(endpoint, family) + if err != nil { + logger.Errorf(err, "Error installing IP%v cable for Endpoint %#v", family, endpoint) + errs = append(errs, err) + } } - return false + return len(errs) > 0 } func (c *controller) handleRemovedEndpoint(obj runtime.Object, _ int) bool { endpoint := obj.(*v1.Endpoint) + commonIPFamilies := findCommonIPFamilies(c.localIPFamilies, endpoint.Spec.GetIPFamilies()) + logger.V(log.DEBUG).Infof("Tunnel controller processing removed submariner Endpoint object: %#v", endpoint) - if err := c.engine.RemoveCable(endpoint); err != nil { - logger.Errorf(err, "Tunnel controller failed to remove Endpoint cable %#v from the engine", endpoint) + var errs []error + + for _, family := range commonIPFamilies { + if err := c.engine.RemoveCable(endpoint, family); err != nil { + logger.Errorf(err, "Tunnel controller failed to remove Endpoint IP%v cable %#v from the engine", family, endpoint) + errs = append(errs, err) + } + } + + if len(errs) > 0 { return true } - logger.V(log.DEBUG).Infof("Tunnel controller successfully removed Endpoint cable %s from the engine", endpoint.Spec.CableName) + logger.V(log.DEBUG).Infof("Tunnel controller processing removed submariner Endpoint object: %#v", endpoint) return false } diff --git a/pkg/controllers/tunnel/tunnel_test.go b/pkg/controllers/tunnel/tunnel_test.go index 36b47c766..76a3612b2 100644 --- a/pkg/controllers/tunnel/tunnel_test.go +++ b/pkg/controllers/tunnel/tunnel_test.go @@ -133,6 +133,7 @@ var _ = Describe("Managing tunnels", func() { UseIP: endpoint.Spec.GetPrivateIP(k8snet.IPv4), UseNAT: false, Endpoint: *endpoint, + Family: k8snet.IPv4, }) } diff --git a/pkg/endpoint/local_ip.go b/pkg/endpoint/local_ip.go index 95ff56355..c0ab2560e 100644 --- a/pkg/endpoint/local_ip.go +++ b/pkg/endpoint/local_ip.go @@ -36,6 +36,18 @@ func GetLocalIPForDestination(dst string) string { return localAddr.IP.String() } +func GetLocalIPForDest(dst string, family k8snet.IPFamily) string { + switch family { + case k8snet.IPv4: + return GetLocalIPForDestination(dst) + case k8snet.IPv6: + // TODO_IPV6: add V6 healthcheck IP + case k8snet.IPFamilyUnknown: + } + + return "" +} + func GetLocalIP(family k8snet.IPFamily) string { switch family { case k8snet.IPv4: diff --git a/pkg/gateway/gateway_test.go b/pkg/gateway/gateway_test.go index 85eb4504a..2beb74e3a 100644 --- a/pkg/gateway/gateway_test.go +++ b/pkg/gateway/gateway_test.go @@ -55,6 +55,7 @@ import ( dynamicfake "k8s.io/client-go/dynamic/fake" k8sfake "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/kubernetes/scheme" + k8snet "k8s.io/utils/net" ) const publicIP = "1.2.3.4" @@ -487,7 +488,7 @@ func (n *fakeNATDiscovery) Run(_ <-chan struct{}) error { return nil } -func (n *fakeNATDiscovery) AddEndpoint(ep *submarinerv1.Endpoint) { +func (n *fakeNATDiscovery) AddEndpoint(ep *submarinerv1.Endpoint, _ k8snet.IPFamily) { n.readyChannel <- &natdiscovery.NATEndpointInfo{ Endpoint: *ep, } diff --git a/pkg/natdiscovery/listener.go b/pkg/natdiscovery/listener.go index f34280318..c351e7cea 100644 --- a/pkg/natdiscovery/listener.go +++ b/pkg/natdiscovery/listener.go @@ -26,9 +26,26 @@ import ( "github.com/pkg/errors" natproto "github.com/submariner-io/submariner/pkg/natdiscovery/proto" "google.golang.org/protobuf/proto" + utilerrors "k8s.io/apimachinery/pkg/util/errors" + k8snet "k8s.io/utils/net" ) -func (nd *natDiscovery) runListener(stopCh <-chan struct{}) error { +func (nd *natDiscovery) runListener(family k8snet.IPFamily, stopCh <-chan struct{}) error { + var errs []error + + if family == k8snet.IPv4 { + err := nd.runListenerV4(stopCh) + if err != nil { + logger.Errorf(err, "Error running IP%v listener", family) + errs = append(errs, err) + } + } + + // TODO_IPV6: add V6 runListener for V6 + return utilerrors.NewAggregate(errs) +} + +func (nd *natDiscovery) runListenerV4(stopCh <-chan struct{}) error { if nd.serverPort == 0 { logger.Infof("NAT discovery protocol port not set for this gateway") return nil diff --git a/pkg/natdiscovery/natdiscovery.go b/pkg/natdiscovery/natdiscovery.go index 165db1822..7208438e6 100644 --- a/pkg/natdiscovery/natdiscovery.go +++ b/pkg/natdiscovery/natdiscovery.go @@ -30,19 +30,20 @@ import ( v1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" "github.com/submariner-io/submariner/pkg/endpoint" "k8s.io/apimachinery/pkg/util/wait" + k8snet "k8s.io/utils/net" logf "sigs.k8s.io/controller-runtime/pkg/log" ) type Interface interface { Run(stopCh <-chan struct{}) error - AddEndpoint(endpoint *v1.Endpoint) + AddEndpoint(endpoint *v1.Endpoint, family k8snet.IPFamily) RemoveEndpoint(endpointName string) GetReadyChannel() chan *NATEndpointInfo } type ( udpWriteFunction func(b []byte, addr *net.UDPAddr) (int, error) - findSrcIPFunction func(destinationIP string) string + findSrcIPFunction func(destinationIP string, family k8snet.IPFamily) string ) type natDiscovery struct { @@ -73,7 +74,7 @@ func newNATDiscovery(localEndpoint *endpoint.Local) (*natDiscovery, error) { localEndpoint: localEndpoint, serverPort: ndPort, remoteEndpoints: map[string]*remoteEndpointNAT{}, - findSrcIP: endpoint.GetLocalIPForDestination, + findSrcIP: endpoint.GetLocalIPForDest, requestCounter: rand.Uint64(), readyChannel: make(chan *NATEndpointInfo, 100), }, nil @@ -101,8 +102,12 @@ func (nd *natDiscovery) GetReadyChannel() chan *NATEndpointInfo { func (nd *natDiscovery) Run(stopCh <-chan struct{}) error { logger.V(log.DEBUG).Infof("NAT discovery server starting on port %d", nd.serverPort) - if err := nd.runListener(stopCh); err != nil { - return err + for _, family := range nd.localEndpoint.Spec().GetIPFamilies() { + if err := nd.runListener(family, stopCh); err != nil { + return err + } + + logger.V(log.TRACE).Infof("NAT discovery start listener IP%v", family) } go wait.Until(func() { @@ -113,11 +118,11 @@ func (nd *natDiscovery) Run(stopCh <-chan struct{}) error { return nil } -func (nd *natDiscovery) AddEndpoint(endPoint *v1.Endpoint) { +func (nd *natDiscovery) AddEndpoint(endPoint *v1.Endpoint, family k8snet.IPFamily) { nd.Lock() defer nd.Unlock() - if ep, exists := nd.remoteEndpoints[endPoint.Spec.CableName]; exists { + if ep, exists := nd.remoteEndpoints[endPoint.Spec.GetFamilyCableName(family)]; exists { if reflect.DeepEqual(ep.endpoint.Spec, endPoint.Spec) { if ep.isDiscoveryComplete() { nd.readyChannel <- ep.toNATEndpointInfo() @@ -126,11 +131,12 @@ func (nd *natDiscovery) AddEndpoint(endPoint *v1.Endpoint) { return } - logger.V(log.DEBUG).Infof("NAT discovery updated endpoint %q", endPoint.Spec.CableName) - delete(nd.remoteEndpoints, endPoint.Spec.CableName) + logger.V(log.DEBUG).Infof("NAT discovery updated endpoint IP%v %q", family, endPoint.Spec.CableName) + + delete(nd.remoteEndpoints, endPoint.Spec.GetFamilyCableName(family)) } - remoteNAT := newRemoteEndpointNAT(endPoint) + remoteNAT := newRemoteEndpointNAT(endPoint, family) // support nat discovery disabled or a remote cluster endpoint which still hasn't implemented this protocol if _, err := extractNATDiscoveryPort(&endPoint.Spec); err != nil || nd.serverPort == 0 { @@ -144,7 +150,7 @@ func (nd *natDiscovery) AddEndpoint(endPoint *v1.Endpoint) { logger.Infof("Starting NAT discovery for endpoint %q", endPoint.Spec.CableName) } - nd.remoteEndpoints[endPoint.Spec.CableName] = remoteNAT + nd.remoteEndpoints[endPoint.Spec.GetFamilyCableName(family)] = remoteNAT } func (nd *natDiscovery) RemoveEndpoint(endpointName string) { @@ -158,7 +164,7 @@ func (nd *natDiscovery) checkEndpointList() { defer nd.Unlock() for _, endpointNAT := range nd.remoteEndpoints { - name := endpointNAT.endpoint.Spec.CableName + name := endpointNAT.endpoint.Spec.GetFamilyCableName(endpointNAT.family) logger.V(log.TRACE).Infof("NAT processing remote endpoint %q", name) if endpointNAT.shouldCheck() { diff --git a/pkg/natdiscovery/natdiscovery_internal_test.go b/pkg/natdiscovery/natdiscovery_internal_test.go index 998fa0432..72576e201 100644 --- a/pkg/natdiscovery/natdiscovery_internal_test.go +++ b/pkg/natdiscovery/natdiscovery_internal_test.go @@ -50,7 +50,7 @@ var _ = When("a remote Endpoint is added", func() { JustBeforeEach(func() { forwardFromUDPChan(t.localUDPSent, t.localUDPAddr, t.remoteND, forwardHowManyFromLocal) - t.localND.AddEndpoint(&t.remoteEndpoint) + t.localND.AddEndpoint(&t.remoteEndpoint, k8snet.IPv4) t.localND.checkEndpointList() }) @@ -74,7 +74,7 @@ var _ = When("a remote Endpoint is added", func() { BeforeEach(func() { forwardHowManyFromLocal = 0 t.remoteEndpoint.Spec.PublicIPs = []string{testRemotePublicIP} - t.remoteND.AddEndpoint(&t.localEndpoint) + t.remoteND.AddEndpoint(&t.localEndpoint, k8snet.IPv4) }) JustBeforeEach(func() { @@ -91,6 +91,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: true, UseIP: t.remoteEndpoint.Spec.GetPublicIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) Expect(t.remoteND.parseAndHandleMessageFromAddress(privateIPReq, t.localUDPAddr)). @@ -100,6 +101,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: false, UseIP: t.remoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -115,6 +117,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: true, UseIP: t.remoteEndpoint.Spec.GetPublicIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) Expect(t.remoteND.parseAndHandleMessageFromAddress(privateIPReq, t.localUDPAddr)). @@ -133,6 +136,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: false, UseIP: t.remoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) Expect(t.remoteND.parseAndHandleMessageFromAddress(publicIPReq, t.localUDPAddr)). @@ -149,6 +153,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: false, UseIP: t.remoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -157,7 +162,7 @@ var _ = When("a remote Endpoint is added", func() { var newRemoteEndpoint submarinerv1.Endpoint BeforeEach(func() { - t.remoteND.AddEndpoint(&t.localEndpoint) + t.remoteND.AddEndpoint(&t.localEndpoint, k8snet.IPv4) newRemoteEndpoint = t.remoteEndpoint }) @@ -167,7 +172,7 @@ var _ = When("a remote Endpoint is added", func() { t.remoteUDPAddr.IP = net.ParseIP(newRemoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4)) forwardFromUDPChan(t.localUDPSent, t.localUDPAddr, t.remoteND, 1) - t.localND.AddEndpoint(&newRemoteEndpoint) + t.localND.AddEndpoint(&newRemoteEndpoint, k8snet.IPv4) t.localND.checkEndpointList() }) @@ -177,6 +182,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: false, UseIP: t.remoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -194,6 +200,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: newRemoteEndpoint, UseNAT: false, UseIP: newRemoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -204,12 +211,12 @@ var _ = When("a remote Endpoint is added", func() { BeforeEach(func() { forwardHowManyFromLocal = 0 - t.remoteND.AddEndpoint(&t.localEndpoint) + t.remoteND.AddEndpoint(&t.localEndpoint, k8snet.IPv4) newRemoteEndpoint = t.remoteEndpoint }) JustBeforeEach(func() { - t.localND.AddEndpoint(&newRemoteEndpoint) + t.localND.AddEndpoint(&newRemoteEndpoint, k8snet.IPv4) }) Context("with no change to the Endpoint", func() { @@ -237,6 +244,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: newRemoteEndpoint, UseNAT: false, UseIP: newRemoteEndpoint.Spec.GetPrivateIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -251,7 +259,7 @@ var _ = When("a remote Endpoint is added", func() { Expect(t.localUDPSent).To(Receive()) Consistently(t.readyChannel).ShouldNot(Receive()) - t.localND.RemoveEndpoint(t.remoteEndpoint.Spec.CableName) + t.localND.RemoveEndpoint(t.remoteEndpoint.Spec.GetFamilyCableName(k8snet.IPv4)) t.localND.checkEndpointList() Expect(t.localUDPSent).ToNot(Receive()) @@ -270,6 +278,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: true, UseIP: t.remoteEndpoint.Spec.GetPublicIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -294,6 +303,7 @@ var _ = When("a remote Endpoint is added", func() { Endpoint: t.remoteEndpoint, UseNAT: true, UseIP: t.remoteEndpoint.Spec.GetPublicIP(k8snet.IPv4), + Family: k8snet.IPv4, }))) }) }) @@ -336,10 +346,10 @@ func newDiscoveryTestDriver() *discoveryTestDriver { t.localEndpoint = createTestLocalEndpoint() t.localND, t.localUDPSent, t.readyChannel = createTestListener(&t.localEndpoint) - t.localND.findSrcIP = func(_ string) string { return testLocalPrivateIP } + t.localND.findSrcIP = func(_ string, _ k8snet.IPFamily) string { return testLocalPrivateIP } t.remoteND, t.remoteUDPSent, _ = createTestListener(&t.remoteEndpoint) - t.remoteND.findSrcIP = func(_ string) string { return testRemotePrivateIP } + t.remoteND.findSrcIP = func(_ string, _ k8snet.IPFamily) string { return testRemotePrivateIP } forwardFromUDPChan(t.remoteUDPSent, t.remoteUDPAddr, t.localND, -1) }) @@ -358,7 +368,7 @@ func newDiscoveryTestDriver() *discoveryTestDriver { func (t *discoveryTestDriver) testRemoteEndpointAdded(expIP string, expectNAT bool) { BeforeEach(func() { - t.remoteND.AddEndpoint(&t.localEndpoint) + t.remoteND.AddEndpoint(&t.localEndpoint, k8snet.IPv4) }) It("should notify with the correct NATEndpointInfo settings and stop the discovery", func() { @@ -366,6 +376,7 @@ func (t *discoveryTestDriver) testRemoteEndpointAdded(expIP string, expectNAT bo Endpoint: t.remoteEndpoint, UseNAT: expectNAT, UseIP: expIP, + Family: k8snet.IPv4, }))) // Verify it doesn't time out and try to notify of the legacy settings diff --git a/pkg/natdiscovery/proto/natdiscovery.pb.go b/pkg/natdiscovery/proto/natdiscovery.pb.go index 6a2334f89..1bc1354f5 100644 --- a/pkg/natdiscovery/proto/natdiscovery.pb.go +++ b/pkg/natdiscovery/proto/natdiscovery.pb.go @@ -18,7 +18,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.1 -// protoc v3.17.3 +// protoc v3.19.6 // source: pkg/natdiscovery/proto/natdiscovery.proto package proto diff --git a/pkg/natdiscovery/remote_endpoint.go b/pkg/natdiscovery/remote_endpoint.go index 7d9236f97..b31819099 100644 --- a/pkg/natdiscovery/remote_endpoint.go +++ b/pkg/natdiscovery/remote_endpoint.go @@ -45,6 +45,7 @@ var ( type remoteEndpointNAT struct { endpoint v1.Endpoint + family k8snet.IPFamily state endpointState lastCheck time.Time lastTransition time.Time @@ -61,6 +62,7 @@ type NATEndpointInfo struct { Endpoint v1.Endpoint UseNAT bool UseIP string + Family k8snet.IPFamily } func (rn *remoteEndpointNAT) toNATEndpointInfo() *NATEndpointInfo { @@ -68,12 +70,14 @@ func (rn *remoteEndpointNAT) toNATEndpointInfo() *NATEndpointInfo { Endpoint: rn.endpoint, UseNAT: rn.useNAT, UseIP: rn.useIP, + Family: rn.family, } } -func newRemoteEndpointNAT(endpoint *v1.Endpoint) *remoteEndpointNAT { +func newRemoteEndpointNAT(endpoint *v1.Endpoint, family k8snet.IPFamily) *remoteEndpointNAT { rnat := &remoteEndpointNAT{ endpoint: *endpoint, + family: family, state: testingPrivateAndPublicIPs, started: time.Now(), lastTransition: time.Now(), @@ -110,24 +114,24 @@ func (rn *remoteEndpointNAT) useLegacyNATSettings() { switch { case rn.usingLoadBalancer: rn.useNAT = true - rn.useIP = rn.endpoint.Spec.GetPublicIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPublicIP(rn.family) rn.transitionToState(selectedPublicIP) - logger.V(log.DEBUG).Infof("using NAT for the load balancer backed endpoint %q, using public IP %q", rn.endpoint.Spec.CableName, - rn.useIP) + logger.V(log.DEBUG).Infof("using NAT for the load balancer backed endpoint %q, using public IP %q", + rn.endpoint.Spec.GetFamilyCableName(rn.family), rn.useIP) case rn.endpoint.Spec.NATEnabled: rn.useNAT = true - rn.useIP = rn.endpoint.Spec.GetPublicIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPublicIP(rn.family) rn.transitionToState(selectedPublicIP) - logger.V(log.DEBUG).Infof("using NAT legacy settings for endpoint %q, using public IP %q", rn.endpoint.Spec.CableName, + logger.V(log.DEBUG).Infof("using NAT legacy settings for endpoint %q, using public IP %q", rn.endpoint.Spec.GetFamilyCableName(rn.family), rn.useIP) default: rn.useNAT = false - rn.useIP = rn.endpoint.Spec.GetPrivateIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPrivateIP(rn.family) rn.transitionToState(selectedPrivateIP) - logger.V(log.DEBUG).Infof("using NAT legacy settings for endpoint %q, using private IP %q", rn.endpoint.Spec.CableName, - rn.useIP) + logger.V(log.DEBUG).Infof("using NAT legacy settings for endpoint %q, using private IP %q", + rn.endpoint.Spec.GetFamilyCableName(rn.family), rn.useIP) } } @@ -159,10 +163,10 @@ func (rn *remoteEndpointNAT) checkSent() { func (rn *remoteEndpointNAT) transitionToPublicIP(remoteEndpointID string, useNAT bool) bool { switch rn.state { case waitingForResponse: - rn.useIP = rn.endpoint.Spec.GetPublicIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPublicIP(rn.family) rn.useNAT = useNAT rn.transitionToState(selectedPublicIP) - logger.V(log.DEBUG).Infof("selected public IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.CableName) + logger.V(log.DEBUG).Infof("selected public IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.GetFamilyCableName(rn.family)) return true case selectedPrivateIP: @@ -179,10 +183,10 @@ func (rn *remoteEndpointNAT) transitionToPublicIP(remoteEndpointID string, useNA func (rn *remoteEndpointNAT) transitionToPrivateIP(remoteEndpointID string, useNAT bool) bool { switch rn.state { case waitingForResponse: - rn.useIP = rn.endpoint.Spec.GetPrivateIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPrivateIP(rn.family) rn.useNAT = useNAT rn.transitionToState(selectedPrivateIP) - logger.V(log.DEBUG).Infof("selected private IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.CableName) + logger.V(log.DEBUG).Infof("selected private IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.GetFamilyCableName(rn.family)) return true case selectedPublicIP: @@ -194,10 +198,10 @@ func (rn *remoteEndpointNAT) transitionToPrivateIP(remoteEndpointID string, useN return false } - rn.useIP = rn.endpoint.Spec.GetPrivateIP(k8snet.IPv4) + rn.useIP = rn.endpoint.Spec.GetPrivateIP(rn.family) rn.useNAT = useNAT rn.transitionToState(selectedPrivateIP) - logger.V(log.DEBUG).Infof("updated to private IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.CableName) + logger.V(log.DEBUG).Infof("updated to private IP %q for endpoint %q", rn.useIP, rn.endpoint.Spec.GetFamilyCableName(rn.family)) return true case testingPrivateAndPublicIPs: diff --git a/pkg/natdiscovery/remote_endpoint_internal_test.go b/pkg/natdiscovery/remote_endpoint_internal_test.go index 5687bacff..37a7e6c05 100644 --- a/pkg/natdiscovery/remote_endpoint_internal_test.go +++ b/pkg/natdiscovery/remote_endpoint_internal_test.go @@ -33,7 +33,7 @@ var _ = Describe("remoteEndpointNAT", func() { BeforeEach(func() { remoteEndpoint = createTestRemoteEndpoint() - rnat = newRemoteEndpointNAT(&remoteEndpoint) + rnat = newRemoteEndpointNAT(&remoteEndpoint, k8snet.IPv4) }) When("first created", func() { @@ -62,7 +62,7 @@ var _ = Describe("remoteEndpointNAT", func() { When("targeting a load balancer", func() { It("should report as timed out earlier", func() { remoteEndpoint.Spec.BackendConfig[submarinerv1.UsingLoadBalancer] = "true" - rnat = newRemoteEndpointNAT(&remoteEndpoint) + rnat = newRemoteEndpointNAT(&remoteEndpoint, k8snet.IPv4) rnat.started = time.Now().Add(-toDuration(&totalTimeoutLoadBalancer)) Expect(rnat.hasTimedOut()).To(BeTrue()) }) @@ -89,7 +89,7 @@ var _ = Describe("remoteEndpointNAT", func() { Context("and targeting a load balancer", func() { It("should select the public IP and NAT", func() { remoteEndpoint.Spec.BackendConfig[submarinerv1.UsingLoadBalancer] = "true" - rnat = newRemoteEndpointNAT(&remoteEndpoint) + rnat = newRemoteEndpointNAT(&remoteEndpoint, k8snet.IPv4) rnat.endpoint.Spec.NATEnabled = false rnat.useLegacyNATSettings() Expect(rnat.state).To(Equal(selectedPublicIP)) @@ -102,7 +102,7 @@ var _ = Describe("remoteEndpointNAT", func() { When("the public IP is selected but no check was sent", func() { It("it should not transition the state", func() { oldState := rnat.state - Expect(rnat.transitionToPublicIP(testRemoteEndpointName, false)).To(BeFalse()) + Expect(rnat.transitionToPublicIP(testRemoteEndpointNameAndFamily, false)).To(BeFalse()) Expect(rnat.state).To(Equal(oldState)) Expect(rnat.useIP).To(Equal("")) }) @@ -111,7 +111,7 @@ var _ = Describe("remoteEndpointNAT", func() { When("the private IP is selected but no check was sent", func() { It("it should not transition the state", func() { oldState := rnat.state - Expect(rnat.transitionToPrivateIP(testRemoteEndpointName, false)).To(BeFalse()) + Expect(rnat.transitionToPrivateIP(testRemoteEndpointNameAndFamily, false)).To(BeFalse()) Expect(rnat.state).To(Equal(oldState)) Expect(rnat.useIP).To(Equal("")) }) @@ -123,7 +123,7 @@ var _ = Describe("remoteEndpointNAT", func() { JustBeforeEach(func() { rnat.checkSent() - Expect(rnat.transitionToPrivateIP(testRemoteEndpointName, useNAT)).To(BeTrue()) + Expect(rnat.transitionToPrivateIP(testRemoteEndpointNameAndFamily, useNAT)).To(BeTrue()) Expect(rnat.state).To(Equal(selectedPrivateIP)) }) @@ -158,7 +158,7 @@ var _ = Describe("remoteEndpointNAT", func() { JustBeforeEach(func() { rnat.checkSent() - Expect(rnat.transitionToPublicIP(testRemoteEndpointName, useNAT)).To(BeTrue()) + Expect(rnat.transitionToPublicIP(testRemoteEndpointNameAndFamily, useNAT)).To(BeTrue()) Expect(rnat.state).To(Equal(selectedPublicIP)) }) @@ -191,8 +191,8 @@ var _ = Describe("remoteEndpointNAT", func() { Context("and the grace period has not elapsed", func() { It("should use the private IP", func() { rnat.checkSent() - Expect(rnat.transitionToPublicIP(testRemoteEndpointName, true)).To(BeTrue()) - Expect(rnat.transitionToPrivateIP(testRemoteEndpointName, false)).To(BeTrue()) + Expect(rnat.transitionToPublicIP(testRemoteEndpointNameAndFamily, true)).To(BeTrue()) + Expect(rnat.transitionToPrivateIP(testRemoteEndpointNameAndFamily, false)).To(BeTrue()) Expect(rnat.state).To(Equal(selectedPrivateIP)) Expect(rnat.useIP).To(Equal(rnat.endpoint.Spec.GetPrivateIP(k8snet.IPv4))) Expect(rnat.useNAT).To(BeFalse()) @@ -202,9 +202,9 @@ var _ = Describe("remoteEndpointNAT", func() { Context("and the grace period has elapsed", func() { It("should still use the public IP", func() { rnat.checkSent() - Expect(rnat.transitionToPublicIP(testRemoteEndpointName, true)).To(BeTrue()) + Expect(rnat.transitionToPublicIP(testRemoteEndpointNameAndFamily, true)).To(BeTrue()) rnat.lastTransition = rnat.lastTransition.Add(-time.Duration(publicToPrivateFailoverTimeout)) - Expect(rnat.transitionToPrivateIP(testRemoteEndpointName, false)).To(BeFalse()) + Expect(rnat.transitionToPrivateIP(testRemoteEndpointNameAndFamily, false)).To(BeFalse()) Expect(rnat.state).To(Equal(selectedPublicIP)) Expect(rnat.useIP).To(Equal(rnat.endpoint.Spec.GetPublicIP(k8snet.IPv4))) Expect(rnat.useNAT).To(BeTrue()) diff --git a/pkg/natdiscovery/request_handle.go b/pkg/natdiscovery/request_handle.go index dcea82007..2375fa478 100644 --- a/pkg/natdiscovery/request_handle.go +++ b/pkg/natdiscovery/request_handle.go @@ -31,11 +31,16 @@ import ( func (nd *natDiscovery) handleRequestFromAddress(req *proto.SubmarinerNATDiscoveryRequest, addr *net.UDPAddr) error { localEndpointSpec := nd.localEndpoint.Spec() + family := k8snet.IPv4 + if addr.IP.To4() == nil { + family = k8snet.IPv6 + } + response := proto.SubmarinerNATDiscoveryResponse{ RequestNumber: req.GetRequestNumber(), Sender: &proto.EndpointDetails{ ClusterId: localEndpointSpec.ClusterID, - EndpointId: localEndpointSpec.CableName, + EndpointId: localEndpointSpec.GetFamilyCableName(family), }, Receiver: req.GetSender(), ReceivedSrc: &proto.IPPortPair{ @@ -64,7 +69,7 @@ func (nd *natDiscovery) handleRequestFromAddress(req *proto.SubmarinerNATDiscove return nd.sendResponseToAddress(&response, addr) } - if req.GetReceiver().GetEndpointId() != localEndpointSpec.CableName { + if req.GetReceiver().GetEndpointId() != localEndpointSpec.GetFamilyCableName(k8snet.IPv4) { logger.Warningf("Received NAT discovery packet for endpoint %q, but we are endpoint %q "+ "if the port for NAT discovery has been mapped somewhere an error may exist", req.GetReceiver().GetEndpointId(), localEndpointSpec.CableName) diff --git a/pkg/natdiscovery/request_handle_internal_test.go b/pkg/natdiscovery/request_handle_internal_test.go index 99c22ccf5..5eec5459b 100644 --- a/pkg/natdiscovery/request_handle_internal_test.go +++ b/pkg/natdiscovery/request_handle_internal_test.go @@ -26,6 +26,7 @@ import ( submarinerv1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" natproto "github.com/submariner-io/submariner/pkg/natdiscovery/proto" "google.golang.org/protobuf/proto" + k8snet "k8s.io/utils/net" ) var _ = Describe("Request handling", func() { @@ -43,9 +44,9 @@ var _ = Describe("Request handling", func() { remoteEndpoint = createTestRemoteEndpoint() localListener, localUDPSent, _ = createTestListener(&localEndpoint) - localListener.findSrcIP = func(_ string) string { return testLocalPrivateIP } + localListener.findSrcIP = func(_ string, _ k8snet.IPFamily) string { return testLocalPrivateIP } remoteListener, remoteUDPSent, _ = createTestListener(&remoteEndpoint) - remoteListener.findSrcIP = func(_ string) string { return testRemotePrivateIP } + remoteListener.findSrcIP = func(_ string, _ k8snet.IPFamily) string { return testRemotePrivateIP } remoteUDPAddr = net.UDPAddr{ IP: net.ParseIP(testRemotePrivateIP), @@ -60,7 +61,7 @@ var _ = Describe("Request handling", func() { } requestResponseFromRemoteToLocal := func(remoteAddr *net.UDPAddr) []*natproto.SubmarinerNATDiscoveryResponse { - err := remoteListener.sendCheckRequest(newRemoteEndpointNAT(&localEndpoint)) + err := remoteListener.sendCheckRequest(newRemoteEndpointNAT(&localEndpoint, k8snet.IPv4)) Expect(err).NotTo(HaveOccurred()) return []*natproto.SubmarinerNATDiscoveryResponse{ parseResponseInLocalListener(awaitChan(remoteUDPSent), remoteAddr), /* Private IP request */ @@ -70,7 +71,7 @@ var _ = Describe("Request handling", func() { When("receiving a request with a known sender endpoint", func() { It("should respond with OK", func() { - localListener.AddEndpoint(&remoteEndpoint) + localListener.AddEndpoint(&remoteEndpoint, k8snet.IPv4) response := requestResponseFromRemoteToLocal(&remoteUDPAddr) Expect(response[0].GetResponse()).To(Equal(natproto.ResponseType_OK)) Expect(response[1].GetResponse()).To(Equal(natproto.ResponseType_NAT_DETECTED)) @@ -82,7 +83,7 @@ var _ = Describe("Request handling", func() { Context("with a modified IP", func() { It("should respond with NAT_DETECTED and SrcIpNatDetected", func() { remoteUDPAddr.IP = net.ParseIP(testRemotePublicIP) - localListener.AddEndpoint(&remoteEndpoint) + localListener.AddEndpoint(&remoteEndpoint, k8snet.IPv4) response := requestResponseFromRemoteToLocal(&remoteUDPAddr) Expect(response[0].GetResponse()).To(Equal(natproto.ResponseType_NAT_DETECTED)) Expect(response[0].GetSrcIpNatDetected()).To(BeTrue()) @@ -93,7 +94,7 @@ var _ = Describe("Request handling", func() { Context("with a modified port", func() { It("should respond with NAT_DETECTED and SrcPortNatDetected", func() { remoteUDPAddr.Port = int(testRemoteNATPort + 1) - localListener.AddEndpoint(&remoteEndpoint) + localListener.AddEndpoint(&remoteEndpoint, k8snet.IPv4) response := requestResponseFromRemoteToLocal(&remoteUDPAddr) Expect(response[0].GetResponse()).To(Equal(natproto.ResponseType_NAT_DETECTED)) Expect(response[0].GetSrcIpNatDetected()).To(BeFalse()) @@ -104,7 +105,7 @@ var _ = Describe("Request handling", func() { When("receiving a request with an unknown receiver endpoint ID", func() { It("should respond with UNKNOWN_DST_ENDPOINT", func() { - localListener.AddEndpoint(&remoteEndpoint) + localListener.AddEndpoint(&remoteEndpoint, k8snet.IPv4) localEndpoint.Spec.CableName = "invalid" response := requestResponseFromRemoteToLocal(&remoteUDPAddr) Expect(response[0].GetResponse()).To(Equal(natproto.ResponseType_UNKNOWN_DST_ENDPOINT)) @@ -113,7 +114,7 @@ var _ = Describe("Request handling", func() { When("receiving a request with an unknown receiver cluster ID", func() { It("should respond with UNKNOWN_DST_CLUSTER", func() { - localListener.AddEndpoint(&remoteEndpoint) + localListener.AddEndpoint(&remoteEndpoint, k8snet.IPv4) localEndpoint.Spec.ClusterID = "invalid" response := requestResponseFromRemoteToLocal(&remoteUDPAddr) Expect(response[0].GetResponse()).To(Equal(natproto.ResponseType_UNKNOWN_DST_CLUSTER)) @@ -165,11 +166,11 @@ func createMalformedRequest(mangleFunction func(*natproto.SubmarinerNATDiscovery request := natproto.SubmarinerNATDiscoveryRequest{ RequestNumber: 1, Sender: &natproto.EndpointDetails{ - EndpointId: testRemoteEndpointName, + EndpointId: testRemoteEndpointNameAndFamily, ClusterId: testRemoteClusterID, }, Receiver: &natproto.EndpointDetails{ - EndpointId: testLocalEndpointName, + EndpointId: testLocalEndpointNameAndFamily, ClusterId: testLocalClusterID, }, UsingSrc: &natproto.IPPortPair{ diff --git a/pkg/natdiscovery/request_send.go b/pkg/natdiscovery/request_send.go index 5454ae8d9..7a6ed803d 100644 --- a/pkg/natdiscovery/request_send.go +++ b/pkg/natdiscovery/request_send.go @@ -32,15 +32,15 @@ func (nd *natDiscovery) sendCheckRequest(remoteNAT *remoteEndpointNAT) error { var errPrivate, errPublic error var reqID uint64 - if remoteNAT.endpoint.Spec.GetPrivateIP(k8snet.IPv4) != "" { - reqID, errPrivate = nd.sendCheckRequestToTargetIP(remoteNAT, remoteNAT.endpoint.Spec.GetPrivateIP(k8snet.IPv4)) + if remoteNAT.endpoint.Spec.GetPrivateIP(remoteNAT.family) != "" { + reqID, errPrivate = nd.sendCheckRequestToTargetIP(remoteNAT, remoteNAT.endpoint.Spec.GetPrivateIP(remoteNAT.family)) if errPrivate == nil { remoteNAT.lastPrivateIPRequestID = reqID } } if remoteNAT.endpoint.Spec.GetPublicIP(k8snet.IPv4) != "" { - reqID, errPublic = nd.sendCheckRequestToTargetIP(remoteNAT, remoteNAT.endpoint.Spec.GetPublicIP(k8snet.IPv4)) + reqID, errPublic = nd.sendCheckRequestToTargetIP(remoteNAT, remoteNAT.endpoint.Spec.GetPublicIP(remoteNAT.family)) if errPublic == nil { remoteNAT.lastPublicIPRequestID = reqID } @@ -70,7 +70,7 @@ func (nd *natDiscovery) sendCheckRequestToTargetIP(remoteNAT *remoteEndpointNAT, return 0, err } - sourceIP := nd.findSrcIP(targetIP) + sourceIP := nd.findSrcIP(targetIP, remoteNAT.family) nd.requestCounter++ @@ -79,11 +79,11 @@ func (nd *natDiscovery) sendCheckRequestToTargetIP(remoteNAT *remoteEndpointNAT, request := &natproto.SubmarinerNATDiscoveryRequest{ RequestNumber: nd.requestCounter, Sender: &natproto.EndpointDetails{ - EndpointId: localEndpointSpec.CableName, + EndpointId: localEndpointSpec.GetFamilyCableName(remoteNAT.family), ClusterId: localEndpointSpec.ClusterID, }, Receiver: &natproto.EndpointDetails{ - EndpointId: remoteNAT.endpoint.Spec.CableName, + EndpointId: remoteNAT.endpoint.Spec.GetFamilyCableName(remoteNAT.family), ClusterId: remoteNAT.endpoint.Spec.ClusterID, }, UsingSrc: &natproto.IPPortPair{ diff --git a/pkg/natdiscovery/request_send_internal_test.go b/pkg/natdiscovery/request_send_internal_test.go index a5ec446bb..52b1e46e2 100644 --- a/pkg/natdiscovery/request_send_internal_test.go +++ b/pkg/natdiscovery/request_send_internal_test.go @@ -23,6 +23,7 @@ import ( . "github.com/onsi/gomega" submarinerv1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" natproto "github.com/submariner-io/submariner/pkg/natdiscovery/proto" + k8snet "k8s.io/utils/net" ) var _ = When("a request is sent", func() { @@ -43,9 +44,9 @@ var _ = When("a request is sent", func() { JustBeforeEach(func() { ndInstance, udpSent, _ = createTestListener(&localEndpoint) - ndInstance.findSrcIP = func(_ string) string { return testLocalPrivateIP } + ndInstance.findSrcIP = func(_ string, _ k8snet.IPFamily) string { return testLocalPrivateIP } - err := ndInstance.sendCheckRequest(newRemoteEndpointNAT(&remoteEndpoint)) + err := ndInstance.sendCheckRequest(newRemoteEndpointNAT(&remoteEndpoint, k8snet.IPv4)) Expect(err).NotTo(HaveOccurred()) request = parseProtocolRequest(awaitChan(udpSent)) @@ -55,13 +56,13 @@ var _ = When("a request is sent", func() { It("should set the sender fields correctly", func() { Expect(request.GetSender()).NotTo(BeNil()) Expect(request.GetSender().GetClusterId()).To(Equal(testLocalClusterID)) - Expect(request.GetSender().GetEndpointId()).To(Equal(testLocalEndpointName)) + Expect(request.GetSender().GetEndpointId()).To(Equal(testLocalEndpointNameAndFamily)) }) It("should set the receiver fields correctly", func() { Expect(request.GetReceiver()).NotTo(BeNil()) Expect(request.GetReceiver().GetClusterId()).To(Equal(testRemoteClusterID)) - Expect(request.GetReceiver().GetEndpointId()).To(Equal(testRemoteEndpointName)) + Expect(request.GetReceiver().GetEndpointId()).To(Equal(testRemoteEndpointNameAndFamily)) }) It("should set the using source fields correctly", func() { diff --git a/pkg/natdiscovery/test_utils.go b/pkg/natdiscovery/test_utils.go index 79199d581..64b1fa991 100644 --- a/pkg/natdiscovery/test_utils.go +++ b/pkg/natdiscovery/test_utils.go @@ -32,19 +32,23 @@ import ( "google.golang.org/protobuf/proto" dynamicfake "k8s.io/client-go/dynamic/fake" "k8s.io/client-go/kubernetes/scheme" + k8snet "k8s.io/utils/net" ) const ( testLocalEndpointName = "cluster-a-ep-1" - testLocalClusterID = "cluster-a" - testLocalPublicIP = "10.1.1.1" - testLocalPrivateIP = "2.2.2.2" - - testRemoteEndpointName = "cluster-b-ep-1" - testRemoteClusterID = "cluster-b" - testRemotePublicIP = "10.3.3.3" - testRemotePrivateIP = "4.4.4.4" - testRemotePrivateIP2 = "5.5.5.5" + // endpointId as formated in GetFamilyCableName. + testLocalEndpointNameAndFamily = testLocalEndpointName + "-ipv" + string(k8snet.IPv4) + testLocalClusterID = "cluster-a" + testLocalPublicIP = "10.1.1.1" + testLocalPrivateIP = "2.2.2.2" + + testRemoteEndpointName = "cluster-b-ep-1" + testRemoteEndpointNameAndFamily = "cluster-b-ep-1" + "-ipv" + string(k8snet.IPv4) + testRemoteClusterID = "cluster-b" + testRemotePublicIP = "10.3.3.3" + testRemotePrivateIP = "4.4.4.4" + testRemotePrivateIP2 = "5.5.5.5" ) var (