Skip to content

Commit

Permalink
Merge pull request #34 from cloudstruct/feature/refactor-expose-proto…
Browse files Browse the repository at this point in the history
…cols

Minor refactor to make library more composable
  • Loading branch information
agaffney authored Mar 7, 2022
2 parents 88740b0 + 93d9413 commit d838151
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 137 deletions.
18 changes: 9 additions & 9 deletions protocol/blockfetch/blockfetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ var (
STATE_DONE = protocol.NewState(4, "Done")
)

var stateMap = protocol.StateMap{
var StateMap = protocol.StateMap{
STATE_IDLE: protocol.StateMapEntry{
Agency: protocol.AGENCY_CLIENT,
Transitions: []protocol.StateTransition{
Expand Down Expand Up @@ -95,7 +95,7 @@ func New(options protocol.ProtocolOptions, callbackConfig *BlockFetchCallbackCon
Role: options.Role,
MessageHandlerFunc: b.messageHandler,
MessageFromCborFunc: NewMsgFromCbor,
StateMap: stateMap,
StateMap: StateMap,
InitialState: STATE_IDLE,
}
b.proto = protocol.New(protoConfig)
Expand All @@ -120,12 +120,12 @@ func (b *BlockFetch) messageHandler(msg protocol.Message) error {
}

func (b *BlockFetch) RequestRange(start []interface{}, end []interface{}) error {
msg := newMsgRequestRange(start, end)
msg := NewMsgRequestRange(start, end)
return b.proto.SendMessage(msg, false)
}

func (b *BlockFetch) ClientDone() error {
msg := newMsgClientDone()
msg := NewMsgClientDone()
return b.proto.SendMessage(msg, false)
}

Expand All @@ -149,18 +149,18 @@ func (b *BlockFetch) handleBlock(msgGeneric protocol.Message) error {
if b.callbackConfig.BlockFunc == nil {
return fmt.Errorf("received block-fetch Block message but no callback function is defined")
}
msg := msgGeneric.(*msgBlock)
msg := msgGeneric.(*MsgBlock)
// Decode only enough to get the block type value
var wrapBlock wrappedBlock
if _, err := utils.CborDecode(msg.WrappedBlock, &wrapBlock); err != nil {
var wrappedBlock WrappedBlock
if _, err := utils.CborDecode(msg.WrappedBlock, &wrappedBlock); err != nil {
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
}
blk, err := block.NewBlockFromCbor(wrapBlock.Type, wrapBlock.RawBlock)
blk, err := block.NewBlockFromCbor(wrappedBlock.Type, wrappedBlock.RawBlock)
if err != nil {
return err
}
// Call the user callback function
return b.callbackConfig.BlockFunc(wrapBlock.Type, blk)
return b.callbackConfig.BlockFunc(wrappedBlock.Type, blk)
}

func (b *BlockFetch) handleBatchDone() error {
Expand Down
34 changes: 17 additions & 17 deletions protocol/blockfetch/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ func NewMsgFromCbor(msgType uint, data []byte) (protocol.Message, error) {
var ret protocol.Message
switch msgType {
case MESSAGE_TYPE_REQUEST_RANGE:
ret = &msgRequestRange{}
ret = &MsgRequestRange{}
case MESSAGE_TYPE_CLIENT_DONE:
ret = &msgClientDone{}
ret = &MsgClientDone{}
case MESSAGE_TYPE_START_BATCH:
ret = &msgStartBatch{}
ret = &MsgStartBatch{}
case MESSAGE_TYPE_NO_BLOCKS:
ret = &msgNoBlocks{}
ret = &MsgNoBlocks{}
case MESSAGE_TYPE_BLOCK:
ret = &msgBlock{}
ret = &MsgBlock{}
case MESSAGE_TYPE_BATCH_DONE:
ret = &msgBatchDone{}
ret = &MsgBatchDone{}
}
if _, err := utils.CborDecode(data, ret); err != nil {
return nil, fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
Expand All @@ -42,14 +42,14 @@ func NewMsgFromCbor(msgType uint, data []byte) (protocol.Message, error) {
return ret, nil
}

type msgRequestRange struct {
type MsgRequestRange struct {
protocol.MessageBase
Start interface{} //point
End interface{} //point
}

func newMsgRequestRange(start interface{}, end interface{}) *msgRequestRange {
m := &msgRequestRange{
func NewMsgRequestRange(start interface{}, end interface{}) *MsgRequestRange {
m := &MsgRequestRange{
MessageBase: protocol.MessageBase{
MessageType: MESSAGE_TYPE_REQUEST_RANGE,
},
Expand All @@ -59,33 +59,33 @@ func newMsgRequestRange(start interface{}, end interface{}) *msgRequestRange {
return m
}

type msgClientDone struct {
type MsgClientDone struct {
protocol.MessageBase
}

func newMsgClientDone() *msgClientDone {
m := &msgClientDone{
func NewMsgClientDone() *MsgClientDone {
m := &MsgClientDone{
MessageBase: protocol.MessageBase{
MessageType: MESSAGE_TYPE_CLIENT_DONE,
},
}
return m
}

type msgStartBatch struct {
type MsgStartBatch struct {
protocol.MessageBase
}

type msgNoBlocks struct {
type MsgNoBlocks struct {
protocol.MessageBase
}

type msgBlock struct {
type MsgBlock struct {
protocol.MessageBase
WrappedBlock []byte
}

type msgBatchDone struct {
type MsgBatchDone struct {
protocol.MessageBase
}

Expand All @@ -97,7 +97,7 @@ type point struct {
}
*/

type wrappedBlock struct {
type WrappedBlock struct {
// Tells the CBOR decoder to convert to/from a struct and a CBOR array
_ struct{} `cbor:",toarray"`
Type uint
Expand Down
38 changes: 20 additions & 18 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var (
STATE_DONE = protocol.NewState(5, "Done")
)

var stateMap = protocol.StateMap{
var StateMap = protocol.StateMap{
STATE_IDLE: protocol.StateMapEntry{
Agency: protocol.AGENCY_CLIENT,
Transitions: []protocol.StateTransition{
Expand Down Expand Up @@ -112,9 +112,11 @@ type ChainSyncDoneFunc func() error
func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConfig) *ChainSync {
// Use node-to-client protocol ID
protocolId := PROTOCOL_ID_NTC
msgFromCborFunc := NewMsgFromCborNtC
if options.Mode == protocol.ProtocolModeNodeToNode {
// Use node-to-node protocol ID
protocolId = PROTOCOL_ID_NTN
msgFromCborFunc = NewMsgFromCborNtN
}
c := &ChainSync{
callbackConfig: callbackConfig,
Expand All @@ -127,8 +129,8 @@ func New(options protocol.ProtocolOptions, callbackConfig *ChainSyncCallbackConf
Mode: options.Mode,
Role: options.Role,
MessageHandlerFunc: c.messageHandler,
MessageFromCborFunc: c.NewMsgFromCbor,
StateMap: stateMap,
MessageFromCborFunc: msgFromCborFunc,
StateMap: StateMap,
InitialState: STATE_IDLE,
}
c.proto = protocol.New(protoConfig)
Expand Down Expand Up @@ -157,12 +159,12 @@ func (c *ChainSync) messageHandler(msg protocol.Message) error {
}

func (c *ChainSync) RequestNext() error {
msg := newMsgRequestNext()
msg := NewMsgRequestNext()
return c.proto.SendMessage(msg, false)
}

func (c *ChainSync) FindIntersect(points []interface{}) error {
msg := newMsgFindIntersect(points)
msg := NewMsgFindIntersect(points)
return c.proto.SendMessage(msg, false)
}

Expand All @@ -179,19 +181,19 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
return fmt.Errorf("received chain-sync RollForward message but no callback function is defined")
}
if c.proto.Mode() == protocol.ProtocolModeNodeToNode {
msg := msgGeneric.(*msgRollForwardNtN)
msg := msgGeneric.(*MsgRollForwardNtN)
var blockHeader interface{}
var blockType uint
blockHeaderType := msg.WrappedHeader.Type
switch blockHeaderType {
case block.BLOCK_HEADER_TYPE_BYRON:
var wrapHeaderByron wrappedHeaderByron
if _, err := utils.CborDecode(msg.WrappedHeader.RawData, &wrapHeaderByron); err != nil {
var wrappedHeaderByron WrappedHeaderByron
if _, err := utils.CborDecode(msg.WrappedHeader.RawData, &wrappedHeaderByron); err != nil {
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
}
blockType = wrapHeaderByron.Unknown.Type
blockType = wrappedHeaderByron.Unknown.Type
var err error
blockHeader, err = block.NewBlockHeaderFromCbor(blockType, wrapHeaderByron.RawHeader)
blockHeader, err = block.NewBlockHeaderFromCbor(blockType, wrappedHeaderByron.RawHeader)
if err != nil {
return err
}
Expand All @@ -218,26 +220,26 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
// Call the user callback function
return c.callbackConfig.RollForwardFunc(blockType, blockHeader)
} else {
msg := msgGeneric.(*msgRollForwardNtC)
msg := msgGeneric.(*MsgRollForwardNtC)
// Decode only enough to get the block type value
var wrapBlock wrappedBlock
if _, err := utils.CborDecode(msg.WrappedData, &wrapBlock); err != nil {
var wrappedBlock WrappedBlock
if _, err := utils.CborDecode(msg.WrappedData, &wrappedBlock); err != nil {
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
}
blk, err := block.NewBlockFromCbor(wrapBlock.Type, wrapBlock.RawBlock)
blk, err := block.NewBlockFromCbor(wrappedBlock.Type, wrappedBlock.RawBlock)
if err != nil {
return err
}
// Call the user callback function
return c.callbackConfig.RollForwardFunc(wrapBlock.Type, blk)
return c.callbackConfig.RollForwardFunc(wrappedBlock.Type, blk)
}
}

func (c *ChainSync) handleRollBackward(msgGeneric protocol.Message) error {
if c.callbackConfig.RollBackwardFunc == nil {
return fmt.Errorf("received chain-sync RollBackward message but no callback function is defined")
}
msg := msgGeneric.(*msgRollBackward)
msg := msgGeneric.(*MsgRollBackward)
// Call the user callback function
return c.callbackConfig.RollBackwardFunc(msg.Point, msg.Tip)
}
Expand All @@ -246,7 +248,7 @@ func (c *ChainSync) handleIntersectFound(msgGeneric protocol.Message) error {
if c.callbackConfig.IntersectFoundFunc == nil {
return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined")
}
msg := msgGeneric.(*msgIntersectFound)
msg := msgGeneric.(*MsgIntersectFound)
// Call the user callback function
return c.callbackConfig.IntersectFoundFunc(msg.Point, msg.Tip)
}
Expand All @@ -255,7 +257,7 @@ func (c *ChainSync) handleIntersectNotFound(msgGeneric protocol.Message) error {
if c.callbackConfig.IntersectNotFoundFunc == nil {
return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined")
}
msg := msgGeneric.(*msgIntersectNotFound)
msg := msgGeneric.(*MsgIntersectNotFound)
// Call the user callback function
return c.callbackConfig.IntersectNotFoundFunc(msg.Tip)
}
Expand Down
Loading

0 comments on commit d838151

Please sign in to comment.