diff --git a/pkg/apis/submariner.io/v1/endpoint.go b/pkg/apis/submariner.io/v1/endpoint.go index f78510622..6d1c2cad7 100644 --- a/pkg/apis/submariner.io/v1/endpoint.go +++ b/pkg/apis/submariner.io/v1/endpoint.go @@ -25,6 +25,7 @@ import ( "github.com/pkg/errors" "github.com/submariner-io/admiral/pkg/log" "github.com/submariner-io/admiral/pkg/resource" + "github.com/submariner-io/submariner/pkg/cidr" "k8s.io/apimachinery/pkg/api/equality" k8snet "k8s.io/utils/net" logf "sigs.k8s.io/controller-runtime/pkg/log" @@ -173,12 +174,8 @@ func (ep *EndpointSpec) SetPrivateIP(ip string) { ep.PrivateIPs, ep.PrivateIP = setIP(ep.PrivateIPs, ep.PrivateIP, ip) } -func (ep *EndpointSpec) GetIPFamilies() [2]k8snet.IPFamily { - var ipFamilies [2]k8snet.IPFamily - // TODO_IPV6: set ipFamilies according to Subnets content - ipFamilies[0] = k8snet.IPv4 - - return ipFamilies +func (ep *EndpointSpec) GetIPFamilies() []k8snet.IPFamily { + return cidr.ExtractIPFamilies(ep.Subnets) } func (ep *EndpointSpec) GetFamilyCableName(family k8snet.IPFamily) string { diff --git a/pkg/apis/submariner.io/v1/endpoint_test.go b/pkg/apis/submariner.io/v1/endpoint_test.go index a58871fe2..76024000e 100644 --- a/pkg/apis/submariner.io/v1/endpoint_test.go +++ b/pkg/apis/submariner.io/v1/endpoint_test.go @@ -28,6 +28,8 @@ import ( const ( ipV4Addr = "1.2.3.4" ipV6Addr = "2001:db8:3333:4444:5555:6666:7777:8888" + ipV4CIDR = "10.16.1.0/32" + ipV6CIDR = "2002::1234:abcd:ffff:c0a8:101/64" ) var _ = Describe("EndpointSpec", func() { @@ -40,6 +42,7 @@ var _ = Describe("EndpointSpec", func() { Context("GetPrivateIP", testGetPrivateIP) Context("SetPrivateIP", testSetPrivateIP) Context("GetFamilyCableName", testGetFamilyCableName) + Context("GetIPFamilies", testGetIPFamilies) }) func testGenerateName() { @@ -424,3 +427,12 @@ func testGetFamilyCableName() { }) }) } + +func testGetIPFamilies() { + It("should return the correct families", func() { + Expect((&v1.EndpointSpec{Subnets: []string{ipV4CIDR}}).GetIPFamilies()).To(Equal([]k8snet.IPFamily{k8snet.IPv4})) + Expect((&v1.EndpointSpec{Subnets: []string{ipV6CIDR}}).GetIPFamilies()).To(Equal([]k8snet.IPFamily{k8snet.IPv6})) + Expect((&v1.EndpointSpec{Subnets: []string{ipV6CIDR, ipV4CIDR}}).GetIPFamilies()).To( + Equal([]k8snet.IPFamily{k8snet.IPv6, k8snet.IPv4})) + }) +} diff --git a/pkg/cidr/cidr_test.go b/pkg/cidr/cidr_test.go new file mode 100644 index 000000000..f0b385d24 --- /dev/null +++ b/pkg/cidr/cidr_test.go @@ -0,0 +1,49 @@ +/* +SPDX-License-Identifier: Apache-2.0 + +Copyright Contributors to the Submariner project. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cidr_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/submariner-io/submariner/pkg/cidr" + k8snet "k8s.io/utils/net" +) + +func TestCidr(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "CIDR Suite") +} + +const ( + ipV4CIDR = "1.2.3.4/16" + ipV6CIDR = "2002::1234:abcd:ffff:c0a8:101/64" +) + +var _ = Describe("ExtractIPFamilies", func() { + It("should return the correct families", func() { + Expect(cidr.ExtractIPFamilies([]string{ipV4CIDR})).To(Equal([]k8snet.IPFamily{k8snet.IPv4})) + Expect(cidr.ExtractIPFamilies([]string{ipV6CIDR})).To(Equal([]k8snet.IPFamily{k8snet.IPv6})) + Expect(cidr.ExtractIPFamilies([]string{ipV4CIDR, ipV6CIDR})).To(Equal([]k8snet.IPFamily{k8snet.IPv4, k8snet.IPv6})) + Expect(cidr.ExtractIPFamilies([]string{ipV4CIDR, ipV4CIDR})).To(Equal([]k8snet.IPFamily{k8snet.IPv4})) + Expect(cidr.ExtractIPFamilies([]string{})).To(BeEmpty()) + Expect(cidr.ExtractIPFamilies([]string{"bogus"})).To(BeEmpty()) + }) +}) diff --git a/pkg/cidr/iputil.go b/pkg/cidr/iputil.go index 3cac33e24..5f30df5a7 100644 --- a/pkg/cidr/iputil.go +++ b/pkg/cidr/iputil.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" "github.com/submariner-io/admiral/pkg/log" + "github.com/submariner-io/admiral/pkg/slices" k8snet "k8s.io/utils/net" logf "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -98,3 +99,18 @@ func ExtractIPv4Subnets(cidrList []string) []string { return ipv4Cidrs } + +func ExtractIPFamilies(fromCIDRs []string) []k8snet.IPFamily { + var ipFamilies []k8snet.IPFamily + + for _, cidr := range fromCIDRs { + f := k8snet.IPFamilyOfCIDRString(cidr) + if f != k8snet.IPFamilyUnknown { + ipFamilies, _ = slices.AppendIfNotPresent(ipFamilies, f, func(e k8snet.IPFamily) k8snet.IPFamily { + return e + }) + } + } + + return ipFamilies +} diff --git a/pkg/controllers/tunnel/tunnel.go b/pkg/controllers/tunnel/tunnel.go index d87176e37..33c070733 100644 --- a/pkg/controllers/tunnel/tunnel.go +++ b/pkg/controllers/tunnel/tunnel.go @@ -23,6 +23,7 @@ import ( "github.com/pkg/errors" "github.com/submariner-io/admiral/pkg/log" + "github.com/submariner-io/admiral/pkg/slices" "github.com/submariner-io/admiral/pkg/watcher" v1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" "github.com/submariner-io/submariner/pkg/cableengine" @@ -33,24 +34,15 @@ import ( type controller struct { engine cableengine.Engine - localIPFamilies [2]k8snet.IPFamily + localIPFamilies []k8snet.IPFamily } var logger = log.Logger{Logger: logf.Log.WithName("Tunnel")} -func findCommonIPFamilies(local, remote [2]k8snet.IPFamily) []k8snet.IPFamily { - common := []k8snet.IPFamily{} - - for _, lf := range local { - for _, rf := range remote { - if lf == rf { - common = append(common, lf) - break - } - } - } - - return common +func findCommonIPFamilies(local, remote []k8snet.IPFamily) []k8snet.IPFamily { + return slices.Intersect(local, remote, func(e k8snet.IPFamily) k8snet.IPFamily { + return e + }) } func StartController(engine cableengine.Engine, namespace string, config *watcher.Config, stopCh <-chan struct{}) error { diff --git a/pkg/controllers/tunnel/tunnel_test.go b/pkg/controllers/tunnel/tunnel_test.go index a77359e13..6829dfabf 100644 --- a/pkg/controllers/tunnel/tunnel_test.go +++ b/pkg/controllers/tunnel/tunnel_test.go @@ -47,6 +47,7 @@ import ( const ( namespace = "submariner" + ipV4CIDR = "1.2.3.4/16" ) func init() { @@ -64,15 +65,21 @@ var _ = BeforeSuite(func() { var _ = Describe("Managing tunnels", func() { var ( - config *watcher.Config - endpoints dynamic.ResourceInterface - endpoint *v1.Endpoint - stopCh chan struct{} + config *watcher.Config + endpoints dynamic.ResourceInterface + endpoint *v1.Endpoint + localEPSpec *v1.EndpointSpec + stopCh chan struct{} ) BeforeEach(func() { fakeDriver = fake.New() + localEPSpec = &v1.EndpointSpec{ + Backend: fake.DriverName, + Subnets: []string{ipV4CIDR}, + } + endpoint = &v1.Endpoint{ ObjectMeta: metav1.ObjectMeta{ Name: "east-submariner-cable-east-192-68-1-1", @@ -83,6 +90,7 @@ var _ = Describe("Managing tunnels", func() { ClusterID: "east", Hostname: "redsox", PrivateIPs: []string{"192.68.1.2"}, + Subnets: []string{ipV4CIDR}, }, } @@ -106,9 +114,7 @@ var _ = Describe("Managing tunnels", func() { }) JustBeforeEach(func() { - localEp := submendpoint.NewLocal(&v1.EndpointSpec{ - Backend: fake.DriverName, - }, fakeClient.NewSimpleDynamicClient(kubeScheme.Scheme), "") + localEp := submendpoint.NewLocal(localEPSpec, fakeClient.NewSimpleDynamicClient(kubeScheme.Scheme), "") engine := cableengine.NewEngine(&types.SubmarinerCluster{}, localEp) diff --git a/pkg/gateway/gateway_test.go b/pkg/gateway/gateway_test.go index 2beb74e3a..c7a407514 100644 --- a/pkg/gateway/gateway_test.go +++ b/pkg/gateway/gateway_test.go @@ -122,6 +122,7 @@ var _ = Describe("Run", func() { endpoint := t.awaitRemoteEndpointSyncedLocal(t.createRemoteEndpointOnBroker()) fakeDriver.AwaitConnectToEndpoint(&natdiscovery.NATEndpointInfo{ Endpoint: *endpoint, + Family: k8snet.IPv4, }) By("Setting leases resource updates to fail") @@ -173,6 +174,7 @@ var _ = Describe("Run", func() { endpoint2 := t.awaitRemoteEndpointSyncedLocal(brokerEndpoint) fakeDriver.AwaitConnectToEndpoint(&natdiscovery.NATEndpointInfo{ Endpoint: *endpoint2, + Family: k8snet.IPv4, }) fakeDriver.AwaitDisconnectFromEndpoint(&endpoint.Spec) @@ -488,9 +490,10 @@ func (n *fakeNATDiscovery) Run(_ <-chan struct{}) error { return nil } -func (n *fakeNATDiscovery) AddEndpoint(ep *submarinerv1.Endpoint, _ k8snet.IPFamily) { +func (n *fakeNATDiscovery) AddEndpoint(ep *submarinerv1.Endpoint, family k8snet.IPFamily) { n.readyChannel <- &natdiscovery.NATEndpointInfo{ Endpoint: *ep, + Family: family, } }