Skip to content

Commit

Permalink
feat: add buffered decoder, bit sequence and decode with length (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvladco authored Feb 5, 2025
1 parent e035c67 commit 83b396c
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 4 deletions.
98 changes: 97 additions & 1 deletion pkg/serialization/codec/jam/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/eigerco/strawberry/internal/crypto"
)

type BitSequence []bool

func Unmarshal(data []byte, dst interface{}) error {
dstv := reflect.ValueOf(dst)
if dstv.Kind() != reflect.Ptr || dstv.IsNil() {
Expand All @@ -25,6 +27,50 @@ func Unmarshal(data []byte, dst interface{}) error {
return ds.unmarshal(indirect(dstv))
}

func NewDecoder(reader io.Reader) *Decoder {
return &Decoder{
byteReader{reader},
}
}

type Decoder struct {
byteReader
}

func (d *Decoder) Decode(dst any) error {
dstv := reflect.ValueOf(dst)
if dstv.Kind() != reflect.Ptr || dstv.IsNil() {
return fmt.Errorf(ErrUnsupportedType, dst)
}

return d.unmarshal(indirect(dstv))
}

func (d *Decoder) DecodeFixedLength(dst any, length uint) error {
dstv := reflect.ValueOf(dst)
if dstv.Kind() != reflect.Ptr || dstv.IsNil() {
return fmt.Errorf(ErrUnsupportedType, dst)
}
dstv = indirect(dstv)

in := dstv.Interface()
switch v := in.(type) {
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
return d.decodeFixedWidth(dstv, length)
case []byte:
return d.decodeBytesFixedLength(dstv, length)
case BitSequence:
if err := d.decodeBitsFixedLength(&v, length); err != nil {
return err
}
inType := reflect.TypeOf(in)
dstv.Set(reflect.ValueOf(v).Convert(inType))
default:
return fmt.Errorf(ErrUnsupportedType, dst)
}
return nil
}

type byteReader struct {
io.Reader
}
Expand All @@ -50,6 +96,8 @@ func (br *byteReader) unmarshal(value reflect.Value) error {
return br.decodeFixedWidth(value, l)
case []byte:
return br.decodeBytes(value)
case BitSequence:
return br.decodeBits(value)
case bool:
return br.decodeBool(value)
default:
Expand All @@ -72,6 +120,12 @@ func (br *byteReader) handleReflectTypes(value reflect.Value) error {
if value.Type() == reflect.TypeOf(ed25519.PublicKey{}) {
return br.decodeEd25519PublicKey(value)
}
if value.Type() == reflect.TypeOf(BitSequence{}) {
return br.decodeBits(value)
}
if value.Type() == reflect.TypeOf([]byte{}) {
return br.decodeBytes(value)
}
return br.decodeSlice(value)
case reflect.Map:
return br.decodeMap(value)
Expand Down Expand Up @@ -376,15 +430,19 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error {
if err != nil {
return err
}
return br.decodeBytesFixedLength(dstv, length)
}

// decodeBytes is used to decode with a destination of []byte
func (br *byteReader) decodeBytesFixedLength(dstv reflect.Value, length uint) error {
if length > math.MaxUint32 {
return ErrExceedingByteArrayLimit
}

b := make([]byte, length)

if length > 0 {
_, err = br.Read(b)
_, err := br.Read(b)
if err != nil {
return err
}
Expand All @@ -396,6 +454,44 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error {
return nil
}

// decodeBytes is used to decode with a destination of []byte
func (br *byteReader) decodeBits(dstv reflect.Value) error {
length, err := br.decodeLength()
if err != nil {
return err
}
var v BitSequence
if err := br.decodeBitsFixedLength(&v, length); err != nil {
return err
}
in := dstv.Interface()
inType := reflect.TypeOf(in)
dstv.Set(reflect.ValueOf(v).Convert(inType))
return nil
}

func (br *byteReader) decodeBitsFixedLength(v *BitSequence, bytesLength uint) (err error) {
if bytesLength > math.MaxUint32 {
return ErrExceedingByteArrayLimit
}
bb := make([]byte, bytesLength)
if _, err = br.Reader.Read(bb); err != nil {
return err
}
if bytesLength == 0 {
return nil
}
*v = make(BitSequence, bytesLength*8)
for i := range *v {
mod := i % 8
b := bb[i/8]
pow2 := byte(1 << mod) // powers of 2
(*v)[i] = b&pow2 == pow2 // identify the bit
}
return nil
}

// decodeFixedWidth E_{l∈N}(N_{2^8l} → Yl) (eq. C.5)
func (br *byteReader) decodeFixedWidth(dstv reflect.Value, length uint) error {
typ := dstv.Type()

Expand Down
58 changes: 58 additions & 0 deletions pkg/serialization/codec/jam/decode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package jam

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func TestDecodeBits(t *testing.T) {
tests := []struct {
name string
input []byte
expect BitSequence
leftover int
err error
}{{
name: "empty",
input: []byte{},
expect: BitSequence{},
}, {
name: "1 bytes",
input: []byte{255},
expect: BitSequence{true, true, true, true, true, true, true, true},
}, {
name: "1.5 bytes",
input: []byte{0, 255},
expect: BitSequence{
false, false, false, false, false, false, false, false,
true, true, true, true, true, true, true, true,
},
}, {
name: "5 bytes",
input: []byte{17, 25, 0, 1, 2},
expect: BitSequence{
true, false, false, false, true, false, false, false,
true, false, false, true, true, false, false, false,
false, false, false, false, false, false, false, false,
true, false, false, false, false, false, false, false,
false, true, false, false, false, false, false, false,
},
}}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
buff := bytes.NewBuffer(tc.input)
d := NewDecoder(buff)

actual := BitSequence{}
err := d.DecodeFixedLength(&actual, uint(len(tc.input)))
if tc.err != nil {
assert.Equal(t, tc.err, err)
}

assert.Equal(t, tc.expect, actual)
assert.Equal(t, tc.leftover, buff.Len())
})
}
}
36 changes: 33 additions & 3 deletions pkg/serialization/codec/jam/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func (bw *byteWriter) marshal(in interface{}) error {
return bw.encodeFixedWidth(v, l)
case []byte:
return bw.encodeBytes(v)
case BitSequence:
return bw.encodeBits(v)
case bool:
return bw.encodeBool(v)
default:
Expand Down Expand Up @@ -74,10 +76,16 @@ func (bw *byteWriter) handleReflectTypes(in interface{}) error {
case reflect.Array:
return bw.encodeArray(in)
case reflect.Slice:
if pk, ok := in.(ed25519.PublicKey); ok {
return bw.encodeEd25519PublicKey(pk)
switch v := in.(type) {
case ed25519.PublicKey:
return bw.encodeEd25519PublicKey(v)
case BitSequence:
return bw.encodeBits(v)
case []byte:
return bw.encodeBytes(v)
default:
return bw.encodeSlice(in)
}
return bw.encodeSlice(in)
case reflect.Map:
return bw.encodeMap(in)
default:
Expand Down Expand Up @@ -288,6 +296,28 @@ func (bw *byteWriter) encodeBytes(b []byte) error {
return err
}

func (bw *byteWriter) encodeBits(bitSequence BitSequence) error {
length := len(bitSequence) / 8
if length > 0 && length%8 == 0 {
length += 1
}
err := bw.encodeLength(length)
if err != nil {
return err
}

bb := make([]byte, length)
for i, b := range bitSequence {
if b {
pow2 := byte(1 << (i % 8)) // powers of 2
bb[i/8] |= pow2 // identify the bit
}
}

_, err = bw.Write(bb)
return err
}

func (bw *byteWriter) encodeFixedWidth(i interface{}, l uint) error {
val := reflect.ValueOf(i)

Expand Down
10 changes: 10 additions & 0 deletions pkg/serialization/codec/jam/encode_decode_jam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type TestStruct struct {
LargeUint uint
PubKey *ed25519.PublicKey
InnerSlice []InnerStruct
Bits jam.BitSequence
}

func TestMarshalUnmarshal(t *testing.T) {
Expand All @@ -40,6 +41,15 @@ func TestMarshalUnmarshal(t *testing.T) {
{2, 3, 4, 5},
{3, 4, 5, 6},
},
Bits: jam.BitSequence{
true, true, true, true, true, true, true, true,
true, true, true, true, true, true, true, false,
true, true, true, true, true, true, false, false,
true, true, true, true, false, false, false, false,
true, true, false, false, false, false, false, false,
true, false, false, false, false, false, false, false,
false, false, false, false, false, false, false, false,
},
}

marshaledData, err := jam.Marshal(original)
Expand Down

0 comments on commit 83b396c

Please sign in to comment.