From d193a75e4fa9f3224863fdf74fcdc98d9e9bfb5c Mon Sep 17 00:00:00 2001 From: Andrew Gaffney Date: Tue, 8 Mar 2022 10:44:18 -0600 Subject: [PATCH] Fix race condition around handshake and muxer protocol registration Fixes #35 --- muxer/muxer.go | 20 ++++++++++++++++++-- ouroboros.go | 10 ++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/muxer/muxer.go b/muxer/muxer.go index 95580860..fc9dad82 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -5,25 +5,31 @@ import ( "encoding/binary" "fmt" "io" + "net" "sync" ) const ( // Magic number chosen to represent unknown protocols PROTOCOL_UNKNOWN uint16 = 0xabcd + + // Handshake protocol ID + PROTOCOL_HANDSHAKE = 0 ) type Muxer struct { - conn io.ReadWriteCloser + conn net.Conn sendMutex sync.Mutex + startChan chan bool ErrorChan chan error protocolSenders map[uint16]chan *Segment protocolReceivers map[uint16]chan *Segment } -func New(conn io.ReadWriteCloser) *Muxer { +func New(conn net.Conn) *Muxer { m := &Muxer{ conn: conn, + startChan: make(chan bool, 1), ErrorChan: make(chan error, 10), protocolSenders: make(map[uint16]chan *Segment), protocolReceivers: make(map[uint16]chan *Segment), @@ -32,6 +38,10 @@ func New(conn io.ReadWriteCloser) *Muxer { return m } +func (m *Muxer) Start() { + m.startChan <- true +} + func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) { // Generate channels senderChan := make(chan *Segment, 10) @@ -69,6 +79,7 @@ func (m *Muxer) Send(msg *Segment) error { } func (m *Muxer) readLoop() { + started := false for { header := SegmentHeader{} if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil { @@ -83,6 +94,11 @@ func (m *Muxer) readLoop() { if _, err := io.ReadFull(m.conn, msg.Payload); err != nil { m.ErrorChan <- err } + // Wait until the muxer is started to process anything other than handshake messages + if !started && msg.GetProtocolId() != PROTOCOL_HANDSHAKE { + <-m.startChan + started = true + } // Send message payload to proper receiver recvChan := m.protocolReceivers[msg.GetProtocolId()] if recvChan == nil { diff --git a/ouroboros.go b/ouroboros.go index f0d0ec2b..924ee57b 100644 --- a/ouroboros.go +++ b/ouroboros.go @@ -20,6 +20,7 @@ type Ouroboros struct { muxer *muxer.Muxer ErrorChan chan error sendKeepAlives bool + delayMuxerStart bool // Mini-protocols Handshake *handshake.Handshake ChainSync *chainsync.ChainSync @@ -39,6 +40,7 @@ type OuroborosOptions struct { Server bool UseNodeToNodeProtocol bool SendKeepAlives bool + DelayMuxerStart bool ChainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig BlockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig KeepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig @@ -57,6 +59,7 @@ func New(options *OuroborosOptions) (*Ouroboros, error) { localTxSubmissionCallbackConfig: options.LocalTxSubmissionCallbackConfig, ErrorChan: options.ErrorChan, sendKeepAlives: options.SendKeepAlives, + delayMuxerStart: options.DelayMuxerStart, } if o.ErrorChan == nil { o.ErrorChan = make(chan error, 10) @@ -69,6 +72,10 @@ func New(options *OuroborosOptions) (*Ouroboros, error) { return o, nil } +func (o *Ouroboros) Muxer() *muxer.Muxer { + return o.muxer +} + // Convenience function for creating a connection if you didn't provide one when // calling New() func (o *Ouroboros) Dial(proto string, address string) error { @@ -134,5 +141,8 @@ func (o *Ouroboros) setupConnection() error { o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig) o.LocalTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionCallbackConfig) } + if !o.delayMuxerStart { + o.muxer.Start() + } return nil }