diff --git a/config/config.go b/config/config.go index 0a68736f64..e1ce654562 100644 --- a/config/config.go +++ b/config/config.go @@ -24,12 +24,11 @@ import ( blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" routed "github.com/libp2p/go-libp2p/p2p/host/routed" "github.com/libp2p/go-libp2p/p2p/net/swarm" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" - tptu "github.com/libp2p/go-libp2p-transport-upgrader" - logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" diff --git a/go.mod b/go.mod index 524441497b..9ff84efedd 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/ipfs/go-datastore v0.5.1 github.com/ipfs/go-ipfs-util v0.0.2 github.com/ipfs/go-log/v2 v2.5.1 + github.com/jbenet/go-temp-err-catcher v0.1.0 github.com/klauspost/compress v1.15.1 github.com/libp2p/go-buffer-pool v0.0.2 github.com/libp2p/go-eventbus v0.2.1 @@ -22,10 +23,10 @@ require ( github.com/libp2p/go-libp2p-nat v0.1.0 github.com/libp2p/go-libp2p-noise v0.4.0 github.com/libp2p/go-libp2p-peerstore v0.6.0 + github.com/libp2p/go-libp2p-pnet v0.2.0 github.com/libp2p/go-libp2p-resource-manager v0.2.1 github.com/libp2p/go-libp2p-testing v0.9.2 github.com/libp2p/go-libp2p-tls v0.4.1 - github.com/libp2p/go-libp2p-transport-upgrader v0.7.1 github.com/libp2p/go-mplex v0.7.0 github.com/libp2p/go-msgio v0.2.0 github.com/libp2p/go-netroute v0.2.0 @@ -74,16 +75,15 @@ require ( github.com/google/uuid v1.3.0 // indirect github.com/huin/goupnp v1.0.3 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect - github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect github.com/jbenet/goprocess v0.1.4 // indirect github.com/klauspost/cpuid/v2 v2.0.12 // indirect github.com/koron/go-ssdp v0.0.2 // indirect github.com/libp2p/go-cidranger v1.1.0 // indirect github.com/libp2p/go-flow-metrics v0.0.3 // indirect github.com/libp2p/go-libp2p-blankhost v0.3.0 // indirect - github.com/libp2p/go-libp2p-pnet v0.2.0 // indirect github.com/libp2p/go-libp2p-quic-transport v0.17.0 // indirect github.com/libp2p/go-libp2p-swarm v0.10.2 // indirect + github.com/libp2p/go-libp2p-transport-upgrader v0.7.1 // indirect github.com/libp2p/go-libp2p-yamux v0.9.1 // indirect github.com/libp2p/go-nat v0.1.0 // indirect github.com/libp2p/go-openssl v0.0.7 // indirect diff --git a/go.sum b/go.sum index 3de3c758a2..dc599a27ed 100644 --- a/go.sum +++ b/go.sum @@ -443,7 +443,6 @@ github.com/libp2p/go-libp2p-core v0.14.0/go.mod h1:tLasfcVdTXnixsLB0QYaT1syJOhsb github.com/libp2p/go-libp2p-core v0.15.1 h1:0RY+Mi/ARK9DgG1g9xVQLb8dDaaU8tCePMtGALEfBnM= github.com/libp2p/go-libp2p-core v0.15.1/go.mod h1:agSaboYM4hzB1cWekgVReqV5M4g5M+2eNNejV+1EEhs= github.com/libp2p/go-libp2p-mplex v0.4.1/go.mod h1:cmy+3GfqfM1PceHTLL7zQzAAYaryDu6iPSC+CIb094g= -github.com/libp2p/go-libp2p-mplex v0.5.0 h1:vt3k4E4HSND9XH4Z8rUpacPJFSAgLOv6HDvG8W9Ks9E= github.com/libp2p/go-libp2p-mplex v0.5.0/go.mod h1:eLImPJLkj3iG5t5lq68w3Vm5NAQ5BcKwrrb2VmOYb3M= github.com/libp2p/go-libp2p-nat v0.1.0 h1:vigUi2MEN+fwghe5ijpScxtbbDz+L/6y8XwlzYOJgSY= github.com/libp2p/go-libp2p-nat v0.1.0/go.mod h1:DQzAG+QbDYjN1/C3B6vXucLtz3u9rEonLVPtZVzQqks= diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 0219a63a85..c404231270 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/tcp" @@ -19,7 +20,6 @@ import ( "github.com/libp2p/go-libp2p-peerstore/pstoremem" tnet "github.com/libp2p/go-libp2p-testing/net" - tptu "github.com/libp2p/go-libp2p-transport-upgrader" msmux "github.com/libp2p/go-stream-muxer-multistream" ma "github.com/multiformats/go-multiaddr" diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 6e56112f97..c1752dc871 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" "github.com/libp2p/go-libp2p/p2p/net/swarm" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/tcp" @@ -22,7 +23,6 @@ import ( "github.com/libp2p/go-libp2p-peerstore/pstoremem" tnet "github.com/libp2p/go-libp2p-testing/net" - tptu "github.com/libp2p/go-libp2p-transport-upgrader" msmux "github.com/libp2p/go-stream-muxer-multistream" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go new file mode 100644 index 0000000000..9aa3b2f2e9 --- /dev/null +++ b/p2p/net/upgrader/conn.go @@ -0,0 +1,52 @@ +package upgrader + +import ( + "fmt" + + "github.com/libp2p/go-libp2p-core/mux" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/transport" +) + +type transportConn struct { + mux.MuxedConn + network.ConnMultiaddrs + network.ConnSecurity + transport transport.Transport + scope network.ConnManagementScope + stat network.ConnStats +} + +var _ transport.CapableConn = &transportConn{} + +func (t *transportConn) Transport() transport.Transport { + return t.transport +} + +func (t *transportConn) String() string { + ts := "" + if s, ok := t.transport.(fmt.Stringer); ok { + ts = "[" + s.String() + "]" + } + return fmt.Sprintf( + " %s (%s)>", + ts, + t.LocalMultiaddr(), + t.LocalPeer(), + t.RemoteMultiaddr(), + t.RemotePeer(), + ) +} + +func (t *transportConn) Stat() network.ConnStats { + return t.stat +} + +func (t *transportConn) Scope() network.ConnScope { + return t.scope +} + +func (t *transportConn) Close() error { + defer t.scope.Done() + return t.MuxedConn.Close() +} diff --git a/p2p/net/upgrader/gater_test.go b/p2p/net/upgrader/gater_test.go new file mode 100644 index 0000000000..2d6b889058 --- /dev/null +++ b/p2p/net/upgrader/gater_test.go @@ -0,0 +1,60 @@ +package upgrader_test + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/control" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" +) + +type testGater struct { + sync.Mutex + + blockAccept, blockSecured bool +} + +var _ connmgr.ConnectionGater = (*testGater)(nil) + +func (t *testGater) BlockAccept(block bool) { + t.Lock() + defer t.Unlock() + + t.blockAccept = block +} + +func (t *testGater) BlockSecured(block bool) { + t.Lock() + defer t.Unlock() + + t.blockSecured = block +} + +func (t *testGater) InterceptPeerDial(p peer.ID) (allow bool) { + panic("not implemented") +} + +func (t *testGater) InterceptAddrDial(id peer.ID, multiaddr ma.Multiaddr) (allow bool) { + panic("not implemented") +} + +func (t *testGater) InterceptAccept(multiaddrs network.ConnMultiaddrs) (allow bool) { + t.Lock() + defer t.Unlock() + + return !t.blockAccept +} + +func (t *testGater) InterceptSecured(direction network.Direction, id peer.ID, multiaddrs network.ConnMultiaddrs) (allow bool) { + t.Lock() + defer t.Unlock() + + return !t.blockSecured +} + +func (t *testGater) InterceptUpgraded(conn network.Conn) (allow bool, reason control.DisconnectReason) { + panic("not implemented") +} diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go new file mode 100644 index 0000000000..fa94e8e6fc --- /dev/null +++ b/p2p/net/upgrader/listener.go @@ -0,0 +1,178 @@ +package upgrader + +import ( + "context" + "fmt" + "sync" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/transport" + + logging "github.com/ipfs/go-log/v2" + tec "github.com/jbenet/go-temp-err-catcher" + manet "github.com/multiformats/go-multiaddr/net" +) + +var log = logging.Logger("upgrader") + +type listener struct { + manet.Listener + + transport transport.Transport + upgrader *upgrader + rcmgr network.ResourceManager + + incoming chan transport.CapableConn + err error + + // Used for backpressure + threshold *threshold + + // Canceling this context isn't sufficient to tear down the listener. + // Call close. + ctx context.Context + cancel func() +} + +// Close closes the listener. +func (l *listener) Close() error { + // Do this first to try to get any relevent errors. + err := l.Listener.Close() + + l.cancel() + // Drain and wait. + for c := range l.incoming { + c.Close() + } + return err +} + +// handles inbound connections. +// +// This function does a few interesting things that should be noted: +// +// 1. It logs and discards temporary/transient errors (errors with a Temporary() +// function that returns true). +// 2. It stops accepting new connections once AcceptQueueLength connections have +// been fully negotiated but not accepted. This gives us a basic backpressure +// mechanism while still allowing us to negotiate connections in parallel. +func (l *listener) handleIncoming() { + var wg sync.WaitGroup + defer func() { + // make sure we're closed + l.Listener.Close() + if l.err == nil { + l.err = fmt.Errorf("listener closed") + } + + wg.Wait() + close(l.incoming) + }() + + var catcher tec.TempErrCatcher + for l.ctx.Err() == nil { + maconn, err := l.Listener.Accept() + if err != nil { + // Note: function may pause the accept loop. + if catcher.IsTemporary(err) { + log.Infof("temporary accept error: %s", err) + continue + } + l.err = err + return + } + catcher.Reset() + + // gate the connection if applicable + if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + if err := maconn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + + connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := maconn.Close(); err != nil { + log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + } + continue + } + + // The go routine below calls Release when the context is + // canceled so there's no need to wait on it here. + l.threshold.Wait() + + log.Debugf("listener %s got connection: %s <---> %s", + l, + maconn.LocalMultiaddr(), + maconn.RemoteMultiaddr()) + + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithTimeout(l.ctx, l.upgrader.acceptTimeout) + defer cancel() + + conn, err := l.upgrader.Upgrade(ctx, l.transport, maconn, network.DirInbound, "", connScope) + if err != nil { + // Don't bother bubbling this up. We just failed + // to completely negotiate the connection. + log.Debugf("accept upgrade error: %s (%s <--> %s)", + err, + maconn.LocalMultiaddr(), + maconn.RemoteMultiaddr()) + connScope.Done() + return + } + + log.Debugf("listener %s accepted connection: %s", l, conn) + + // This records the fact that the connection has been + // setup and is waiting to be accepted. This call + // *never* blocks, even if we go over the threshold. It + // simply ensures that calls to Wait block while we're + // over the threshold. + l.threshold.Acquire() + defer l.threshold.Release() + + select { + case l.incoming <- conn: + case <-ctx.Done(): + if l.ctx.Err() == nil { + // Listener *not* closed but the accept timeout expired. + log.Warn("listener dropped connection due to slow accept") + } + // Wait on the context with a timeout. This way, + // if we stop accepting connections for some reason, + // we'll eventually close all the open ones + // instead of hanging onto them. + conn.Close() + } + }() + } +} + +// Accept accepts a connection. +func (l *listener) Accept() (transport.CapableConn, error) { + for c := range l.incoming { + // Could have been sitting there for a while. + if !c.IsClosed() { + return c, nil + } + } + return nil, l.err +} + +func (l *listener) String() string { + if s, ok := l.transport.(fmt.Stringer); ok { + return fmt.Sprintf("", s, l.Multiaddr()) + } + return fmt.Sprintf("", l.Multiaddr()) +} + +var _ transport.Listener = (*listener)(nil) diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go new file mode 100644 index 0000000000..1ea80d68dd --- /dev/null +++ b/p2p/net/upgrader/listener_test.go @@ -0,0 +1,405 @@ +package upgrader_test + +import ( + "context" + "errors" + "io" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/p2p/net/upgrader" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec" + "github.com/libp2p/go-libp2p-core/transport" + + mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +type MuxAdapter struct { + tpt sec.SecureTransport +} + +var _ sec.SecureMuxer = &MuxAdapter{} + +func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureInbound(ctx, insecure, p) + return sconn, true, err +} + +func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureOutbound(ctx, insecure, p) + return sconn, false, err +} + +func createListener(t *testing.T, u transport.Upgrader) transport.Listener { + t.Helper() + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0") + require.NoError(t, err) + ln, err := manet.Listen(addr) + require.NoError(t, err) + return u.UpgradeListener(nil, ln) +} + +func TestAcceptSingleConn(t *testing.T) { + require := require.New(t) + + id, u := createUpgrader(t) + ln := createListener(t, u) + defer ln.Close() + + cconn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + + sconn, err := ln.Accept() + require.NoError(err) + + testConn(t, cconn, sconn) +} + +func TestAcceptMultipleConns(t *testing.T) { + require := require.New(t) + + id, u := createUpgrader(t) + ln := createListener(t, u) + defer ln.Close() + + var toClose []io.Closer + defer func() { + for _, c := range toClose { + _ = c.Close() + } + }() + + for i := 0; i < 10; i++ { + cconn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + toClose = append(toClose, cconn) + + sconn, err := ln.Accept() + require.NoError(err) + toClose = append(toClose, sconn) + + testConn(t, cconn, sconn) + } +} + +func TestConnectionsClosedIfNotAccepted(t *testing.T) { + require := require.New(t) + + var timeout = 100 * time.Millisecond + if os.Getenv("CI") != "" { + timeout = 500 * time.Millisecond + } + + id, u := createUpgrader(t, upgrader.WithAcceptTimeout(timeout)) + ln := createListener(t, u) + defer ln.Close() + + conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + + errCh := make(chan error) + go func() { + defer conn.Close() + str, err := conn.OpenStream(context.Background()) + if err != nil { + errCh <- err + return + } + // start a Read. It will block until the connection is closed + _, _ = str.Read([]byte{0}) + errCh <- nil + }() + + time.Sleep(timeout / 2) + select { + case err := <-errCh: + t.Fatalf("connection closed earlier than expected. expected nothing on channel, got: %v", err) + default: + } + + time.Sleep(timeout) + require.Nil(<-errCh) +} + +func TestFailedUpgradeOnListen(t *testing.T) { + require := require.New(t) + + id, u := createUpgraderWithMuxer(t, &errorMuxer{}) + ln := createListener(t, u) + + errCh := make(chan error) + go func() { + _, err := ln.Accept() + errCh <- err + }() + + _, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.Error(err) + + // close the listener. + ln.Close() + require.Error(<-errCh) +} + +func TestListenerClose(t *testing.T) { + require := require.New(t) + + _, u := createUpgrader(t) + ln := createListener(t, u) + + errCh := make(chan error) + go func() { + _, err := ln.Accept() + errCh <- err + }() + + select { + case err := <-errCh: + t.Fatalf("connection closed earlier than expected. expected nothing on channel, got: %v", err) + case <-time.After(200 * time.Millisecond): + // nothing in 200ms. + } + + // unblocks Accept when it is closed. + require.NoError(ln.Close()) + err := <-errCh + require.Error(err) + require.Contains(err.Error(), "use of closed network connection") + + // doesn't accept new connections when it is closed + _, err = dial(t, u, ln.Multiaddr(), peer.ID("1"), network.NullScope) + require.Error(err) +} + +func TestListenerCloseClosesQueued(t *testing.T) { + require := require.New(t) + + id, upgrader := createUpgrader(t) + ln := createListener(t, upgrader) + + var conns []transport.CapableConn + for i := 0; i < 10; i++ { + conn, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + conns = append(conns, conn) + } + + // wait for all the dials to happen. + time.Sleep(500 * time.Millisecond) + + // all the connections are opened. + for _, c := range conns { + require.False(c.IsClosed()) + } + + // expect that all the connections will be closed. + err := ln.Close() + require.NoError(err) + + // all the connections are closed. + require.Eventually(func() bool { + for _, c := range conns { + if !c.IsClosed() { + return false + } + } + return true + }, 3*time.Second, 100*time.Millisecond) + + for _, c := range conns { + _ = c.Close() + } +} + +func TestConcurrentAccept(t *testing.T) { + var num = 3 * upgrader.AcceptQueueLength + + blockingMuxer := newBlockingMuxer() + id, u := createUpgraderWithMuxer(t, blockingMuxer) + ln := createListener(t, u) + defer ln.Close() + + accepted := make(chan transport.CapableConn, num) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + _ = conn.Close() + accepted <- conn + } + }() + + // start num dials, which all block while setting up the muxer + errCh := make(chan error, num) + var wg sync.WaitGroup + for i := 0; i < num; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + if err != nil { + errCh <- err + return + } + defer conn.Close() + + _, err = conn.AcceptStream() // wait for conn to be accepted. + errCh <- err + }() + } + + time.Sleep(200 * time.Millisecond) + // the dials are still blocked, so we shouldn't have any connection available yet + require.Empty(t, accepted) + blockingMuxer.Unblock() // make all dials succeed + require.Eventually(t, func() bool { return len(accepted) == num }, 3*time.Second, 100*time.Millisecond) + wg.Wait() +} + +func TestAcceptQueueBacklogged(t *testing.T) { + require := require.New(t) + + id, u := createUpgrader(t) + ln := createListener(t, u) + defer ln.Close() + + // setup AcceptQueueLength connections, but don't accept any of them + var counter int32 // to be used atomically + doDial := func() { + conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + atomic.AddInt32(&counter, 1) + t.Cleanup(func() { conn.Close() }) + } + + for i := 0; i < upgrader.AcceptQueueLength; i++ { + go doDial() + } + + require.Eventually(func() bool { return int(atomic.LoadInt32(&counter)) == upgrader.AcceptQueueLength }, 2*time.Second, 50*time.Millisecond) + + // dial a new connection. This connection should not complete setup, since the queue is full + go doDial() + + time.Sleep(100 * time.Millisecond) + require.Equal(int(atomic.LoadInt32(&counter)), upgrader.AcceptQueueLength) + + // accept a single connection. Now the new connection should be set up, and fill the queue again + conn, err := ln.Accept() + require.NoError(err) + require.NoError(conn.Close()) + + require.Eventually(func() bool { return int(atomic.LoadInt32(&counter)) == upgrader.AcceptQueueLength+1 }, 2*time.Second, 50*time.Millisecond) +} + +func TestListenerConnectionGater(t *testing.T) { + require := require.New(t) + + testGater := &testGater{} + id, u := createUpgrader(t, upgrader.WithConnectionGater(testGater)) + + ln := createListener(t, u) + defer ln.Close() + + // no gating. + conn, err := dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + require.False(conn.IsClosed()) + _ = conn.Close() + + // rejecting after handshake. + testGater.BlockSecured(true) + testGater.BlockAccept(false) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + require.Error(err) + require.Nil(conn) + + // rejecting on accept will trigger firupgrader. + testGater.BlockSecured(true) + testGater.BlockAccept(true) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + require.Error(err) + require.Nil(conn) + + // rejecting only on acceptance. + testGater.BlockSecured(false) + testGater.BlockAccept(true) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", network.NullScope) + require.Error(err) + require.Nil(conn) + + // back to normal + testGater.BlockSecured(false) + testGater.BlockAccept(false) + conn, err = dial(t, u, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + require.False(conn.IsClosed()) + _ = conn.Close() +} + +func TestListenerResourceManagement(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + id, upgrader := createUpgrader(t, upgrader.WithResourceManager(rcmgr)) + ln := createListener(t, upgrader) + defer ln.Close() + + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + gomock.InOrder( + rcmgr.EXPECT().OpenConnection(network.DirInbound, true).Return(connScope, nil), + connScope.EXPECT().PeerScope(), + connScope.EXPECT().SetPeer(id), + connScope.EXPECT().PeerScope(), + ) + + cconn, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + require.NoError(t, err) + defer cconn.Close() + + sconn, err := ln.Accept() + require.NoError(t, err) + connScope.EXPECT().Done() + defer sconn.Close() +} + +func TestListenerResourceManagementDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + id, upgrader := createUpgrader(t, upgrader.WithResourceManager(rcmgr)) + ln := createListener(t, upgrader) + + rcmgr.EXPECT().OpenConnection(network.DirInbound, true).Return(nil, errors.New("nope")) + _, err := dial(t, upgrader, ln.Multiaddr(), id, network.NullScope) + require.Error(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + ln.Accept() + }() + + select { + case <-done: + t.Fatal("accept shouldn't have accepted anything") + case <-time.After(50 * time.Millisecond): + } + require.NoError(t, ln.Close()) + <-done +} diff --git a/p2p/net/upgrader/threshold.go b/p2p/net/upgrader/threshold.go new file mode 100644 index 0000000000..1e8b112cb8 --- /dev/null +++ b/p2p/net/upgrader/threshold.go @@ -0,0 +1,50 @@ +package upgrader + +import ( + "sync" +) + +func newThreshold(cutoff int) *threshold { + t := &threshold{ + threshold: cutoff, + } + t.cond.L = &t.mu + return t +} + +type threshold struct { + mu sync.Mutex + cond sync.Cond + + count int + threshold int +} + +// Acquire increments the counter. It will not block. +func (t *threshold) Acquire() { + t.mu.Lock() + t.count++ + t.mu.Unlock() +} + +// Release decrements the counter. +func (t *threshold) Release() { + t.mu.Lock() + if t.count == 0 { + panic("negative count") + } + if t.threshold == t.count { + t.cond.Broadcast() + } + t.count-- + t.mu.Unlock() +} + +// Wait waits for the counter to drop below the threshold +func (t *threshold) Wait() { + t.mu.Lock() + for t.count >= t.threshold { + t.cond.Wait() + } + t.mu.Unlock() +} diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go new file mode 100644 index 0000000000..3ea71bf777 --- /dev/null +++ b/p2p/net/upgrader/upgrader.go @@ -0,0 +1,218 @@ +package upgrader + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + ipnet "github.com/libp2p/go-libp2p-core/pnet" + "github.com/libp2p/go-libp2p-core/sec" + "github.com/libp2p/go-libp2p-core/transport" + + pnet "github.com/libp2p/go-libp2p-pnet" + manet "github.com/multiformats/go-multiaddr/net" +) + +// ErrNilPeer is returned when attempting to upgrade an outbound connection +// without specifying a peer ID. +var ErrNilPeer = errors.New("nil peer") + +// AcceptQueueLength is the number of connections to fully setup before not accepting any new connections +var AcceptQueueLength = 16 + +const defaultAcceptTimeout = 15 * time.Second + +type Option func(*upgrader) error + +func WithPSK(psk ipnet.PSK) Option { + return func(u *upgrader) error { + u.psk = psk + return nil + } +} + +func WithAcceptTimeout(t time.Duration) Option { + return func(u *upgrader) error { + u.acceptTimeout = t + return nil + } +} + +func WithConnectionGater(g connmgr.ConnectionGater) Option { + return func(u *upgrader) error { + u.connGater = g + return nil + } +} + +func WithResourceManager(m network.ResourceManager) Option { + return func(u *upgrader) error { + u.rcmgr = m + return nil + } +} + +// Upgrader is a multistream upgrader that can upgrade an underlying connection +// to a full transport connection (secure and multiplexed). +type upgrader struct { + secure sec.SecureMuxer + muxer network.Multiplexer + + psk ipnet.PSK + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + + // AcceptTimeout is the maximum duration an Accept is allowed to take. + // This includes the time between accepting the raw network connection, + // protocol selection as well as the handshake, if applicable. + // + // If unset, the default value (15s) is used. + acceptTimeout time.Duration +} + +var _ transport.Upgrader = &upgrader{} + +func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) (transport.Upgrader, error) { + u := &upgrader{ + secure: secureMuxer, + muxer: muxer, + acceptTimeout: defaultAcceptTimeout, + } + for _, opt := range opts { + if err := opt(u); err != nil { + return nil, err + } + } + if u.rcmgr == nil { + u.rcmgr = network.NullResourceManager + } + return u, nil +} + +// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener. +func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener { + ctx, cancel := context.WithCancel(context.Background()) + l := &listener{ + Listener: list, + upgrader: u, + transport: t, + rcmgr: u.rcmgr, + threshold: newThreshold(AcceptQueueLength), + incoming: make(chan transport.CapableConn), + cancel: cancel, + ctx: ctx, + } + go l.handleIncoming() + return l +} + +// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. +func (u *upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { + c, err := u.upgrade(ctx, t, maconn, dir, p, connScope) + if err != nil { + connScope.Done() + return nil, err + } + return c, nil +} + +func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { + if dir == network.DirOutbound && p == "" { + return nil, ErrNilPeer + } + var stat network.ConnStats + if cs, ok := maconn.(network.ConnStat); ok { + stat = cs.Stat() + } + + var conn net.Conn = maconn + if u.psk != nil { + pconn, err := pnet.NewProtectedConn(u.psk, conn) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to setup private network protector: %s", err) + } + conn = pconn + } else if ipnet.ForcePrivateNetwork { + log.Error("tried to dial with no Private Network Protector but usage of Private Networks is forced by the environment") + return nil, ipnet.ErrNotInPrivateNetwork + } + + sconn, server, err := u.setupSecurity(ctx, conn, p, dir) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to negotiate security protocol: %s", err) + } + + // call the connection gater, if one is registered. + if u.connGater != nil && !u.connGater.InterceptSecured(dir, sconn.RemotePeer(), maconn) { + if err := maconn.Close(); err != nil { + log.Errorw("failed to close connection", "peer", p, "addr", maconn.RemoteMultiaddr(), "error", err) + } + return nil, fmt.Errorf("gater rejected connection with peer %s and addr %s with direction %d", + sconn.RemotePeer().Pretty(), maconn.RemoteMultiaddr(), dir) + } + // Only call SetPeer if it hasn't already been set -- this can happen when we don't know + // the peer in advance and in some bug scenarios. + if connScope.PeerScope() == nil { + if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { + log.Debugw("resource manager blocked connection for peer", "peer", sconn.RemotePeer(), "addr", conn.RemoteAddr(), "error", err) + if err := maconn.Close(); err != nil { + log.Errorw("failed to close connection", "peer", p, "addr", maconn.RemoteMultiaddr(), "error", err) + } + return nil, fmt.Errorf("resource manager connection with peer %s and addr %s with direction %d", + sconn.RemotePeer().Pretty(), maconn.RemoteMultiaddr(), dir) + } + } + + smconn, err := u.setupMuxer(ctx, sconn, server, connScope.PeerScope()) + if err != nil { + sconn.Close() + return nil, fmt.Errorf("failed to negotiate stream multiplexer: %s", err) + } + + tc := &transportConn{ + MuxedConn: smconn, + ConnMultiaddrs: maconn, + ConnSecurity: sconn, + transport: t, + stat: stat, + scope: connScope, + } + return tc, nil +} + +func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) { + if dir == network.DirInbound { + return u.secure.SecureInbound(ctx, conn, p) + } + return u.secure.SecureOutbound(ctx, conn, p) +} + +func (u *upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool, scope network.PeerScope) (network.MuxedConn, error) { + // TODO: The muxer should take a context. + done := make(chan struct{}) + + var smconn network.MuxedConn + var err error + go func() { + defer close(done) + smconn, err = u.muxer.NewConn(conn, server, scope) + }() + + select { + case <-done: + return smconn, err + case <-ctx.Done(): + // interrupt this process + conn.Close() + // wait to finish + <-done + return nil, ctx.Err() + } +} diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go new file mode 100644 index 0000000000..be201eca60 --- /dev/null +++ b/p2p/net/upgrader/upgrader_test.go @@ -0,0 +1,191 @@ +package upgrader_test + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/libp2p/go-libp2p/p2p/muxer/yamux" + "github.com/libp2p/go-libp2p/p2p/net/upgrader" + + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec/insecure" + "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-libp2p-core/transport" + + mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" + + "github.com/golang/mock/gomock" + ma "github.com/multiformats/go-multiaddr" + + manet "github.com/multiformats/go-multiaddr/net" + "github.com/stretchr/testify/require" +) + +func createUpgrader(t *testing.T, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { + return createUpgraderWithMuxer(t, &negotiatingMuxer{}, opts...) +} + +func createUpgraderWithMuxer(t *testing.T, muxer network.Multiplexer, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { + priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(id, priv)}, muxer, opts...) + require.NoError(t, err) + return id, u +} + +// negotiatingMuxer sets up a new yamux connection +// It makes sure that this happens at the same time for client and server. +type negotiatingMuxer struct{} + +func (m *negotiatingMuxer) NewConn(c net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) { + var err error + // run a fake muxer negotiation + if isServer { + _, err = c.Write([]byte("setup")) + } else { + _, err = c.Read(make([]byte, 5)) + } + if err != nil { + return nil, err + } + return yamux.DefaultTransport.NewConn(c, isServer, scope) +} + +// blockingMuxer blocks the muxer negotiation until the contain chan is closed +type blockingMuxer struct { + unblock chan struct{} +} + +var _ network.Multiplexer = &blockingMuxer{} + +func newBlockingMuxer() *blockingMuxer { + return &blockingMuxer{unblock: make(chan struct{})} +} + +func (m *blockingMuxer) NewConn(c net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) { + <-m.unblock + return (&negotiatingMuxer{}).NewConn(c, isServer, scope) +} + +func (m *blockingMuxer) Unblock() { + close(m.unblock) +} + +// errorMuxer is a muxer that errors while setting up +type errorMuxer struct{} + +var _ network.Multiplexer = &errorMuxer{} + +func (m *errorMuxer) NewConn(c net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) { + return nil, errors.New("mux error") +} + +func testConn(t *testing.T, clientConn, serverConn transport.CapableConn) { + t.Helper() + require := require.New(t) + + cstr, err := clientConn.OpenStream(context.Background()) + require.NoError(err) + + _, err = cstr.Write([]byte("foobar")) + require.NoError(err) + + sstr, err := serverConn.AcceptStream() + require.NoError(err) + + b := make([]byte, 6) + _, err = sstr.Read(b) + require.NoError(err) + require.Equal([]byte("foobar"), b) +} + +func dial(t *testing.T, upgrader transport.Upgrader, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (transport.CapableConn, error) { + t.Helper() + + macon, err := manet.Dial(raddr) + if err != nil { + return nil, err + } + return upgrader.Upgrade(context.Background(), nil, macon, network.DirOutbound, p, scope) +} + +func TestOutboundConnectionGating(t *testing.T) { + require := require.New(t) + + id, u := createUpgrader(t) + ln := createListener(t, u) + defer ln.Close() + + testGater := &testGater{} + _, dialUpgrader := createUpgrader(t, upgrader.WithConnectionGater(testGater)) + conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + require.NotNil(conn) + _ = conn.Close() + + // blocking accepts doesn't affect the dialling side, only the listener. + testGater.BlockAccept(true) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + require.NoError(err) + require.NotNil(conn) + _ = conn.Close() + + // now let's block all connections after being secured. + testGater.BlockSecured(true) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, network.NullScope) + require.Error(err) + require.Contains(err.Error(), "gater rejected connection") + require.Nil(conn) +} + +func TestOutboundResourceManagement(t *testing.T) { + t.Run("successful handshake", func(t *testing.T) { + id, upgrader := createUpgrader(t) + ln := createListener(t, upgrader) + defer ln.Close() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + gomock.InOrder( + connScope.EXPECT().PeerScope(), + connScope.EXPECT().SetPeer(id), + connScope.EXPECT().PeerScope().Return(network.NullScope), + ) + _, dialUpgrader := createUpgrader(t) + conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, connScope) + require.NoError(t, err) + require.NotNil(t, conn) + connScope.EXPECT().Done() + require.NoError(t, conn.Close()) + }) + + t.Run("failed negotiation", func(t *testing.T) { + id, upgrader := createUpgraderWithMuxer(t, &errorMuxer{}) + ln := createListener(t, upgrader) + defer ln.Close() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + gomock.InOrder( + connScope.EXPECT().PeerScope(), + connScope.EXPECT().SetPeer(id), + connScope.EXPECT().PeerScope().Return(network.NullScope), + connScope.EXPECT().Done(), + ) + _, dialUpgrader := createUpgrader(t) + _, err := dial(t, dialUpgrader, ln.Multiaddr(), id, connScope) + require.Error(t, err) + }) + + t.Run("blocked by the resource manager", func(t *testing.T) { + + }) +} diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 494d955e6c..7af1642261 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -86,7 +86,7 @@ func (ll *tcpListener) Accept() (manet.Conn, error) { tryKeepAlive(c, true) // We're not calling OpenConnection in the resource manager here, // since the manet.Conn doesn't allow us to save the scope. - // It's the caller's (usually the go-libp2p-transport-upgrader) responsibility + // It's the caller's (usually the p2p/net/upgrader) responsibility // to call the resource manager. return c, nil } diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 3976d76fdc..06ea31b2e9 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" @@ -17,7 +18,6 @@ import ( mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" - tptu "github.com/libp2p/go-libp2p-transport-upgrader" ma "github.com/multiformats/go-multiaddr" diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index db80277adb..b83f528f84 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -16,6 +16,7 @@ import ( "time" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" @@ -27,7 +28,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" - tptu "github.com/libp2p/go-libp2p-transport-upgrader" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require"