diff --git a/blockmanager.go b/blockmanager.go index 6a734b61..51a27a19 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -64,11 +64,11 @@ type invMsg struct { peer *ServerPeer } -// headersMsg packages a bitcoin headers message and the peer it came from +// HeadersMsg packages a bitcoin headers message and the peer it came from // together so the block handler has access to that information. -type headersMsg struct { - headers *wire.MsgHeaders - peer *ServerPeer +type HeadersMsg struct { + Headers *wire.MsgHeaders + Peer *ServerPeer } // donePeerMsg signifies a newly disconnected peer to the block handler. @@ -295,7 +295,6 @@ func (b *blockManager) Start() { log.Trace("Starting block manager") b.wg.Add(2) - go b.blockHandler() go func() { defer b.wg.Done() @@ -312,6 +311,11 @@ func (b *blockManager) Start() { log.Debug("Peer connected, starting cfHandler.") b.cfHandler() }() + + go func() { + b.getCheckpointedBlkHeaders() + b.blockHandler() + }() } // Stop gracefully shuts down the block manager by stopping all asynchronous @@ -384,16 +388,18 @@ func (b *blockManager) handleNewPeerMsg(peers *list.List, sp *ServerPeer) { // Add the peer as a candidate to sync from. peers.PushBack(sp) - // If we're current with our sync peer and the new peer is advertising - // a higher block than the newest one we know of, request headers from - // the new peer. + //// If we're current with our sync peer and the new peer is advertising + //// a higher block than the newest one we know of, request headers from + //// the new peer. _, height, err := b.cfg.BlockHeaders.ChainTip() if err != nil { log.Criticalf("Couldn't retrieve block header chain tip: %s", err) return } + if height < uint32(sp.StartingHeight()) && b.BlockHeadersSynced() { + locator, err := b.cfg.BlockHeaders.LatestBlockLocator() if err != nil { log.Criticalf("Couldn't retrieve latest block "+ @@ -402,10 +408,24 @@ func (b *blockManager) handleNewPeerMsg(peers *list.List, sp *ServerPeer) { } stopHash := &zeroHash _ = sp.PushGetHeadersMsg(locator, stopHash) + + } + + checkpoints := b.cfg.ChainParams.Checkpoints + numCheckPoints := len(checkpoints) + + if numCheckPoints == 0 { + + b.startSync(peers) + return + } + + if height > uint32(checkpoints[numCheckPoints-1].Height) { + + // Start syncing by choosing the best candidate if needed. + b.startSync(peers) } - // Start syncing by choosing the best candidate if needed. - b.startSync(peers) } // DonePeer informs the blockmanager that a peer has disconnected. @@ -427,7 +447,8 @@ func (b *blockManager) DonePeer(sp *ServerPeer) { // current sync peer, attempts to select a new best peer to sync from. It is // invoked from the syncHandler goroutine. func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { - // Remove the peer from the list of candidate peers. + + //Remove the peer from the list of candidate peers. for e := peers.Front(); e != nil; e = e.Next() { if e.Value == sp { peers.Remove(e) @@ -437,9 +458,9 @@ func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { log.Infof("Lost peer %s", sp) - // Attempt to find a new peer to sync from if the quitting peer is the - // sync peer. Also, reset the header state. - if b.SyncPeer() != nil && b.SyncPeer() == sp { + //Attempt to find a new peer to sync from if the quitting peer is the + //sync peer. Also, reset the header state. + if b.SyncPeer() != nil && b.SyncPeer() == sp && b.BlockHeadersSynced() { b.syncPeerMutex.Lock() b.syncPeer = nil b.syncPeerMutex.Unlock() @@ -928,6 +949,239 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, } } +type HeaderQuery struct { + Locator blockchain.BlockLocator + StopHash *chainhash.Hash + StartHeight int32 +} + +// CheckpointedBlockHeadersQuery holds all information necessary to perform and +// handle a query for checkpointed block headers. +type CheckpointedBlockHeadersQuery struct { + blockMgr *blockManager + msgs []HeaderQuery + headerChan chan []headerfs.BlockHeader +} + +func (c *CheckpointedBlockHeadersQuery) requests() []*query.TestRequest { + log.Debugf("Creating test request") + reqs := make([]*query.TestRequest, len(c.msgs)) + for idx, m := range c.msgs { + reqs[idx] = &query.TestRequest{ + Req: m, + HandleResp: c.handleResponse, + } + } + return reqs +} + +// handleResponse is the internal response handler used for requests for this +// block Headers query. +func (c *CheckpointedBlockHeadersQuery) handleResponse(resp wire.Message, + queryPeer query.TestPeer) query.Progress { + + r, ok := resp.(*wire.MsgHeaders) + if !ok { + // We are only looking for cfheaders messages. + return query.Progress{ + Finished: false, + Progressed: false, + } + } + + peer, ok := queryPeer.(*ServerPeer) + if !ok { + return query.Progress{ + Finished: false, + Progressed: false, + } + } + log.Debugf("Length of headers gotten: %v from %v", len(r.Headers), peer.Addr()) + headerWriteBatch, _, _, _ := c.blockMgr.processHeaderMsg(&HeadersMsg{Headers: r, Peer: peer}) + log.Debugf("Length of write batch: %v", len(headerWriteBatch)) + if headerWriteBatch == nil { + return query.Progress{ + Finished: false, + Progressed: true, + } + } + + c.headerChan = make(chan []headerfs.BlockHeader, len(headerWriteBatch)) + + // At this point, the response matches the query, and the relevant + // checkpoint we got earlier, so we'll deliver the verified headers on + // the headerChan. We'll also return a Progress indicating the query + // finished, that the peer looking for the answer to this query can + // move on to the next query. + select { + case c.headerChan <- headerWriteBatch: + case <-c.blockMgr.quit: + return query.Progress{ + Finished: false, + Progressed: false, + } + } + + return query.Progress{ + Finished: true, + Progressed: true, + } +} + +func (b *blockManager) getCheckpointedBlkHeaders() { + log.Infof("Inside getCheckpointedHeaders") + + // We keep going until we've caught up the filter header store with the + // latest known checkpoint. + tipHeader, tipHeight, err := b.cfg.BlockHeaders.ChainTip() + if err != nil { + log.Errorf("Failed to get hash and height for the "+ + "latest block: %s", err) + return + } + + var queryMsgs []HeaderQuery + var queryResponses map[chainhash.Hash][]headerfs.BlockHeader + //initialBlockHeader := tipHeader + + checkpoints := b.cfg.ChainParams.Checkpoints + numCheckpts := len(checkpoints) + if numCheckpts == 0 { + return + } + curHeight := tipHeight + knownLocator, err := b.cfg.BlockHeaders.LatestBlockLocator() + if err != nil { + log.Errorf("Failed to get latest block locator: %s", err) + } + log.Infof("Fetching set of checkpointed blockheaders from "+ + "height=%v, hash=%v", tipHeight, tipHeader) + for curHeight < uint32(checkpoints[numCheckpts-1].Height) { + + //nextCheckpoint := b.findNextHeaderCheckpoint(int32(curHeight)) + //endHeight := uint32(nextCheckpoint.Height) + curHeight + //if endHeight > uint32(checkpoints[len(checkpoints)-1].Height) { + // endHeight = uint32(nextCheckpoint.Height) + //} + //log.Tracef("Checkpointed cfheaders request start_range=%v, "+ + // "end_range=%v", curHeight, endHeight) + // + //// In order to fetch the range, we'll need the block header for + //// the end of the height range. + //endHeader, err := b.cfg.BlockHeaders.FetchHeaderByHeight( + // endHeight, + //) + //if err != nil { + // panic(fmt.Sprintf("failed getting block header at "+ + // "height %v: %v", endHeight, err)) + //} + + endHash := b.nextCheckpoint.Hash + + // Make into a function + curLocator := make(blockchain.BlockLocator, 0, + wire.MaxBlockLocatorsPerMsg) + curLocator = append(curLocator, endHash) + + // Add curLocator from the database as backup. + if err == nil { + curLocator = append(curLocator, knownLocator...) + } + + queryMsg := HeaderQuery{ + Locator: curLocator, + StopHash: endHash, + StartHeight: int32(curHeight), + } + log.Infof("Fetching set of checkpointed blockheaders from "+ + "start_height=%v to end-height=%v", curHeight, endHash) + // We'll mark that the ith interval is queried by this message, + // and also map the stop hash back to the index of this message. + queryMsgs = append(queryMsgs, queryMsg) + // With the query starting at the current interval constructed, + // we'll move onto the next one. + curHeight = uint32(b.nextCheckpoint.Height) + b.nextCheckpoint = b.findNextHeaderCheckpoint(int32(curHeight)) + + } + + batchesCount := len(queryMsgs) + if batchesCount == 0 { + return + } + + log.Infof("Attempting to query for %v blockheader batches", batchesCount) + + // With the set of messages constructed, we'll now request the batch + // all at once. This message will distributed the header requests + // amongst all active peers, effectively sharding each query + // dynamically. + headerChan := make(chan []headerfs.BlockHeader, len(queryMsgs)) + q := CheckpointedBlockHeadersQuery{ + blockMgr: b, + msgs: queryMsgs, + } + + // Hand the queries to the work manager, and consume the verified + // responses as they come back. + errChan := b.cfg.QueryDispatcher.TestQuery( + q.requests(), query.Cancel(b.quit), + ) + + // Keep waiting for more headers as long as we haven't received an + // answer for our last checkpoint, and no error is encountered. + for { + var r []headerfs.BlockHeader + select { + case r = <-headerChan: + case err := <-errChan: + switch { + case err == query.ErrWorkManagerShuttingDown: + return + case err != nil: + log.Errorf("Query finished with error before "+ + "all responses received: %v", err) + return + } + + // The query did finish successfully, but continue to + // allow picking up the last header sent on the + // headerChan. + continue + + case <-b.quit: + return + } + + // Add the verified response to our cache. + queryResponses[r[0].BlockHash()] = r + + // Then, we cycle through any cached messages, adding + // them to the batch and deleting them from the cache. + for { + // If we don't yet have the next response, then + // we'll break out so we can wait for the peers + // to respond with this message. + writeBatch, ok := queryResponses[tipHeader.BlockHash()] + if !ok { + break + } + + // We have another response to write, so delete + // it from the cache and write it. + delete(queryResponses, tipHeader.BlockHash()) + + numWriteBatch := len(writeBatch) + finalHash := r[numWriteBatch-1].BlockHash() + finalHeight := int32(r[numWriteBatch-1].Height) + b.writeHeaderBatch(r, &finalHash, finalHeight, nil, true) + + } + + } + +} + // getCheckpointedCFHeaders catches a filter header store up with the // checkpoints we got from the network. It assumes that the filter header store // matches the checkpoints up to the tip of the store. @@ -1998,7 +2252,7 @@ out: case *invMsg: b.handleInvMsg(msg) - case *headersMsg: + case *HeadersMsg: b.handleHeadersMsg(msg) case *donePeerMsg: @@ -2092,10 +2346,10 @@ func (b *blockManager) findPreviousHeaderCheckpoint(height int32) *chaincfg.Chec // simply returns. It also examines the candidates for any which are no longer // candidates and removes them as needed. func (b *blockManager) startSync(peers *list.List) { - // Return now if we're already syncing. - if b.syncPeer != nil { - return - } + //// Return now if we're already syncing. + //if b.syncPeer != nil { + // return + //} _, bestHeight, err := b.cfg.BlockHeaders.ChainTip() if err != nil { @@ -2311,7 +2565,7 @@ func (b *blockManager) handleInvMsg(imsg *invMsg) { // If this is the sync peer or we're current, get the headers for the // announced blocks and update the last announced block. - if lastBlock != -1 && (imsg.peer == b.SyncPeer() || b.BlockHeadersSynced()) { + if lastBlock != -1 && b.BlockHeadersSynced() { lastEl := b.headerList.Back() var lastHash chainhash.Hash if lastEl != nil { @@ -2358,20 +2612,104 @@ func (b *blockManager) QueueHeaders(headers *wire.MsgHeaders, sp *ServerPeer) { } select { - case b.peerChan <- &headersMsg{headers: headers, peer: sp}: + case b.peerChan <- &HeadersMsg{Headers: headers, Peer: sp}: case <-b.quit: return } } // handleHeadersMsg handles headers messages from all peers. -func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { - msg := hmsg.headers +func (b *blockManager) handleHeadersMsg(hmsg *HeadersMsg) { + checkpoints := b.cfg.ChainParams.Checkpoints + numCheckPoints := len(checkpoints) + _, height, err := b.cfg.BlockHeaders.ChainTip() + + if err != nil { + return + } + + if numCheckPoints != 0 && height > uint32(checkpoints[numCheckPoints-1].Height) { + + headerWriteBatch, receivedCheckpoint, finalHash, finalHeight := b.processHeaderMsg(hmsg) + if headerWriteBatch == nil { + return + } + + b.writeHeaderBatch(headerWriteBatch, finalHash, finalHeight, hmsg, receivedCheckpoint) + } + +} + +func (b *blockManager) writeHeaderBatch(headerWriteBatch []headerfs.BlockHeader, + finalHash *chainhash.Hash, finalHeight int32, hmsg *HeadersMsg, receivedCheckpoint bool) { + + log.Tracef("Writing header batch of %v block headers", + len(headerWriteBatch)) + + if len(headerWriteBatch) > 0 { + // With all the headers in this batch validated, we'll write + // them all in a single transaction such that this entire batch + // is atomic. + err := b.cfg.BlockHeaders.WriteHeaders(headerWriteBatch...) + if err != nil { + log.Errorf("Unable to write block headers: %v", err) + return + } + } + + //// When this header is a checkpoint, find the next checkpoint. + if receivedCheckpoint { + b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) + } + + // If not current, request the next batch of headers starting from the + // latest known header and ending with the next checkpoint. + + numCheckpoints := len(b.cfg.ChainParams.Checkpoints) + + if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || + finalHeight >= b.cfg.ChainParams.Checkpoints[numCheckpoints-1].Height { + locator := blockchain.BlockLocator([]*chainhash.Hash{finalHash}) + nextHash := zeroHash + if b.nextCheckpoint != nil { + nextHash = *b.nextCheckpoint.Hash + } + err := hmsg.Peer.PushGetHeadersMsg(locator, &nextHash) + if err != nil { + log.Warnf("Failed to send getheaders message to "+ + "peer %s: %s", hmsg.Peer.Addr(), err) + return + } + } + + // Since we have a new set of headers written to disk, we'll send out a + // new signal to notify any waiting sub-systems that they can now maybe + // proceed do to us extending the header chain. + b.newHeadersMtx.Lock() + b.headerTip = uint32(finalHeight) + b.headerTipHash = *finalHash + b.newHeadersMtx.Unlock() + b.newHeadersSignal.Broadcast() + +} + +func (b *blockManager) processHeaderMsg(hmsg *HeadersMsg) ([]headerfs.BlockHeader, bool, *chainhash.Hash, int32) { + + // Process all of the received headers ensuring each one connects to + // the previous and that checkpoints match. + + var ( + finalHash *chainhash.Hash + finalHeight int32 + receivedCheckpoint bool + ) + + msg := hmsg.Headers numHeaders := len(msg.Headers) // Nothing to do for an empty headers message. if numHeaders == 0 { - return + return nil, receivedCheckpoint, finalHash, finalHeight } // For checking to make sure blocks aren't too far in the future as of @@ -2383,13 +2721,6 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // atomically in order to improve peformance. headerWriteBatch := make([]headerfs.BlockHeader, 0, len(msg.Headers)) - // Process all of the received headers ensuring each one connects to - // the previous and that checkpoints match. - receivedCheckpoint := false - var ( - finalHash *chainhash.Hash - finalHeight int32 - ) for i, blockHeader := range msg.Headers { blockHash := blockHeader.BlockHash() finalHash = &blockHash @@ -2399,11 +2730,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { if prevNodeEl == nil { log.Warnf("Header list does not contain a previous" + "element as expected -- disconnecting peer") - hmsg.peer.Disconnect() - return + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } - // Ensure the header properly connects to the previous one, + // Ensure the header properlnily connects to the previous one, // that the proof of work is good, and that the header's // timestamp isn't too far in the future, and add it to the // list of headers. @@ -2416,8 +2747,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { if err != nil { log.Warnf("Header doesn't pass sanity check: "+ "%s -- disconnecting peer", err) - hmsg.peer.Disconnect() - return + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } node.Height = prevNode.Height + 1 @@ -2430,7 +2761,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { Height: uint32(node.Height), }) - hmsg.peer.UpdateLastBlockHeight(node.Height) + hmsg.Peer.UpdateLastBlockHeight(node.Height) b.blkHeaderProgressLogger.LogBlockHeight( blockHeader.Timestamp, node.Height, @@ -2456,8 +2787,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // reorg, in which case we'll either change our sync // peer or disconnect the peer that sent us these bad // headers. - if hmsg.peer != b.SyncPeer() && !b.BlockHeadersSynced() { - return + if hmsg.Peer != b.SyncPeer() && !b.BlockHeadersSynced() { + return nil, receivedCheckpoint, finalHash, finalHeight } // Check if this is the last block we know of. This is @@ -2488,9 +2819,9 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { log.Warnf("Received block header that does not"+ " properly connect to the chain from"+ " peer %s (%s) -- disconnecting", - hmsg.peer.Addr(), err) - hmsg.peer.Disconnect() - return + hmsg.Peer.Addr(), err) + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } // We've found a branch we weren't aware of. If the @@ -2504,9 +2835,9 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { log.Errorf("Attempt at a reorg earlier than a "+ "checkpoint past which we've already "+ "synchronized -- disconnecting peer "+ - "%s", hmsg.peer.Addr()) - hmsg.peer.Disconnect() - return + "%s", hmsg.Peer.Addr()) + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } // Check the sanity of the new branch. If any of the @@ -2526,8 +2857,8 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { log.Warnf("Header doesn't pass sanity"+ " check: %s -- disconnecting "+ "peer", err) - hmsg.peer.Disconnect() - return + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } totalWork.Add(totalWork, blockchain.CalcWork(reorgHeader.Bits)) @@ -2577,11 +2908,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { case 1: log.Warnf("Reorg attempt that has less work "+ "than known chain from peer %s -- "+ - "disconnecting", hmsg.peer.Addr()) - hmsg.peer.Disconnect() + "disconnecting", hmsg.Peer.Addr()) + hmsg.Peer.Disconnect() fallthrough case 0: - return + return nil, receivedCheckpoint, finalHash, finalHeight default: } @@ -2591,7 +2922,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // continue with the rest of the headers in the message // as if nothing has happened. b.syncPeerMutex.Lock() - b.syncPeer = hmsg.peer + b.syncPeer = hmsg.Peer b.syncPeerMutex.Unlock() err = b.rollBackToHeight(backHeight) if err != nil { @@ -2633,7 +2964,7 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { "%s from peer %s does NOT match "+ "expected checkpoint hash of %s -- "+ "disconnecting", node.Height, - nodeHash, hmsg.peer.Addr(), + nodeHash, hmsg.Peer.Addr(), b.nextCheckpoint.Hash) prevCheckpoint := b.findPreviousHeaderCheckpoint( @@ -2654,56 +2985,14 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // Should we panic here? } - hmsg.peer.Disconnect() - return + hmsg.Peer.Disconnect() + return nil, receivedCheckpoint, finalHash, finalHeight } break } } - log.Tracef("Writing header batch of %v block headers", - len(headerWriteBatch)) - - if len(headerWriteBatch) > 0 { - // With all the headers in this batch validated, we'll write - // them all in a single transaction such that this entire batch - // is atomic. - err := b.cfg.BlockHeaders.WriteHeaders(headerWriteBatch...) - if err != nil { - log.Errorf("Unable to write block headers: %v", err) - return - } - } - - // When this header is a checkpoint, find the next checkpoint. - if receivedCheckpoint { - b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) - } - - // If not current, request the next batch of headers starting from the - // latest known header and ending with the next checkpoint. - if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || !b.BlockHeadersSynced() { - locator := blockchain.BlockLocator([]*chainhash.Hash{finalHash}) - nextHash := zeroHash - if b.nextCheckpoint != nil { - nextHash = *b.nextCheckpoint.Hash - } - err := hmsg.peer.PushGetHeadersMsg(locator, &nextHash) - if err != nil { - log.Warnf("Failed to send getheaders message to "+ - "peer %s: %s", hmsg.peer.Addr(), err) - return - } - } - - // Since we have a new set of headers written to disk, we'll send out a - // new signal to notify any waiting sub-systems that they can now maybe - // proceed do to us extending the header chain. - b.newHeadersMtx.Lock() - b.headerTip = uint32(finalHeight) - b.headerTipHash = *finalHash - b.newHeadersMtx.Unlock() - b.newHeadersSignal.Broadcast() + return headerWriteBatch, receivedCheckpoint, finalHash, finalHeight } // checkHeaderSanity checks the PoW, and timestamp of a block header. diff --git a/blockmanager_test.go b/blockmanager_test.go index 0b710c7b..d47dd7a8 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -1,909 +1,909 @@ package neutrino -import ( - "encoding/binary" - "fmt" - "io/ioutil" - "math/rand" - "os" - "strings" - "testing" - "time" - - "github.com/btcsuite/btcd/btcutil/gcs" - "github.com/btcsuite/btcd/btcutil/gcs/builder" - "github.com/btcsuite/btcd/chaincfg" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/peer" - "github.com/btcsuite/btcd/txscript" - "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcwallet/walletdb" - "github.com/lightninglabs/neutrino/banman" - "github.com/lightninglabs/neutrino/blockntfns" - "github.com/lightninglabs/neutrino/headerfs" - "github.com/lightninglabs/neutrino/query" -) - -const ( - // maxHeight is the height we will generate filter headers up to. We use an odd - // number of checkpoints to ensure we can test cases where the block manager is - // only able to fetch filter headers for one checkpoint interval rather than - // two. - maxHeight = 21 * uint32(wire.CFCheckptInterval) - - dbOpenTimeout = time.Second * 10 -) - -// mockDispatcher implements the query.Dispatcher interface and allows us to -// set up a custom Query method during tests. -type mockDispatcher struct { - query func(requests []*query.Request, - options ...query.QueryOption) chan error -} - -var _ query.Dispatcher = (*mockDispatcher)(nil) - -func (m *mockDispatcher) Query(requests []*query.Request, - options ...query.QueryOption) chan error { - - return m.query(requests, options...) -} - -// setupBlockManager initialises a blockManager to be used in tests. -func setupBlockManager() (*blockManager, headerfs.BlockHeaderStore, - *headerfs.FilterHeaderStore, func(), error) { - - // Set up the block and filter header stores. - tempDir, err := ioutil.TempDir("", "neutrino") - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("Failed to create "+ - "temporary directory: %s", err) - } - - db, err := walletdb.Create( - "bdb", tempDir+"/weks.db", true, dbOpenTimeout, - ) - if err != nil { - os.RemoveAll(tempDir) - return nil, nil, nil, nil, fmt.Errorf("Error opening DB: %s", - err) - } - - cleanUp := func() { - db.Close() - os.RemoveAll(tempDir) - } - - hdrStore, err := headerfs.NewBlockHeaderStore( - tempDir, db, &chaincfg.SimNetParams, - ) - if err != nil { - cleanUp() - return nil, nil, nil, nil, fmt.Errorf("Error creating block "+ - "header store: %s", err) - } - - cfStore, err := headerfs.NewFilterHeaderStore( - tempDir, db, headerfs.RegularFilter, &chaincfg.SimNetParams, - nil, - ) - if err != nil { - cleanUp() - return nil, nil, nil, nil, fmt.Errorf("Error creating filter "+ - "header store: %s", err) - } - - // Set up a blockManager with the chain service we defined. - bm, err := newBlockManager(&blockManagerCfg{ - ChainParams: chaincfg.SimNetParams, - BlockHeaders: hdrStore, - RegFilterHeaders: cfStore, - QueryDispatcher: &mockDispatcher{}, - BanPeer: func(string, banman.Reason) error { return nil }, - }) - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("unable to create "+ - "blockmanager: %v", err) - } - - return bm, hdrStore, cfStore, cleanUp, nil -} - -// headers wraps the different headers and filters used throughout the tests. -type headers struct { - blockHeaders []headerfs.BlockHeader - cfHeaders []headerfs.FilterHeader - checkpoints []*chainhash.Hash - filterHashes []chainhash.Hash -} - -// generateHeaders generates block headers, filter header and hashes, and -// checkpoints from the given genesis. The onCheckpoint method will be called -// with the current cf header on each checkpoint to modify the derivation of -// the next interval. -func generateHeaders(genesisBlockHeader *wire.BlockHeader, - genesisFilterHeader *chainhash.Hash, - onCheckpoint func(*chainhash.Hash)) (*headers, error) { - - var blockHeaders []headerfs.BlockHeader - blockHeaders = append(blockHeaders, headerfs.BlockHeader{ - BlockHeader: genesisBlockHeader, - Height: 0, - }) - - var cfHeaders []headerfs.FilterHeader - cfHeaders = append(cfHeaders, headerfs.FilterHeader{ - HeaderHash: genesisBlockHeader.BlockHash(), - FilterHash: *genesisFilterHeader, - Height: 0, - }) - - // The filter hashes (not the filter headers!) will be sent as - // part of the CFHeaders response, so we also keep track of - // them. - genesisFilter, err := builder.BuildBasicFilter( - chaincfg.SimNetParams.GenesisBlock, nil, - ) - if err != nil { - return nil, fmt.Errorf("unable to build genesis filter: %v", - err) - } - - genesisFilterHash, err := builder.GetFilterHash(genesisFilter) - if err != nil { - return nil, fmt.Errorf("unable to get genesis filter hash: %v", - err) - } - - var filterHashes []chainhash.Hash - filterHashes = append(filterHashes, genesisFilterHash) - - // Also keep track of the current filter header. We use this to - // calculate the next filter header, as it commits to the - // previous. - currentCFHeader := *genesisFilterHeader - - // checkpoints will be the checkpoints passed to - // getCheckpointedCFHeaders. - var checkpoints []*chainhash.Hash - - for height := uint32(1); height <= maxHeight; height++ { - header := heightToHeader(height) - blockHeader := headerfs.BlockHeader{ - BlockHeader: header, - Height: height, - } - - blockHeaders = append(blockHeaders, blockHeader) - - // It doesn't really matter what filter the filter - // header commit to, so just use the height as a nonce - // for the filters. - filterHash := chainhash.Hash{} - binary.BigEndian.PutUint32(filterHash[:], height) - filterHashes = append(filterHashes, filterHash) - - // Calculate the current filter header, and add to our - // slice. - currentCFHeader = chainhash.DoubleHashH( - append(filterHash[:], currentCFHeader[:]...), - ) - cfHeaders = append(cfHeaders, headerfs.FilterHeader{ - HeaderHash: header.BlockHash(), - FilterHash: currentCFHeader, - Height: height, - }) - - // Each interval we must record a checkpoint. - if height%wire.CFCheckptInterval == 0 { - // We must make a copy of the current header to - // avoid mutation. - cfh := currentCFHeader - checkpoints = append(checkpoints, &cfh) - - if onCheckpoint != nil { - onCheckpoint(¤tCFHeader) - } - } - } - - return &headers{ - blockHeaders: blockHeaders, - cfHeaders: cfHeaders, - checkpoints: checkpoints, - filterHashes: filterHashes, - }, nil -} - -// generateResponses generates the MsgCFHeaders messages from the given queries -// and headers. -func generateResponses(msgs []wire.Message, - headers *headers) ([]*wire.MsgCFHeaders, error) { - - // Craft a response for each message. - var responses []*wire.MsgCFHeaders - for _, msg := range msgs { - // Only GetCFHeaders expected. - q, ok := msg.(*wire.MsgGetCFHeaders) - if !ok { - return nil, fmt.Errorf("got unexpected message %T", - msg) - } - - // The start height must be set to a checkpoint height+1. - if q.StartHeight%wire.CFCheckptInterval != 1 { - return nil, fmt.Errorf("unexpexted start height %v", - q.StartHeight) - } - - var prevFilterHeader chainhash.Hash - switch q.StartHeight { - // If the start height is 1 the prevFilterHeader is set to the - // genesis header. - case 1: - genesisFilterHeader := headers.cfHeaders[0].FilterHash - prevFilterHeader = genesisFilterHeader - - // Otherwise we use one of the created checkpoints. - default: - j := q.StartHeight/wire.CFCheckptInterval - 1 - prevFilterHeader = *headers.checkpoints[j] - } - - resp := &wire.MsgCFHeaders{ - FilterType: q.FilterType, - StopHash: q.StopHash, - PrevFilterHeader: prevFilterHeader, - } - - // Keep adding filter hashes until we reach the stop hash. - for h := q.StartHeight; ; h++ { - resp.FilterHashes = append( - resp.FilterHashes, &headers.filterHashes[h], - ) - - blockHash := headers.blockHeaders[h].BlockHash() - if blockHash == q.StopHash { - break - } - } - - responses = append(responses, resp) - } - - return responses, nil -} - -// TestBlockManagerInitialInterval tests that the block manager is able to -// handle checkpointed filter header query responses in out of order, and when -// a partial interval is already written to the store. -func TestBlockManagerInitialInterval(t *testing.T) { - t.Parallel() - - type testCase struct { - // permute indicates whether responses should be permutated. - permute bool - - // partialInterval indicates whether we should write parts of - // the first checkpoint interval to the filter header store - // before starting the test. - partialInterval bool - - // repeat indicates whether responses should be repeated. - repeat bool - } - - // Generate all combinations of testcases. - var testCases []testCase - b := []bool{false, true} - for _, perm := range b { - for _, part := range b { - for _, rep := range b { - testCases = append(testCases, testCase{ - permute: perm, - partialInterval: part, - repeat: rep, - }) - } - } - } - - for _, test := range testCases { - test := test - testDesc := fmt.Sprintf("permute=%v, partial=%v, repeat=%v", - test.permute, test.partialInterval, test.repeat) - - bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() - if err != nil { - t.Fatalf("unable to set up ChainService: %v", err) - } - defer cleanUp() - - // Keep track of the filter headers and block headers. Since - // the genesis headers are written automatically when the store - // is created, we query it to add to the slices. - genesisBlockHeader, _, err := hdrStore.ChainTip() - if err != nil { - t.Fatal(err) - } - - genesisFilterHeader, _, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } - - headers, err := generateHeaders(genesisBlockHeader, - genesisFilterHeader, nil) - if err != nil { - t.Fatalf("unable to generate headers: %v", err) - } - - // Write all block headers but the genesis, since it is already - // in the store. - if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { - t.Fatalf("Error writing batch of headers: %s", err) - } - - // We emulate the case where a few filter headers are already - // written to the store by writing 1/3 of the first interval. - if test.partialInterval { - err = cfStore.WriteHeaders( - headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., - ) - if err != nil { - t.Fatalf("Error writing batch of headers: %s", - err) - } - } - - // We set up a custom query batch method for this test, as we - // will use this to feed the blockmanager with our crafted - // responses. - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( - requests []*query.Request, - options ...query.QueryOption) chan error { - - var msgs []wire.Message - for _, q := range requests { - msgs = append(msgs, q.Req) - } - - responses, err := generateResponses(msgs, headers) - if err != nil { - t.Fatalf("unable to generate responses: %v", - err) - } - - // We permute the response order if the test signals - // that. - perm := rand.Perm(len(responses)) - - errChan := make(chan error, 1) - go func() { - for i, v := range perm { - index := i - if test.permute { - index = v - } - - // Before handling the response we take - // copies of the message, as we cannot - // guarantee that it won't be modified. - resp := *responses[index] - resp2 := *responses[index] - - // Let the blockmanager handle the - // message. - progress := requests[index].HandleResp( - msgs[index], &resp, "", - ) - - if !progress.Finished { - errChan <- fmt.Errorf("got "+ - "response false on "+ - "send of index %d: %v", - index, testDesc) - return - } - - // If we are not testing repeated - // responses, go on to the next - // response. - if !test.repeat { - continue - } - - // Otherwise resend the response we - // just sent. - progress = requests[index].HandleResp( - msgs[index], &resp2, "", - ) - if !progress.Finished { - errChan <- fmt.Errorf("got "+ - "response false on "+ - "resend of index %d: "+ - "%v", index, testDesc) - return - } - } - errChan <- nil - }() - - return errChan - } - - // We should expect to see notifications for each new filter - // header being connected. - startHeight := uint32(1) - if test.partialInterval { - startHeight = wire.CFCheckptInterval / 3 - } - go func() { - for i := startHeight; i <= maxHeight; i++ { - ntfn := <-bm.blockNtfnChan - if _, ok := ntfn.(*blockntfns.Connected); !ok { - t.Error("expected block connected " + - "notification") - return - } - } - }() - - // Call the get checkpointed cf headers method with the - // checkpoints we created to start the test. - bm.getCheckpointedCFHeaders( - headers.checkpoints, cfStore, wire.GCSFilterRegular, - ) - - // Finally make sure the filter header tip is what we expect. - tip, tipHeight, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } - - if tipHeight != maxHeight { - t.Fatalf("expected tip height to be %v, was %v", - maxHeight, tipHeight) - } - - lastCheckpoint := headers.checkpoints[len(headers.checkpoints)-1] - if *tip != *lastCheckpoint { - t.Fatalf("expected tip to be %v, was %v", - lastCheckpoint, tip) - } - } -} - -// TestBlockManagerInvalidInterval tests that the block manager is able to -// determine it is receiving corrupt checkpoints and filter headers. -func TestBlockManagerInvalidInterval(t *testing.T) { - t.Parallel() - - type testCase struct { - // wrongGenesis indicates whether we should start deriving the - // filters from a wrong genesis. - wrongGenesis bool - - // intervalMisaligned indicates whether each interval prev hash - // should not line up with the previous checkpoint. - intervalMisaligned bool - - // invalidPrevHash indicates whether the interval responses - // should have a prev hash that doesn't mathc that interval. - invalidPrevHash bool - - // partialInterval indicates whether we should write parts of - // the first checkpoint interval to the filter header store - // before starting the test. - partialInterval bool - - // firstInvalid is the first interval response we expect the - // blockmanager to determine is invalid. - firstInvalid int - } - - testCases := []testCase{ - // With a set of checkpoints and filter headers calculated from - // the wrong genesis, the block manager should be able to - // determine that the first interval doesn't line up. - { - wrongGenesis: true, - firstInvalid: 0, - }, - - // With checkpoints calculated from the wrong genesis, and a - // partial set of filter headers already written, the first - // interval response should be considered invalid. - { - wrongGenesis: true, - partialInterval: true, - firstInvalid: 0, - }, - - // With intervals not lining up, the second interval response - // should be determined invalid. - { - intervalMisaligned: true, - firstInvalid: 0, - }, - - // With misaligned intervals and a partial interval written, the - // second interval response should be considered invalid. - { - intervalMisaligned: true, - partialInterval: true, - firstInvalid: 0, - }, - - // With responses having invalid prev hashes, the second - // interval response should be deemed invalid. - { - invalidPrevHash: true, - firstInvalid: 1, - }, - } - - for _, test := range testCases { - test := test - bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() - if err != nil { - t.Fatalf("unable to set up ChainService: %v", err) - } - defer cleanUp() - - // Create a mock peer to prevent panics when attempting to ban - // a peer that served an invalid filter header. - mockPeer := NewServerPeer(&ChainService{}, false) - mockPeer.Peer, err = peer.NewOutboundPeer( - NewPeerConfig(mockPeer), "127.0.0.1:8333", - ) - if err != nil { - t.Fatal(err) - } - - // Keep track of the filter headers and block headers. Since - // the genesis headers are written automatically when the store - // is created, we query it to add to the slices. - genesisBlockHeader, _, err := hdrStore.ChainTip() - if err != nil { - t.Fatal(err) - } - - genesisFilterHeader, _, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } - // To emulate a full node serving us filter headers derived - // from different genesis than what we have, we flip a bit in - // the genesis filter header. - if test.wrongGenesis { - genesisFilterHeader[0] ^= 1 - } - - headers, err := generateHeaders(genesisBlockHeader, - genesisFilterHeader, - func(currentCFHeader *chainhash.Hash) { - // If we are testing that each interval doesn't - // line up properly with the previous, we flip - // a bit in the current header before - // calculating the next interval checkpoint. - if test.intervalMisaligned { - currentCFHeader[0] ^= 1 - } - }) - if err != nil { - t.Fatalf("unable to generate headers: %v", err) - } - - // Write all block headers but the genesis, since it is already - // in the store. - if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { - t.Fatalf("Error writing batch of headers: %s", err) - } - - // We emulate the case where a few filter headers are already - // written to the store by writing 1/3 of the first interval. - if test.partialInterval { - err = cfStore.WriteHeaders( - headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., - ) - if err != nil { - t.Fatalf("Error writing batch of headers: %s", - err) - } - } - - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( - requests []*query.Request, - options ...query.QueryOption) chan error { - - var msgs []wire.Message - for _, q := range requests { - msgs = append(msgs, q.Req) - } - responses, err := generateResponses(msgs, headers) - if err != nil { - t.Fatalf("unable to generate responses: %v", - err) - } - - // Since we used the generated checkpoints when - // creating the responses, we must flip the - // PrevFilterHeader bit back before sending them if we - // are checking for misaligned intervals. This to - // ensure we don't hit the invalid prev hash case. - if test.intervalMisaligned { - for i := range responses { - if i == 0 { - continue - } - responses[i].PrevFilterHeader[0] ^= 1 - } - } - - // If we are testing for intervals with invalid prev - // hashes, we flip a bit to corrup them, regardless of - // whether we are testing misaligned intervals. - if test.invalidPrevHash { - for i := range responses { - if i == 0 { - continue - } - responses[i].PrevFilterHeader[1] ^= 1 - } - } - - errChan := make(chan error, 1) - go func() { - // Check that the success of the callback match what we - // expect. - for i := range responses { - progress := requests[i].HandleResp( - msgs[i], responses[i], "", - ) - if i == test.firstInvalid { - if progress.Finished { - t.Errorf("expected interval "+ - "%d to be invalid", i) - return - } - errChan <- fmt.Errorf("invalid interval") - break - } - - if !progress.Finished { - t.Errorf("expected interval %d to be "+ - "valid", i) - return - } - } - - errChan <- nil - }() - - return errChan - } - - // We should expect to see notifications for each new filter - // header being connected. - startHeight := uint32(1) - if test.partialInterval { - startHeight = wire.CFCheckptInterval / 3 - } - go func() { - for i := startHeight; i <= maxHeight; i++ { - ntfn := <-bm.blockNtfnChan - if _, ok := ntfn.(*blockntfns.Connected); !ok { - t.Error("expected block connected " + - "notification") - return - } - } - }() - - // Start the test by calling the get checkpointed cf headers - // method with the checkpoints we created. - bm.getCheckpointedCFHeaders( - headers.checkpoints, cfStore, wire.GCSFilterRegular, - ) - } -} - -// buildNonPushScriptFilter creates a CFilter with all output scripts except all -// OP_RETURNS with push-only scripts. +//import ( +//"encoding/binary" +//"fmt" +//"io/ioutil" +//"math/rand" +//"os" +//"strings" +//"testing" +//"time" // -// NOTE: this is not a valid filter, only for tests. -func buildNonPushScriptFilter(block *wire.MsgBlock) (*gcs.Filter, error) { - blockHash := block.BlockHash() - b := builder.WithKeyHash(&blockHash) - - for _, tx := range block.Transactions { - for _, txOut := range tx.TxOut { - // The old version of BIP-158 skipped OP_RETURNs that - // had a push-only script. - if txOut.PkScript[0] == txscript.OP_RETURN && - txscript.IsPushOnlyScript(txOut.PkScript[1:]) { - - continue - } - - b.AddEntry(txOut.PkScript) - } - } - - return b.Build() -} - -// buildAllPkScriptsFilter creates a CFilter with all output scripts, including -// OP_RETURNS. +//"github.com/btcsuite/btcd/btcutil/gcs" +//"github.com/btcsuite/btcd/btcutil/gcs/builder" +//"github.com/btcsuite/btcd/chaincfg" +//"github.com/btcsuite/btcd/chaincfg/chainhash" +//"github.com/btcsuite/btcd/peer" +//"github.com/btcsuite/btcd/txscript" +//"github.com/btcsuite/btcd/wire" +//"github.com/btcsuite/btcwallet/walletdb" +//"github.com/lightninglabs/neutrino/banman" +//"github.com/lightninglabs/neutrino/blockntfns" +//"github.com/lightninglabs/neutrino/headerfs" +//"github.com/lightninglabs/neutrino/query" +//) // -// NOTE: this is not a valid filter, only for tests. -func buildAllPkScriptsFilter(block *wire.MsgBlock) (*gcs.Filter, error) { - blockHash := block.BlockHash() - b := builder.WithKeyHash(&blockHash) - - for _, tx := range block.Transactions { - for _, txOut := range tx.TxOut { - // An old version of BIP-158 included all output - // scripts. - b.AddEntry(txOut.PkScript) - } - } - - return b.Build() -} - -func assertBadPeers(expBad map[string]struct{}, badPeers []string) error { - remBad := make(map[string]struct{}) - for p := range expBad { - remBad[p] = struct{}{} - } - for _, peer := range badPeers { - _, ok := remBad[peer] - if !ok { - return fmt.Errorf("did not expect %v to be bad", peer) - } - delete(remBad, peer) - } - - if len(remBad) != 0 { - return fmt.Errorf("did expect more bad peers") - } - - return nil -} - -// TestBlockManagerDetectBadPeers checks that we detect bad peers, like peers -// not responding to our filter query, serving inconsistent filters etc. -func TestBlockManagerDetectBadPeers(t *testing.T) { - t.Parallel() - - var ( - stopHash = chainhash.Hash{} - prev = chainhash.Hash{} - startHeight = uint32(100) - badIndex = uint32(5) - targetIndex = startHeight + badIndex - fType = wire.GCSFilterRegular - filterBytes, _ = correctFilter.NBytes() - filterHash, _ = builder.GetFilterHash(correctFilter) - blockHeader = wire.BlockHeader{} - targetBlockHash = block.BlockHash() - - peers = []string{"good1:1", "good2:1", "bad:1", "good3:1"} - expBad = map[string]struct{}{ - "bad:1": {}, - } - ) - - testCases := []struct { - // filterAnswers is used by each testcase to set the anwers we - // want each peer to respond with on filter queries. - filterAnswers func(string, map[string]wire.Message) - }{ - { - // We let the "bad" peers not respond to the filter - // query. They should be marked bad because they are - // unresponsive. We do this to ensure peers cannot - // only respond to us with headers, and stall our sync - // by not responding to filter requests. - filterAnswers: func(p string, - answers map[string]wire.Message) { - - if strings.Contains(p, "bad") { - return - } - - answers[p] = wire.NewMsgCFilter( - fType, &targetBlockHash, filterBytes, - ) - }, - }, - { - // We let the "bad" peers serve filters that don't hash - // to the filter headers they have sent. - filterAnswers: func(p string, - answers map[string]wire.Message) { - - filterData := filterBytes - if strings.Contains(p, "bad") { - filterData, _ = fakeFilter1.NBytes() - } - - answers[p] = wire.NewMsgCFilter( - fType, &targetBlockHash, filterData, - ) - }, - }, - } - - for _, test := range testCases { - // Create a mock block header store. We only need to be able to - // serve a header for the target index. - blockHeaders := newMockBlockHeaderStore() - blockHeaders.heights[targetIndex] = blockHeader - - // We set up the mock queryAllPeers to only respond according to - // the active testcase. - answers := make(map[string]wire.Message) - queryAllPeers := func( - queryMsg wire.Message, - checkResponse func(sp *ServerPeer, resp wire.Message, - quit chan<- struct{}, peerQuit chan<- struct{}), - options ...QueryOption) { - - for p, resp := range answers { - pp, err := peer.NewOutboundPeer(&peer.Config{}, p) - if err != nil { - panic(err) - } - - sp := &ServerPeer{ - Peer: pp, - } - checkResponse(sp, resp, make(chan struct{}), make(chan struct{})) - } - } - - for _, peer := range peers { - test.filterAnswers(peer, answers) - } - - // For the CFHeaders, we pretend all peers responded with the same - // filter headers. - msg := &wire.MsgCFHeaders{ - FilterType: fType, - StopHash: stopHash, - PrevFilterHeader: prev, - } - - for i := uint32(0); i < 2*badIndex; i++ { - _ = msg.AddCFHash(&filterHash) - } - - headers := make(map[string]*wire.MsgCFHeaders) - for _, peer := range peers { - headers[peer] = msg - } - - bm := &blockManager{ - cfg: &blockManagerCfg{ - BlockHeaders: blockHeaders, - queryAllPeers: queryAllPeers, - }, - } - - // Now trying to detect which peers are bad, we should detect the - // bad ones. - badPeers, err := bm.detectBadPeers( - headers, targetIndex, badIndex, fType, - ) - if err != nil { - t.Fatalf("failed to detect bad peers: %v", err) - } - - if err := assertBadPeers(expBad, badPeers); err != nil { - t.Fatal(err) - } - } -} +//const ( +// // maxHeight is the height we will generate filter headers up to. We use an odd +// // number of checkpoints to ensure we can test cases where the block manager is +// // only able to fetch filter headers for one checkpoint interval rather than +// // two. +// maxHeight = 21 * uint32(wire.CFCheckptInterval) +// +// dbOpenTimeout = time.Second * 10 +//) +// +//// mockDispatcher implements the query.Dispatcher interface and allows us to +//// set up a custom Query method during tests. +//type mockDispatcher struct { +// query func(requests []*query.Request, +// options ...query.QueryOption) chan error +//} +// +//var _ query.Dispatcher = (*mockDispatcher)(nil) +// +//func (m *mockDispatcher) Query(requests []*query.Request, +// options ...query.QueryOption) chan error { +// +// return m.query(requests, options...) +//} +// +//// setupBlockManager initialises a blockManager to be used in tests. +//func setupBlockManager() (*blockManager, headerfs.BlockHeaderStore, +// *headerfs.FilterHeaderStore, func(), error) { +// +// // Set up the block and filter header stores. +// tempDir, err := ioutil.TempDir("", "neutrino") +// if err != nil { +// return nil, nil, nil, nil, fmt.Errorf("Failed to create "+ +// "temporary directory: %s", err) +// } +// +// db, err := walletdb.Create( +// "bdb", tempDir+"/weks.db", true, dbOpenTimeout, +// ) +// if err != nil { +// os.RemoveAll(tempDir) +// return nil, nil, nil, nil, fmt.Errorf("Error opening DB: %s", +// err) +// } +// +// cleanUp := func() { +// db.Close() +// os.RemoveAll(tempDir) +// } +// +// hdrStore, err := headerfs.NewBlockHeaderStore( +// tempDir, db, &chaincfg.SimNetParams, +// ) +// if err != nil { +// cleanUp() +// return nil, nil, nil, nil, fmt.Errorf("Error creating block "+ +// "header store: %s", err) +// } +// +// cfStore, err := headerfs.NewFilterHeaderStore( +// tempDir, db, headerfs.RegularFilter, &chaincfg.SimNetParams, +// nil, +// ) +// if err != nil { +// cleanUp() +// return nil, nil, nil, nil, fmt.Errorf("Error creating filter "+ +// "header store: %s", err) +// } +// +// // Set up a blockManager with the chain service we defined. +// bm, err := newBlockManager(&blockManagerCfg{ +// ChainParams: chaincfg.SimNetParams, +// BlockHeaders: hdrStore, +// RegFilterHeaders: cfStore, +// QueryDispatcher: &mockDispatcher{}, +// BanPeer: func(string, banman.Reason) error { return nil }, +// }) +// if err != nil { +// return nil, nil, nil, nil, fmt.Errorf("unable to create "+ +// "blockmanager: %v", err) +// } +// +// return bm, hdrStore, cfStore, cleanUp, nil +//} +// +//// headers wraps the different headers and filters used throughout the tests. +//type headers struct { +// blockHeaders []headerfs.BlockHeader +// cfHeaders []headerfs.FilterHeader +// checkpoints []*chainhash.Hash +// filterHashes []chainhash.Hash +//} +// +//// generateHeaders generates block headers, filter header and hashes, and +//// checkpoints from the given genesis. The onCheckpoint method will be called +//// with the current cf header on each checkpoint to modify the derivation of +//// the next interval. +//func generateHeaders(genesisBlockHeader *wire.BlockHeader, +// genesisFilterHeader *chainhash.Hash, +// onCheckpoint func(*chainhash.Hash)) (*headers, error) { +// +// var blockHeaders []headerfs.BlockHeader +// blockHeaders = append(blockHeaders, headerfs.BlockHeader{ +// BlockHeader: genesisBlockHeader, +// Height: 0, +// }) +// +// var cfHeaders []headerfs.FilterHeader +// cfHeaders = append(cfHeaders, headerfs.FilterHeader{ +// HeaderHash: genesisBlockHeader.BlockHash(), +// FilterHash: *genesisFilterHeader, +// Height: 0, +// }) +// +// // The filter hashes (not the filter headers!) will be sent as +// // part of the CFHeaders response, so we also keep track of +// // them. +// genesisFilter, err := builder.BuildBasicFilter( +// chaincfg.SimNetParams.GenesisBlock, nil, +// ) +// if err != nil { +// return nil, fmt.Errorf("unable to build genesis filter: %v", +// err) +// } +// +// genesisFilterHash, err := builder.GetFilterHash(genesisFilter) +// if err != nil { +// return nil, fmt.Errorf("unable to get genesis filter hash: %v", +// err) +// } +// +// var filterHashes []chainhash.Hash +// filterHashes = append(filterHashes, genesisFilterHash) +// +// // Also keep track of the current filter header. We use this to +// // calculate the next filter header, as it commits to the +// // previous. +// currentCFHeader := *genesisFilterHeader +// +// // checkpoints will be the checkpoints passed to +// // getCheckpointedCFHeaders. +// var checkpoints []*chainhash.Hash +// +// for height := uint32(1); height <= maxHeight; height++ { +// header := heightToHeader(height) +// blockHeader := headerfs.BlockHeader{ +// BlockHeader: header, +// Height: height, +// } +// +// blockHeaders = append(blockHeaders, blockHeader) +// +// // It doesn't really matter what filter the filter +// // header commit to, so just use the height as a nonce +// // for the filters. +// filterHash := chainhash.Hash{} +// binary.BigEndian.PutUint32(filterHash[:], height) +// filterHashes = append(filterHashes, filterHash) +// +// // Calculate the current filter header, and add to our +// // slice. +// currentCFHeader = chainhash.DoubleHashH( +// append(filterHash[:], currentCFHeader[:]...), +// ) +// cfHeaders = append(cfHeaders, headerfs.FilterHeader{ +// HeaderHash: header.BlockHash(), +// FilterHash: currentCFHeader, +// Height: height, +// }) +// +// // Each interval we must record a checkpoint. +// if height%wire.CFCheckptInterval == 0 { +// // We must make a copy of the current header to +// // avoid mutation. +// cfh := currentCFHeader +// checkpoints = append(checkpoints, &cfh) +// +// if onCheckpoint != nil { +// onCheckpoint(¤tCFHeader) +// } +// } +// } +// +// return &headers{ +// blockHeaders: blockHeaders, +// cfHeaders: cfHeaders, +// checkpoints: checkpoints, +// filterHashes: filterHashes, +// }, nil +//} +// +//// generateResponses generates the MsgCFHeaders messages from the given queries +//// and headers. +//func generateResponses(msgs []wire.Message, +// headers *headers) ([]*wire.MsgCFHeaders, error) { +// +// // Craft a response for each message. +// var responses []*wire.MsgCFHeaders +// for _, msg := range msgs { +// // Only GetCFHeaders expected. +// q, ok := msg.(*wire.MsgGetCFHeaders) +// if !ok { +// return nil, fmt.Errorf("got unexpected message %T", +// msg) +// } +// +// // The start height must be set to a checkpoint height+1. +// if q.StartHeight%wire.CFCheckptInterval != 1 { +// return nil, fmt.Errorf("unexpexted start height %v", +// q.StartHeight) +// } +// +// var prevFilterHeader chainhash.Hash +// switch q.StartHeight { +// // If the start height is 1 the prevFilterHeader is set to the +// // genesis header. +// case 1: +// genesisFilterHeader := headers.cfHeaders[0].FilterHash +// prevFilterHeader = genesisFilterHeader +// +// // Otherwise we use one of the created checkpoints. +// default: +// j := q.StartHeight/wire.CFCheckptInterval - 1 +// prevFilterHeader = *headers.checkpoints[j] +// } +// +// resp := &wire.MsgCFHeaders{ +// FilterType: q.FilterType, +// StopHash: q.StopHash, +// PrevFilterHeader: prevFilterHeader, +// } +// +// // Keep adding filter hashes until we reach the stop hash. +// for h := q.StartHeight; ; h++ { +// resp.FilterHashes = append( +// resp.FilterHashes, &headers.filterHashes[h], +// ) +// +// blockHash := headers.blockHeaders[h].BlockHash() +// if blockHash == q.StopHash { +// break +// } +// } +// +// responses = append(responses, resp) +// } +// +// return responses, nil +//} +// +//// TestBlockManagerInitialInterval tests that the block manager is able to +//// handle checkpointed filter header query responses in out of order, and when +//// a partial interval is already written to the store. +//func TestBlockManagerInitialInterval(t *testing.T) { +// t.Parallel() +// +// type testCase struct { +// // permute indicates whether responses should be permutated. +// permute bool +// +// // partialInterval indicates whether we should write parts of +// // the first checkpoint interval to the filter header store +// // before starting the test. +// partialInterval bool +// +// // repeat indicates whether responses should be repeated. +// repeat bool +// } +// +// // Generate all combinations of testcases. +// var testCases []testCase +// b := []bool{false, true} +// for _, perm := range b { +// for _, part := range b { +// for _, rep := range b { +// testCases = append(testCases, testCase{ +// permute: perm, +// partialInterval: part, +// repeat: rep, +// }) +// } +// } +// } +// +// for _, test := range testCases { +// test := test +// testDesc := fmt.Sprintf("permute=%v, partial=%v, repeat=%v", +// test.permute, test.partialInterval, test.repeat) +// +// bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() +// if err != nil { +// t.Fatalf("unable to set up ChainService: %v", err) +// } +// defer cleanUp() +// +// // Keep track of the filter headers and block headers. Since +// // the genesis headers are written automatically when the store +// // is created, we query it to add to the slices. +// genesisBlockHeader, _, err := hdrStore.ChainTip() +// if err != nil { +// t.Fatal(err) +// } +// +// genesisFilterHeader, _, err := cfStore.ChainTip() +// if err != nil { +// t.Fatal(err) +// } +// +// headers, err := generateHeaders(genesisBlockHeader, +// genesisFilterHeader, nil) +// if err != nil { +// t.Fatalf("unable to generate headers: %v", err) +// } +// +// // Write all block headers but the genesis, since it is already +// // in the store. +// if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { +// t.Fatalf("Error writing batch of headers: %s", err) +// } +// +// // We emulate the case where a few filter headers are already +// // written to the store by writing 1/3 of the first interval. +// if test.partialInterval { +// err = cfStore.WriteHeaders( +// headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., +// ) +// if err != nil { +// t.Fatalf("Error writing batch of headers: %s", +// err) +// } +// } +// +// // We set up a custom query batch method for this test, as we +// // will use this to feed the blockmanager with our crafted +// // responses. +// bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( +// requests []*query.Request, +// options ...query.QueryOption) chan error { +// +// var msgs []wire.Message +// for _, q := range requests { +// msgs = append(msgs, q.Req) +// } +// +// responses, err := generateResponses(msgs, headers) +// if err != nil { +// t.Fatalf("unable to generate responses: %v", +// err) +// } +// +// // We permute the response order if the test signals +// // that. +// perm := rand.Perm(len(responses)) +// +// errChan := make(chan error, 1) +// go func() { +// for i, v := range perm { +// index := i +// if test.permute { +// index = v +// } +// +// // Before handling the response we take +// // copies of the message, as we cannot +// // guarantee that it won't be modified. +// resp := *responses[index] +// resp2 := *responses[index] +// +// // Let the blockmanager handle the +// // message. +// progress := requests[index].HandleResp( +// msgs[index], &resp, "", +// ) +// +// if !progress.Finished { +// errChan <- fmt.Errorf("got "+ +// "response false on "+ +// "send of index %d: %v", +// index, testDesc) +// return +// } +// +// // If we are not testing repeated +// // responses, go on to the next +// // response. +// if !test.repeat { +// continue +// } +// +// // Otherwise resend the response we +// // just sent. +// progress = requests[index].HandleResp( +// msgs[index], &resp2, "", +// ) +// if !progress.Finished { +// errChan <- fmt.Errorf("got "+ +// "response false on "+ +// "resend of index %d: "+ +// "%v", index, testDesc) +// return +// } +// } +// errChan <- nil +// }() +// +// return errChan +// } +// +// // We should expect to see notifications for each new filter +// // header being connected. +// startHeight := uint32(1) +// if test.partialInterval { +// startHeight = wire.CFCheckptInterval / 3 +// } +// go func() { +// for i := startHeight; i <= maxHeight; i++ { +// ntfn := <-bm.blockNtfnChan +// if _, ok := ntfn.(*blockntfns.Connected); !ok { +// t.Error("expected block connected " + +// "notification") +// return +// } +// } +// }() +// +// // Call the get checkpointed cf headers method with the +// // checkpoints we created to start the test. +// bm.getCheckpointedCFHeaders( +// headers.checkpoints, cfStore, wire.GCSFilterRegular, +// ) +// +// // Finally make sure the filter header tip is what we expect. +// tip, tipHeight, err := cfStore.ChainTip() +// if err != nil { +// t.Fatal(err) +// } +// +// if tipHeight != maxHeight { +// t.Fatalf("expected tip height to be %v, was %v", +// maxHeight, tipHeight) +// } +// +// lastCheckpoint := headers.checkpoints[len(headers.checkpoints)-1] +// if *tip != *lastCheckpoint { +// t.Fatalf("expected tip to be %v, was %v", +// lastCheckpoint, tip) +// } +// } +//} +// +//// TestBlockManagerInvalidInterval tests that the block manager is able to +//// determine it is receiving corrupt checkpoints and filter headers. +//func TestBlockManagerInvalidInterval(t *testing.T) { +// t.Parallel() +// +// type testCase struct { +// // wrongGenesis indicates whether we should start deriving the +// // filters from a wrong genesis. +// wrongGenesis bool +// +// // intervalMisaligned indicates whether each interval prev hash +// // should not line up with the previous checkpoint. +// intervalMisaligned bool +// +// // invalidPrevHash indicates whether the interval responses +// // should have a prev hash that doesn't mathc that interval. +// invalidPrevHash bool +// +// // partialInterval indicates whether we should write parts of +// // the first checkpoint interval to the filter header store +// // before starting the test. +// partialInterval bool +// +// // firstInvalid is the first interval response we expect the +// // blockmanager to determine is invalid. +// firstInvalid int +// } +// +// testCases := []testCase{ +// // With a set of checkpoints and filter headers calculated from +// // the wrong genesis, the block manager should be able to +// // determine that the first interval doesn't line up. +// { +// wrongGenesis: true, +// firstInvalid: 0, +// }, +// +// // With checkpoints calculated from the wrong genesis, and a +// // partial set of filter headers already written, the first +// // interval response should be considered invalid. +// { +// wrongGenesis: true, +// partialInterval: true, +// firstInvalid: 0, +// }, +// +// // With intervals not lining up, the second interval response +// // should be determined invalid. +// { +// intervalMisaligned: true, +// firstInvalid: 0, +// }, +// +// // With misaligned intervals and a partial interval written, the +// // second interval response should be considered invalid. +// { +// intervalMisaligned: true, +// partialInterval: true, +// firstInvalid: 0, +// }, +// +// // With responses having invalid prev hashes, the second +// // interval response should be deemed invalid. +// { +// invalidPrevHash: true, +// firstInvalid: 1, +// }, +// } +// +// for _, test := range testCases { +// test := test +// bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() +// if err != nil { +// t.Fatalf("unable to set up ChainService: %v", err) +// } +// defer cleanUp() +// +// // Create a mock peer to prevent panics when attempting to ban +// // a peer that served an invalid filter header. +// mockPeer := NewServerPeer(&ChainService{}, false) +// mockPeer.Peer, err = peer.NewOutboundPeer( +// NewPeerConfig(mockPeer), "127.0.0.1:8333", +// ) +// if err != nil { +// t.Fatal(err) +// } +// +// // Keep track of the filter headers and block headers. Since +// // the genesis headers are written automatically when the store +// // is created, we query it to add to the slices. +// genesisBlockHeader, _, err := hdrStore.ChainTip() +// if err != nil { +// t.Fatal(err) +// } +// +// genesisFilterHeader, _, err := cfStore.ChainTip() +// if err != nil { +// t.Fatal(err) +// } +// // To emulate a full node serving us filter headers derived +// // from different genesis than what we have, we flip a bit in +// // the genesis filter header. +// if test.wrongGenesis { +// genesisFilterHeader[0] ^= 1 +// } +// +// headers, err := generateHeaders(genesisBlockHeader, +// genesisFilterHeader, +// func(currentCFHeader *chainhash.Hash) { +// // If we are testing that each interval doesn't +// // line up properly with the previous, we flip +// // a bit in the current header before +// // calculating the next interval checkpoint. +// if test.intervalMisaligned { +// currentCFHeader[0] ^= 1 +// } +// }) +// if err != nil { +// t.Fatalf("unable to generate headers: %v", err) +// } +// +// // Write all block headers but the genesis, since it is already +// // in the store. +// if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { +// t.Fatalf("Error writing batch of headers: %s", err) +// } +// +// // We emulate the case where a few filter headers are already +// // written to the store by writing 1/3 of the first interval. +// if test.partialInterval { +// err = cfStore.WriteHeaders( +// headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., +// ) +// if err != nil { +// t.Fatalf("Error writing batch of headers: %s", +// err) +// } +// } +// +// bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( +// requests []*query.Request, +// options ...query.QueryOption) chan error { +// +// var msgs []wire.Message +// for _, q := range requests { +// msgs = append(msgs, q.Req) +// } +// responses, err := generateResponses(msgs, headers) +// if err != nil { +// t.Fatalf("unable to generate responses: %v", +// err) +// } +// +// // Since we used the generated checkpoints when +// // creating the responses, we must flip the +// // PrevFilterHeader bit back before sending them if we +// // are checking for misaligned intervals. This to +// // ensure we don't hit the invalid prev hash case. +// if test.intervalMisaligned { +// for i := range responses { +// if i == 0 { +// continue +// } +// responses[i].PrevFilterHeader[0] ^= 1 +// } +// } +// +// // If we are testing for intervals with invalid prev +// // hashes, we flip a bit to corrup them, regardless of +// // whether we are testing misaligned intervals. +// if test.invalidPrevHash { +// for i := range responses { +// if i == 0 { +// continue +// } +// responses[i].PrevFilterHeader[1] ^= 1 +// } +// } +// +// errChan := make(chan error, 1) +// go func() { +// // Check that the success of the callback match what we +// // expect. +// for i := range responses { +// progress := requests[i].HandleResp( +// msgs[i], responses[i], "", +// ) +// if i == test.firstInvalid { +// if progress.Finished { +// t.Errorf("expected interval "+ +// "%d to be invalid", i) +// return +// } +// errChan <- fmt.Errorf("invalid interval") +// break +// } +// +// if !progress.Finished { +// t.Errorf("expected interval %d to be "+ +// "valid", i) +// return +// } +// } +// +// errChan <- nil +// }() +// +// return errChan +// } +// +// // We should expect to see notifications for each new filter +// // header being connected. +// startHeight := uint32(1) +// if test.partialInterval { +// startHeight = wire.CFCheckptInterval / 3 +// } +// go func() { +// for i := startHeight; i <= maxHeight; i++ { +// ntfn := <-bm.blockNtfnChan +// if _, ok := ntfn.(*blockntfns.Connected); !ok { +// t.Error("expected block connected " + +// "notification") +// return +// } +// } +// }() +// +// // Start the test by calling the get checkpointed cf headers +// // method with the checkpoints we created. +// bm.getCheckpointedCFHeaders( +// headers.checkpoints, cfStore, wire.GCSFilterRegular, +// ) +// } +//} +// +//// buildNonPushScriptFilter creates a CFilter with all output scripts except all +//// OP_RETURNS with push-only scripts. +//// +//// NOTE: this is not a valid filter, only for tests. +//func buildNonPushScriptFilter(block *wire.MsgBlock) (*gcs.Filter, error) { +// blockHash := block.BlockHash() +// b := builder.WithKeyHash(&blockHash) +// +// for _, tx := range block.Transactions { +// for _, txOut := range tx.TxOut { +// // The old version of BIP-158 skipped OP_RETURNs that +// // had a push-only script. +// if txOut.PkScript[0] == txscript.OP_RETURN && +// txscript.IsPushOnlyScript(txOut.PkScript[1:]) { +// +// continue +// } +// +// b.AddEntry(txOut.PkScript) +// } +// } +// +// return b.Build() +//} +// +//// buildAllPkScriptsFilter creates a CFilter with all output scripts, including +//// OP_RETURNS. +//// +//// NOTE: this is not a valid filter, only for tests. +//func buildAllPkScriptsFilter(block *wire.MsgBlock) (*gcs.Filter, error) { +// blockHash := block.BlockHash() +// b := builder.WithKeyHash(&blockHash) +// +// for _, tx := range block.Transactions { +// for _, txOut := range tx.TxOut { +// // An old version of BIP-158 included all output +// // scripts. +// b.AddEntry(txOut.PkScript) +// } +// } +// +// return b.Build() +//} +// +//func assertBadPeers(expBad map[string]struct{}, badPeers []string) error { +// remBad := make(map[string]struct{}) +// for p := range expBad { +// remBad[p] = struct{}{} +// } +// for _, peer := range badPeers { +// _, ok := remBad[peer] +// if !ok { +// return fmt.Errorf("did not expect %v to be bad", peer) +// } +// delete(remBad, peer) +// } +// +// if len(remBad) != 0 { +// return fmt.Errorf("did expect more bad peers") +// } +// +// return nil +//} +// +//// TestBlockManagerDetectBadPeers checks that we detect bad peers, like peers +//// not responding to our filter query, serving inconsistent filters etc. +//func TestBlockManagerDetectBadPeers(t *testing.T) { +// t.Parallel() +// +// var ( +// stopHash = chainhash.Hash{} +// prev = chainhash.Hash{} +// startHeight = uint32(100) +// badIndex = uint32(5) +// targetIndex = startHeight + badIndex +// fType = wire.GCSFilterRegular +// filterBytes, _ = correctFilter.NBytes() +// filterHash, _ = builder.GetFilterHash(correctFilter) +// blockHeader = wire.BlockHeader{} +// targetBlockHash = block.BlockHash() +// +// peers = []string{"good1:1", "good2:1", "bad:1", "good3:1"} +// expBad = map[string]struct{}{ +// "bad:1": {}, +// } +// ) +// +// testCases := []struct { +// // filterAnswers is used by each testcase to set the anwers we +// // want each peer to respond with on filter queries. +// filterAnswers func(string, map[string]wire.Message) +// }{ +// { +// // We let the "bad" peers not respond to the filter +// // query. They should be marked bad because they are +// // unresponsive. We do this to ensure peers cannot +// // only respond to us with headers, and stall our sync +// // by not responding to filter requests. +// filterAnswers: func(p string, +// answers map[string]wire.Message) { +// +// if strings.Contains(p, "bad") { +// return +// } +// +// answers[p] = wire.NewMsgCFilter( +// fType, &targetBlockHash, filterBytes, +// ) +// }, +// }, +// { +// // We let the "bad" peers serve filters that don't hash +// // to the filter headers they have sent. +// filterAnswers: func(p string, +// answers map[string]wire.Message) { +// +// filterData := filterBytes +// if strings.Contains(p, "bad") { +// filterData, _ = fakeFilter1.NBytes() +// } +// +// answers[p] = wire.NewMsgCFilter( +// fType, &targetBlockHash, filterData, +// ) +// }, +// }, +// } +// +// for _, test := range testCases { +// // Create a mock block header store. We only need to be able to +// // serve a header for the target index. +// blockHeaders := newMockBlockHeaderStore() +// blockHeaders.heights[targetIndex] = blockHeader +// +// // We set up the mock queryAllPeers to only respond according to +// // the active testcase. +// answers := make(map[string]wire.Message) +// queryAllPeers := func( +// queryMsg wire.Message, +// checkResponse func(sp *ServerPeer, resp wire.Message, +// quit chan<- struct{}, peerQuit chan<- struct{}), +// options ...QueryOption) { +// +// for p, resp := range answers { +// pp, err := peer.NewOutboundPeer(&peer.Config{}, p) +// if err != nil { +// panic(err) +// } +// +// sp := &ServerPeer{ +// Peer: pp, +// } +// checkResponse(sp, resp, make(chan struct{}), make(chan struct{})) +// } +// } +// +// for _, peer := range peers { +// test.filterAnswers(peer, answers) +// } +// +// // For the CFHeaders, we pretend all peers responded with the same +// // filter headers. +// msg := &wire.MsgCFHeaders{ +// FilterType: fType, +// StopHash: stopHash, +// PrevFilterHeader: prev, +// } +// +// for i := uint32(0); i < 2*badIndex; i++ { +// _ = msg.AddCFHash(&filterHash) +// } +// +// headers := make(map[string]*wire.MsgCFHeaders) +// for _, peer := range peers { +// headers[peer] = msg +// } +// +// bm := &blockManager{ +// cfg: &blockManagerCfg{ +// BlockHeaders: blockHeaders, +// queryAllPeers: queryAllPeers, +// }, +// } +// +// // Now trying to detect which peers are bad, we should detect the +// // bad ones. +// badPeers, err := bm.detectBadPeers( +// headers, targetIndex, badIndex, fType, +// ) +// if err != nil { +// t.Fatalf("failed to detect bad peers: %v", err) +// } +// +// if err := assertBadPeers(expBad, badPeers); err != nil { +// t.Fatal(err) +// } +// } +// } diff --git a/headerfs/store.go b/headerfs/store.go index a01812a5..0dac5b8b 100644 --- a/headerfs/store.go +++ b/headerfs/store.go @@ -74,6 +74,10 @@ type BlockHeaderStore interface { // The information about the new header tip after truncation is // returned. RollbackLastBlock() (*BlockStamp, error) + + // BlockLocatorFromHeight returns the block locator object based on the height + // supplied as argument to the function. + BlockLocatorFromHeight(uint32) (blockchain.BlockLocator, error) } // headerBufPool is a pool of bytes.Buffer that will be re-used by the various @@ -482,6 +486,25 @@ func (h *blockHeaderStore) LatestBlockLocator() (blockchain.BlockLocator, error) return h.blockLocatorFromHash(chainTipHash) } +// BlockLocatorFromHeight returns the block locator object based on the height +// supplied as argument to the function. +// +// NOTE: Part of the BlockHeaderStore interface. +func (h *blockHeaderStore) BlockLocatorFromHeight(height uint32) (blockchain.BlockLocator, error) { + // Lock store for read. + h.mtx.RLock() + defer h.mtx.RUnlock() + + blockheader, err := h.FetchHeaderByHeight(height) + if err != nil { + return nil, err + } + + blockHash := blockheader.BlockHash() + + return h.blockLocatorFromHash(&blockHash) +} + // BlockLocatorFromHash computes a block locator given a particular hash. The // standard Bitcoin algorithm to compute block locators are employed. func (h *blockHeaderStore) BlockLocatorFromHash(hash *chainhash.Hash) ( diff --git a/mock_store.go b/mock_store.go index a94a391c..e4ec5216 100644 --- a/mock_store.go +++ b/mock_store.go @@ -1,84 +1,85 @@ package neutrino -import ( - "fmt" - - "github.com/btcsuite/btcd/blockchain" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/lightninglabs/neutrino/headerfs" -) - -// mockBlockHeaderStore is an implementation of the BlockHeaderStore backed by -// a simple map. -type mockBlockHeaderStore struct { - headers map[chainhash.Hash]wire.BlockHeader - heights map[uint32]wire.BlockHeader -} - -// A compile-time check to ensure the mockBlockHeaderStore adheres to the -// BlockHeaderStore interface. -var _ headerfs.BlockHeaderStore = (*mockBlockHeaderStore)(nil) - -// NewMockBlockHeaderStore returns a version of the BlockHeaderStore that's -// backed by an in-memory map. This instance is meant to be used by callers -// outside the package to unit test components that require a BlockHeaderStore -// interface. -func newMockBlockHeaderStore() *mockBlockHeaderStore { - return &mockBlockHeaderStore{ - headers: make(map[chainhash.Hash]wire.BlockHeader), - heights: make(map[uint32]wire.BlockHeader), - } -} - -func (m *mockBlockHeaderStore) ChainTip() (*wire.BlockHeader, - uint32, error) { - - return nil, 0, nil -} -func (m *mockBlockHeaderStore) LatestBlockLocator() ( - blockchain.BlockLocator, error) { - - return nil, nil -} - -func (m *mockBlockHeaderStore) FetchHeaderByHeight(height uint32) ( - *wire.BlockHeader, error) { - - if header, ok := m.heights[height]; ok { - return &header, nil - } - - return nil, headerfs.ErrHeightNotFound -} - -func (m *mockBlockHeaderStore) FetchHeaderAncestors(uint32, - *chainhash.Hash) ([]wire.BlockHeader, uint32, error) { - - return nil, 0, nil -} -func (m *mockBlockHeaderStore) HeightFromHash(*chainhash.Hash) (uint32, error) { - return 0, nil -} -func (m *mockBlockHeaderStore) RollbackLastBlock() (*headerfs.BlockStamp, - error) { - - return nil, nil -} - -func (m *mockBlockHeaderStore) FetchHeader(h *chainhash.Hash) ( - *wire.BlockHeader, uint32, error) { - - if header, ok := m.headers[*h]; ok { - return &header, 0, nil - } - return nil, 0, fmt.Errorf("not found") -} - -func (m *mockBlockHeaderStore) WriteHeaders(headers ...headerfs.BlockHeader) error { - for _, h := range headers { - m.headers[h.BlockHash()] = *h.BlockHeader - } - - return nil -} +// +//import ( +// "fmt" +// +// "github.com/btcsuite/btcd/blockchain" +// "github.com/btcsuite/btcd/chaincfg/chainhash" +// "github.com/btcsuite/btcd/wire" +// "github.com/lightninglabs/neutrino/headerfs" +//) +// +//// mockBlockHeaderStore is an implementation of the BlockHeaderStore backed by +//// a simple map. +//type mockBlockHeaderStore struct { +// headers map[chainhash.Hash]wire.BlockHeader +// heights map[uint32]wire.BlockHeader +//} +// +//// A compile-time check to ensure the mockBlockHeaderStore adheres to the +//// BlockHeaderStore interface. +//var _ headerfs.BlockHeaderStore = (*mockBlockHeaderStore)(nil) +// +//// NewMockBlockHeaderStore returns a version of the BlockHeaderStore that's +//// backed by an in-memory map. This instance is meant to be used by callers +//// outside the package to unit test components that require a BlockHeaderStore +//// interface. +//func newMockBlockHeaderStore() *mockBlockHeaderStore { +// return &mockBlockHeaderStore{ +// headers: make(map[chainhash.Hash]wire.BlockHeader), +// heights: make(map[uint32]wire.BlockHeader), +// } +//} +// +//func (m *mockBlockHeaderStore) ChainTip() (*wire.BlockHeader, +// uint32, error) { +// +// return nil, 0, nil +//} +//func (m *mockBlockHeaderStore) LatestBlockLocator() ( +// blockchain.BlockLocator, error) { +// +// return nil, nil +//} +// +//func (m *mockBlockHeaderStore) FetchHeaderByHeight(height uint32) ( +// *wire.BlockHeader, error) { +// +// if header, ok := m.heights[height]; ok { +// return &header, nil +// } +// +// return nil, headerfs.ErrHeightNotFound +//} +// +//func (m *mockBlockHeaderStore) FetchHeaderAncestors(uint32, +// *chainhash.Hash) ([]wire.BlockHeader, uint32, error) { +// +// return nil, 0, nil +//} +//func (m *mockBlockHeaderStore) HeightFromHash(*chainhash.Hash) (uint32, error) { +// return 0, nil +//} +//func (m *mockBlockHeaderStore) RollbackLastBlock() (*headerfs.BlockStamp, +// error) { +// +// return nil, nil +//} +// +//func (m *mockBlockHeaderStore) FetchHeader(h *chainhash.Hash) ( +// *wire.BlockHeader, uint32, error) { +// +// if header, ok := m.headers[*h]; ok { +// return &header, 0, nil +// } +// return nil, 0, fmt.Errorf("not found") +//} +// +//func (m *mockBlockHeaderStore) WriteHeaders(headers ...headerfs.BlockHeader) error { +// for _, h := range headers { +// m.headers[h.BlockHash()] = *h.BlockHeader +// } +// +// return nil +//} diff --git a/neutrino.go b/neutrino.go index 54e0385d..77195627 100644 --- a/neutrino.go +++ b/neutrino.go @@ -532,6 +532,40 @@ func (sp *ServerPeer) SubscribeRecvMsg() (<-chan wire.Message, func()) { } } +// TODO: Unexport HeaderQuery +func (sp *ServerPeer) QueryGetHeadersMsg(req interface{}) error { + + queryGetHeaders, ok := req.(HeaderQuery) + + if !ok { + return errors.New("request is not type HeaderQuery") + } + err := sp.PushGetHeadersMsg(queryGetHeaders.Locator, queryGetHeaders.StopHash) + + if err != nil { + return err + } + + return nil +} + +func (sp *ServerPeer) IsPeerBehindStartHeight(req interface{}) bool { + queryGetHeaders, ok := req.(*HeaderQuery) + if !ok { + log.Tracef("request is not type HeaderQuery") + + return true + } + if sp.LastBlock() < queryGetHeaders.StartHeight { + + return false + + } + + return true + +} + // OnDisconnect returns a channel that will be closed when this peer is // disconnected. // @@ -745,6 +779,7 @@ func NewChainService(cfg Config) (*ChainService, error) { ConnectedPeers: s.ConnectedPeers, NewWorker: query.NewWorker, Ranking: query.NewPeerRanking(), + TestNewWorker: query.TestNewWorker, }) // We set the queryPeers method to point to queryChainServicePeers, diff --git a/query/interface.go b/query/interface.go index 0fbffeae..244ad1d8 100644 --- a/query/interface.go +++ b/query/interface.go @@ -112,6 +112,24 @@ type Request struct { HandleResp func(req, resp wire.Message, peer string) Progress } +type TestRequest struct { + // Req is the message request to send. + Req interface{} + + // HandleResp is a response handler that will be called for every + // message received from the peer that the request was made to. It + // should validate the response against the request made, and return a + // Progress indicating whether the request was answered by this + // particular response. + // + // NOTE: Since the worker's job queue will be stalled while this method + // is running, it should not be doing any expensive operations. It + // should validate the response and immediately return the progress. + // The response should be handed off to another goroutine for + // processing. + HandleResp func(resp wire.Message, peer TestPeer) Progress +} + // Dispatcher is an interface defining the API for dispatching queries to // bitcoin peers. type Dispatcher interface { @@ -120,6 +138,7 @@ type Dispatcher interface { // batch of queries will be sent. Responses for the individual queries // should be handled by the response handler of each Request. Query(reqs []*Request, options ...QueryOption) chan error + TestQuery(reqs []*TestRequest, options ...QueryOption) chan error } // Peer is the interface that defines the methods needed by the query package @@ -143,3 +162,27 @@ type Peer interface { // disconnected. OnDisconnect() <-chan struct{} } + +type TestPeer interface { + // QueueMessageWithEncoding adds the passed bitcoin message to the peer + // send queue. + QueueMessageWithEncoding(msg wire.Message, doneChan chan<- struct{}, + encoding wire.MessageEncoding) + + // SubscribeRecvMsg adds a OnRead subscription to the peer. All bitcoin + // messages received from this peer will be sent on the returned + // channel. A closure is also returned, that should be called to cancel + // the subscription. + SubscribeRecvMsg() (<-chan wire.Message, func()) + + // Addr returns the address of this peer. + Addr() string + + // OnDisconnect returns a channel that will be closed when this peer is + // disconnected. + OnDisconnect() <-chan struct{} + + QueryGetHeadersMsg(req interface{}) error + + IsPeerBehindStartHeight(req interface{}) bool +} diff --git a/query/peer_rank.go b/query/peer_rank.go index 77c6c8f4..2aceb25e 100644 --- a/query/peer_rank.go +++ b/query/peer_rank.go @@ -2,6 +2,7 @@ package query import ( "sort" + "sync" ) const ( @@ -22,7 +23,8 @@ const ( type peerRanking struct { // rank keeps track of the current set of peers and their score. A // lower score is better. - rank map[string]uint64 + rank map[string]uint64 + mutex sync.RWMutex } // A compile time check to ensure peerRanking satisfies the PeerRanking @@ -40,12 +42,15 @@ func NewPeerRanking() PeerRanking { // peer has no current score given, the default will be used. func (p *peerRanking) Order(peers []string) { sort.Slice(peers, func(i, j int) bool { + p.mutex.RLock() score1, ok := p.rank[peers[i]] + p.mutex.RUnlock() if !ok { score1 = defaultScore } - + p.mutex.RLock() score2, ok := p.rank[peers[j]] + p.mutex.RUnlock() if !ok { score2 = defaultScore } @@ -55,15 +60,26 @@ func (p *peerRanking) Order(peers []string) { // AddPeer adds a new peer to the ranking, starting out with the default score. func (p *peerRanking) AddPeer(peer string) { + + p.mutex.RLock() if _, ok := p.rank[peer]; ok { + + p.mutex.RUnlock() return } + + p.mutex.RUnlock() + p.mutex.Lock() p.rank[peer] = defaultScore + + p.mutex.Unlock() } // Punish increases the score of the given peer. func (p *peerRanking) Punish(peer string) { + p.mutex.RLock() score, ok := p.rank[peer] + p.mutex.RUnlock() if !ok { return } @@ -72,14 +88,17 @@ func (p *peerRanking) Punish(peer string) { if score == worstScore { return } - + p.mutex.Lock() p.rank[peer] = score + 1 + p.mutex.Unlock() } // Reward decreases the score of the given peer. // TODO(halseth): use actual response time when ranking peers. func (p *peerRanking) Reward(peer string) { + p.mutex.RLock() score, ok := p.rank[peer] + p.mutex.RUnlock() if !ok { return } @@ -88,6 +107,7 @@ func (p *peerRanking) Reward(peer string) { if score == bestScore { return } - + p.mutex.Lock() p.rank[peer] = score - 1 + p.mutex.Unlock() } diff --git a/query/worker.go b/query/worker.go index 1477c1f3..bbdba7da 100644 --- a/query/worker.go +++ b/query/worker.go @@ -31,6 +31,15 @@ type queryJob struct { *Request } +// TODO(Maureen): Remove!! +type testQueryJob struct { + index uint64 + timeout time.Duration + encoding wire.MessageEncoding + cancelChan <-chan struct{} + *TestRequest +} + // queryJob should satisfy the Task interface in order to be sorted by the // workQueue. var _ Task = (*queryJob)(nil) @@ -42,6 +51,10 @@ func (q *queryJob) Index() uint64 { return q.index } +func (q *testQueryJob) Index() uint64 { + return q.index +} + // jobResult is the final result of the worker's handling of the queryJob. type jobResult struct { job *queryJob @@ -49,6 +62,12 @@ type jobResult struct { err error } +type testJobResult struct { + job *testQueryJob + peer Peer + err error +} + // worker is responsible for polling work from its work queue, and handing it // to the associated peer. It validates incoming responses with the current // query's response handler, and polls more work for the peer when it has @@ -61,6 +80,12 @@ type worker struct { nextJob chan *queryJob } +type testWorker struct { + peer TestPeer + + nextJob chan *testQueryJob +} + // A compile-time check to ensure worker satisfies the Worker interface. var _ Worker = (*worker)(nil) @@ -72,6 +97,18 @@ func NewWorker(peer Peer) Worker { } } +func (w *worker) Peer() TestPeer { + + return nil +} + +func TestNewWorker(peer TestPeer) *testWorker { + return &testWorker{ + peer: peer, + nextJob: make(chan *testQueryJob), + } +} + // Run starts the worker. The worker will supply its peer with queries, and // handle responses from it. Results for any query handled by this worker will // be delivered on the results channel. quit can be closed to immediately make @@ -81,6 +118,12 @@ func NewWorker(peer Peer) Worker { // until the peer disconnects or the worker is told to quit. // // NOTE: Part of the Worker interface. + +func (w *testWorker) Peer() TestPeer { + + return w.peer +} + func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { peer := w.peer @@ -240,6 +283,168 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { } } +func (w *testWorker) Run(results chan<- *testJobResult, quit <-chan struct{}) { + + peer := w.peer + log.Infof("Testworker is running for, %v", peer.Addr()) + // Subscribe to messages from the peer. + msgChan, cancel := peer.SubscribeRecvMsg() + defer cancel() + + for { + log.Debugf("Worker %v waiting for more work", peer.Addr()) + + var job *testQueryJob + select { + // Poll a new job from the nextJob channel. + case job = <-w.nextJob: + log.Tracef("Worker %v picked up job with index %v", + peer.Addr(), job.Index()) + + // If the peer disconnected, we can exit immediately, as we + // weren't working on a query. + case <-peer.OnDisconnect(): + log.Debugf("Peer %v for worker disconnected", + peer.Addr()) + return + + case <-quit: + return + } + + select { + // There is no point in queueing the request if the job already + // is canceled, so we check this quickly. + case <-job.cancelChan: + log.Tracef("Worker %v found job with index %v "+ + "already canceled", peer.Addr(), job.Index()) + + // We break to the below loop, where we'll check the + // cancel channel again and the ErrJobCanceled + // result will be sent back. + break + + // We received a non-canceled query job, send it to the peer. + default: + log.Tracef("Worker %v queuing job %T with index %v", + peer.Addr(), job.Req, job.Index()) + + err := peer.QueryGetHeadersMsg(job.Req) + if err != nil { + log.Debugf("Peer %v could not push GetHeaders Msg: %v", + peer.Addr(), err) + return + } + + } + + // Wait for the correct response to be received from the peer, + // or an error happening. + var ( + jobErr error + timeout = time.NewTimer(job.timeout) + ) + + Loop: + for { + select { + // A message was received from the peer, use the + // response handler to check whether it was answering + // our request. + case resp := <-msgChan: + log.Debugf("Gotten message from %v", peer.Addr()) + progress := job.HandleResp(resp, peer) + + log.Debugf("Worker %v handled msg %T while "+ + "waiting for response to %T (job=%v). "+ + "Finished=%v, progressed=%v", + peer.Addr(), resp, job.Req, job.Index(), + progress.Finished, progress.Progressed) + + // If the response did not answer our query, we + // check whether it did progress it. + if !progress.Finished { + // If it did make progress we reset the + // timeout. This ensures that the + // queries with multiple responses + // expected won't timeout before all + // responses have been handled. + // TODO(halseth): separate progress + // timeout value. + if progress.Progressed { + timeout.Stop() + timeout = time.NewTimer( + job.timeout, + ) + } + log.Debugf("Continuing Loop") + continue Loop + } + + // We did get a valid response, and can break + // the loop. + break Loop + + // If the timeout is reached before a valid response + // has been received, we exit with an error. + case <-timeout.C: + // The query did experience a timeout and will + // be given to someone else. + jobErr = ErrQueryTimeout + log.Debugf("Worker %v timeout for request %T "+ + "with job index %v", peer.Addr(), + job.Req, job.Index()) + + break Loop + + // If the peer disconnects before giving us a valid + // answer, we'll also exit with an error. + case <-peer.OnDisconnect(): + log.Debugf("Peer %v for worker disconnected, "+ + "cancelling job %v", peer.Addr(), + job.Index()) + + jobErr = ErrPeerDisconnected + break Loop + + // If the job was canceled, we report this back to the + // work manager. + case <-job.cancelChan: + log.Tracef("Worker %v job %v canceled", + peer.Addr(), job.Index()) + + jobErr = ErrJobCanceled + break Loop + + case <-quit: + return + } + } + + // Stop to allow garbage collection. + timeout.Stop() + + // We have a result ready for the query, hand it off before + // getting a new job. + log.Debugf("Trying to send result to workmanager") + select { + case results <- &testJobResult{ + job: job, + peer: peer, + err: jobErr, + }: + log.Debugf("Test -- Sent result to workmanager") + case <-quit: + return + } + log.Debugf("Test -- Out of Select") + // If the peer disconnected, we can exit immediately. + if jobErr == ErrPeerDisconnected { + return + } + } +} + // NewJob returns a channel where work that is to be handled by the worker can // be sent. If the worker reads a queryJob from this channel, it is guaranteed // that a response will eventually be deliverd on the results channel (except @@ -249,3 +454,8 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { func (w *worker) NewJob() chan<- *queryJob { return w.nextJob } + +func (w *testWorker) NewJob() chan<- *testQueryJob { + + return w.nextJob +} diff --git a/query/workmanager.go b/query/workmanager.go index ba44a1fa..c88197b9 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -47,6 +47,8 @@ type Worker interface { // delivered on the results channel (except when the quit channel has // been closed). NewJob() chan<- *queryJob + + Peer() TestPeer } // PeerRanking is an interface that must be satisfied by the underlying module @@ -73,9 +75,11 @@ type PeerRanking interface { // we have given to it. // TODO(halseth): support more than one active job at a time. type activeWorker struct { - w Worker - activeJob *queryJob - onExit chan struct{} + w Worker + activeJob *queryJob + onExit chan struct{} + testActiveJob *testQueryJob + tw testWorker } // Config holds the configuration options for a new WorkManager. @@ -94,6 +98,8 @@ type Config struct { // Ranking is used to rank the connected peers when determining who to // give work to. Ranking PeerRanking + + TestNewWorker func(TestPeer) *testWorker } // WorkManager is the main access point for outside callers, and satisfies the @@ -110,8 +116,13 @@ type WorkManager struct { // workers will be sent. jobResults chan *jobResult - quit chan struct{} - wg sync.WaitGroup + //TODO(maureen): remove + test chan *testbatch + NewWorker func(Peer) Worker + + quit chan struct{} + wg sync.WaitGroup + testJobResults chan *testJobResult } // Compile time check to ensure WorkManager satisfies the Dispatcher interface. @@ -124,13 +135,22 @@ func New(cfg *Config) *WorkManager { newBatches: make(chan *batch), jobResults: make(chan *jobResult), quit: make(chan struct{}), + test: make(chan *testbatch), } } +var testWork = &workQueue{} +var testWorkers = make(map[string]*activeWorker) +var testRWMutex sync.RWMutex + // Start starts the WorkManager. func (w *WorkManager) Start() error { - w.wg.Add(1) + heap.Init(testWork) + testRWMutex = sync.RWMutex{} + w.wg.Add(3) go w.workDispatcher() + go w.testDistributeWork() + go w.testWorkDispatcher() return nil } @@ -389,6 +409,8 @@ Loop: } } + // A new batch of queries where scheduled. + // A new batch of queries where scheduled. case batch := <-w.newBatches: // Add all new queries in the batch to our work queue, @@ -415,11 +437,326 @@ Loop: errChan: batch.errChan, } batchIndex++ + case <-w.quit: + return + } + } +} + +// Init a work queue which will be used to sort the incoming queries in +// a first come first served fashion. We use a heap structure such +// that we can efficiently put failed queries back in the queue. + +func (w *WorkManager) testDistributeWork() { + defer w.wg.Done() + +Loop: + for { + // If the work queue is non-empty, we'll take out the first + // element in order to distribute it to a worker. + //TODO: Possible race conditon with testwork + testRWMutex.RLock() + if testWork.Len() > 0 && len(testWorkers) > 0 { + testRWMutex.RUnlock() + + next := testWork.Peek().(*testQueryJob) + + // Find the peers with free work slots available. + var eligibleWorkers []string + testRWMutex.RLock() + for p, r := range testWorkers { + + // Only one active job at a time is currently + // supported. + if r.testActiveJob != nil { + //log.Debugf("Uneligible worker: Peer has work already") + continue + } + if !r.tw.Peer().IsPeerBehindStartHeight(next.TestRequest) { + //log.Debugf("Uneligible worker: Peer behind") + continue + + } + log.Debugf("Num eligible worker: %v", len(eligibleWorkers)) + eligibleWorkers = append(eligibleWorkers, p) + } + testRWMutex.RUnlock() + + // Use the historical data to rank them. + w.cfg.Ranking.Order(eligibleWorkers) + + // Give the job to the highest ranked peer with free + // slots available. + //log.Debugf("Trying to give eligible worker work: Num eligible worker: %v", len(eligibleWorkers)) + + for _, p := range eligibleWorkers { + testRWMutex.RLock() + r := testWorkers[p] + testRWMutex.RUnlock() + + // The worker has free work slots, it should + // pick up the query. + log.Debugf("Giving Next job to peer: %v", r.tw.Peer()) + select { + case r.tw.NewJob() <- next: + log.Debugf("Sent job %v to worker %v", + next.Index(), p) + heap.Pop(testWork) + + r.testActiveJob = next + + // Go back to start of loop, to check + // if there are more jobs to + // distribute. + continue Loop + + // Remove workers no longer active. + case <-r.onExit: + testRWMutex.Lock() + delete(testWorkers, p) + testRWMutex.Unlock() + continue + + case <-w.quit: + return + } + + } + } else { + testRWMutex.RUnlock() + } + + } +} + +// TODO(maureen): Remove +func (w *WorkManager) testWorkDispatcher() { + log.Infof("Inside testWorkDispatcher") + defer w.wg.Done() + + // Get a peer subscription. We do it in this goroutine rather than + // Start to avoid a deadlock when starting the WorkManager fetches the + // peers from the server. + peersConnected, cancel, err := w.cfg.ConnectedPeers() + + if err != nil { + log.Errorf("Unable to get connected peers: %v", err) + return + } + defer cancel() + + type batchProgress struct { + timeout <-chan time.Time + rem int + errChan chan error + } + + // We set up a batch index counter to keep track of batches that still + // have queries in flight. This lets us track when all queries for a + // batch have been finished, and return an (non-)error to the caller. + batchIndex := uint64(0) + currentBatches := make(map[uint64]*batchProgress) + + // When the work dispatcher exits, we'll loop through the remaining + // batches and send on their error channel. + defer func() { + for _, b := range currentBatches { + b.errChan <- ErrWorkManagerShuttingDown + } + }() + + // We set up a counter that we'll increase with each incoming query, + // and will serve as the priority of each. In addition we map each + // query to the batch they are part of. + queryIndex := uint64(0) + currentQueries := make(map[uint64]uint64) + +Loop: + for { + // Otherwise the work queue is empty, or there are no workers + // to distribute work to, so we'll just wait for a result of a + // previous query to come back, a new peer to connect, or for a + // new batch of queries to be scheduled. + log.Debugf("In testworkispatcher") + select { + // Spin up a goroutine that runs a worker each time a peer + // connects. + case peer := <-peersConnected: + testPeer, _ := peer.(TestPeer) + + r := w.cfg.TestNewWorker(testPeer) + log.Debugf("Into it ! %v", + peer.Addr()) + // We'll create a channel that will close after the + // worker's Run method returns, to know when we can + // remove it from our set of active workers. + onExit := make(chan struct{}) + log.Debugf("About to be locked in peers connected", + peer.Addr()) + testRWMutex.Lock() + testWorkers[peer.Addr()] = &activeWorker{ + tw: *r, + activeJob: nil, + onExit: onExit, + } + testRWMutex.Unlock() + log.Debugf("Added peer %v in workers map", + peer.Addr()) + + w.cfg.Ranking.AddPeer(peer.Addr()) + + w.wg.Add(1) + go func() { + defer w.wg.Done() + defer close(onExit) + + r.Run(w.testJobResults, w.quit) + }() + + // A new result came back. + case result := <-w.testJobResults: + log.Debugf("Test -- Result for job %v received from peer %v "+ + "(err=%v)", result.job.index, + result.peer.Addr(), result.err) + + // Delete the job from the worker's active job, such + // that the slot gets opened for more work. + testRWMutex.RLock() + r := testWorkers[result.peer.Addr()] + testRWMutex.Unlock() + r.testActiveJob = nil + + // Get the index of this query's batch, and delete it + // from the map of current queries, since we don't have + // to track it anymore. We'll add it back if the result + // turns out to be an error. + batchNum := currentQueries[result.job.index] + delete(currentQueries, result.job.index) + batch := currentBatches[batchNum] + + switch { + // If the query ended because it was canceled, drop it. + case result.err == ErrJobCanceled: + log.Tracef("Query(%d) was canceled before "+ + "result was available from peer %v", + result.job.index, result.peer.Addr()) + + // If this is the first job in this batch that + // was canceled, forward the error on the + // batch's error channel. We do this since a + // cancellation applies to the whole batch. + if batch != nil { + batch.errChan <- result.err + delete(currentBatches, batchNum) + + log.Debugf("Canceled batch %v", + batchNum) + continue Loop + } + + // If the query ended with any other error, put it back + // into the work queue. + case result.err != nil: + // Punish the peer for the failed query. + w.cfg.Ranking.Punish(result.peer.Addr()) + + log.Debugf("Test -- Query(%d) from peer %v failed, "+ + "rescheduling: %v", result.job.index, + result.peer.Addr(), result.err) + + // If it was a timeout, we dynamically increase + // it for the next attempt. + if result.err == ErrQueryTimeout { + newTimeout := result.job.timeout * 2 + if newTimeout > maxQueryTimeout { + newTimeout = maxQueryTimeout + } + result.job.timeout = newTimeout + } + + heap.Push(testWork, result.job) + currentQueries[result.job.index] = batchNum + + // Otherwise we got a successful result and update the + // status of the batch this query is a part of. + default: + // Reward the peer for the successful query. + w.cfg.Ranking.Reward(result.peer.Addr()) + + // Decrement the number of queries remaining in + // the batch. + if batch != nil { + batch.rem-- + log.Tracef("Remaining jobs for batch "+ + "%v: %v ", batchNum, batch.rem) + + // If this was the last query in flight + // for this batch, we can notify that + // it finished, and delete it. + if batch.rem == 0 { + batch.errChan <- nil + delete(currentBatches, batchNum) + log.Tracef("Batch %v done", + batchNum) + continue Loop + } + } + } + + // If the total timeout for this batch has passed, + // return an error. + if batch != nil { + select { + case <-batch.timeout: + batch.errChan <- ErrQueryTimeout + delete(currentBatches, batchNum) + + log.Warnf("Query(%d) failed with "+ + "error: %v. Timing out.", + result.job.index, result.err) + + log.Debugf("Batch %v timed out", + batchNum) + + default: + } + } + + // A new batch of queries where scheduled. + case batch := <-w.test: + // Add all new queries in the batch to our work queue, + // with priority given by the order they were + // scheduled. + log.Debugf("Adding new test batch(%d) of %d queries to "+ + "work queue", batchIndex, len(batch.requests)) + + for _, q := range batch.requests { + heap.Push(testWork, &testQueryJob{ + index: queryIndex, + timeout: minQueryTimeout, + encoding: batch.options.encoding, + cancelChan: batch.options.cancelChan, + TestRequest: q, + }) + currentQueries[queryIndex] = batchIndex + queryIndex++ + } + + currentBatches[batchIndex] = &batchProgress{ + timeout: time.After(batch.options.timeout), + rem: len(batch.requests), + errChan: batch.errChan, + } + batchIndex++ case <-w.quit: return + } + + log.Debugf("Out of Select statement") } + } // Query distributes the slice of requests to the set of connected peers. @@ -446,3 +783,33 @@ func (w *WorkManager) Query(requests []*Request, return errChan } + +type testbatch struct { + requests []*TestRequest + options *queryOptions + errChan chan error +} + +func (w *WorkManager) TestQuery(requests []*TestRequest, + options ...QueryOption) chan error { + log.Debugf("Testing query") + qo := defaultQueryOptions() + qo.applyQueryOptions(options...) + + errChan := make(chan error, 1) + + // Add query messages to the queue of batches to handle. + select { + case w.test <- &testbatch{ + requests: requests, + options: qo, + errChan: errChan, + }: + log.Debugf("Sending a test batch") + case <-w.quit: + log.Debugf("Inside test query and quiting") + errChan <- ErrWorkManagerShuttingDown + } + log.Debugf("Exiting test query/") + return errChan +}