From aaa94a22223d3bee5e1684e8d5f543b8d05cdc3e Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 17 Jan 2025 15:40:19 +0800 Subject: [PATCH] Bytes() return fixed array --- core/trie/bitarray.go | 13 ++++++++----- core/trie/bitarray_test.go | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index b2bf95269..d1d28eab7 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -38,7 +38,8 @@ func NewBitArray(length uint8, val uint64) BitArray { // Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt - f.SetBytes(b.Bytes()) + bt := b.Bytes() + f.SetBytes(bt[:]) return f } @@ -47,7 +48,7 @@ func (b *BitArray) Len() uint8 { } // Returns the bytes representation of the bit array in big endian format -func (b *BitArray) Bytes() []byte { +func (b *BitArray) Bytes() [32]byte { var res [32]byte binary.BigEndian.PutUint64(res[0:8], b.words[3]) @@ -55,7 +56,7 @@ func (b *BitArray) Bytes() []byte { binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - return res[:] + return res } // Sets the bit array to the least significant 'n' bits of x. @@ -512,15 +513,17 @@ func (b *BitArray) Copy() BitArray { // Returns the encoded string representation of the bit array. func (b *BitArray) EncodedString() string { var res []byte + bt := b.Bytes() res = append(res, b.len) - res = append(res, b.Bytes()...) + res = append(res, bt[:]...) return string(res) } // Returns a string representation of the bit array. // This is typically used for logging or debugging. func (b *BitArray) String() string { - return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) + bt := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bt[:])) } func (b *BitArray) setFelt(f *felt.Felt) { diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 123ed5244..4786eee2e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -97,7 +97,7 @@ func TestBytes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.ba.Bytes() - if !bytes.Equal(got, tt.want[:]) { + if !bytes.Equal(got[:], tt.want[:]) { t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) }