Skip to content

Commit

Permalink
error check for UnmarshalBinary
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Jan 17, 2025
1 parent ba17687 commit 3929edc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
20 changes: 17 additions & 3 deletions core/trie/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"math"
"math/bits"
Expand Down Expand Up @@ -424,12 +425,25 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) {
// Example:
//
// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}}
func (b *BitArray) UnmarshalBinary(data []byte) {
b.len = data[0]
func (b *BitArray) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return errors.New("empty data")
}

length := data[0]
byteCount := (uint(length) + 7) / 8 // Round up to nearest byte

if len(data) < int(byteCount)+1 {
return fmt.Errorf("invalid data length: got %d bytes, expected %d", len(data), byteCount+1)
}

b.len = length

var bs [32]byte
copy(bs[32-b.byteCount():], data[1:])
copy(bs[32-byteCount:], data[1:])
b.setBytes32(bs[:])

return nil
}

// Sets the bit array to the same value as x.
Expand Down
9 changes: 7 additions & 2 deletions core/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,14 @@ func (n *Node) UnmarshalBinary(data []byte) error {
n.Right = new(BitArray)
}

n.Left.UnmarshalBinary(data)
if err := n.Left.UnmarshalBinary(data); err != nil {
return err
}
data = data[n.Left.EncodedLen():]
n.Right.UnmarshalBinary(data)

if err := n.Right.UnmarshalBinary(data); err != nil {
return err
}
data = data[n.Right.EncodedLen():]

if n.LeftHash == nil {
Expand Down
3 changes: 1 addition & 2 deletions core/trie/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ func (t *Storage) RootKey() (*BitArray, error) {
var rootKey *BitArray
if err := t.txn.Get(t.prefix, func(val []byte) error {
rootKey = new(BitArray)
rootKey.UnmarshalBinary(val)
return nil
return rootKey.UnmarshalBinary(val)
}); err != nil {
return nil, err
}
Expand Down

0 comments on commit 3929edc

Please sign in to comment.