diff --git a/pkg/serialization/codec/jam/decode.go b/pkg/serialization/codec/jam/decode.go index 3b7ca95..eb75dfe 100644 --- a/pkg/serialization/codec/jam/decode.go +++ b/pkg/serialization/codec/jam/decode.go @@ -13,6 +13,8 @@ import ( "github.com/eigerco/strawberry/internal/crypto" ) +type BitSequence []bool + func Unmarshal(data []byte, dst interface{}) error { dstv := reflect.ValueOf(dst) if dstv.Kind() != reflect.Ptr || dstv.IsNil() { @@ -25,6 +27,50 @@ func Unmarshal(data []byte, dst interface{}) error { return ds.unmarshal(indirect(dstv)) } +func NewDecoder(reader io.Reader) *Decoder { + return &Decoder{ + byteReader{reader}, + } +} + +type Decoder struct { + byteReader +} + +func (d *Decoder) Decode(dst any) error { + dstv := reflect.ValueOf(dst) + if dstv.Kind() != reflect.Ptr || dstv.IsNil() { + return fmt.Errorf(ErrUnsupportedType, dst) + } + + return d.unmarshal(indirect(dstv)) +} + +func (d *Decoder) DecodeFixedLength(dst any, length uint) error { + dstv := reflect.ValueOf(dst) + if dstv.Kind() != reflect.Ptr || dstv.IsNil() { + return fmt.Errorf(ErrUnsupportedType, dst) + } + dstv = indirect(dstv) + + in := dstv.Interface() + switch v := in.(type) { + case int8, uint8, int16, uint16, int32, uint32, int64, uint64: + return d.decodeFixedWidth(dstv, length) + case []byte: + return d.decodeBytesFixedLength(dstv, length) + case BitSequence: + if err := d.decodeBitsFixedLength(&v, length); err != nil { + return err + } + inType := reflect.TypeOf(in) + dstv.Set(reflect.ValueOf(v).Convert(inType)) + default: + return fmt.Errorf(ErrUnsupportedType, dst) + } + return nil +} + type byteReader struct { io.Reader } @@ -50,6 +96,8 @@ func (br *byteReader) unmarshal(value reflect.Value) error { return br.decodeFixedWidth(value, l) case []byte: return br.decodeBytes(value) + case BitSequence: + return br.decodeBits(value) case bool: return br.decodeBool(value) default: @@ -72,6 +120,12 @@ func (br *byteReader) handleReflectTypes(value reflect.Value) error { if value.Type() == reflect.TypeOf(ed25519.PublicKey{}) { return br.decodeEd25519PublicKey(value) } + if value.Type() == reflect.TypeOf(BitSequence{}) { + return br.decodeBits(value) + } + if value.Type() == reflect.TypeOf([]byte{}) { + return br.decodeBytes(value) + } return br.decodeSlice(value) case reflect.Map: return br.decodeMap(value) @@ -376,7 +430,11 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error { if err != nil { return err } + return br.decodeBytesFixedLength(dstv, length) +} +// decodeBytes is used to decode with a destination of []byte +func (br *byteReader) decodeBytesFixedLength(dstv reflect.Value, length uint) error { if length > math.MaxUint32 { return ErrExceedingByteArrayLimit } @@ -384,7 +442,7 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error { b := make([]byte, length) if length > 0 { - _, err = br.Read(b) + _, err := br.Read(b) if err != nil { return err } @@ -396,6 +454,44 @@ func (br *byteReader) decodeBytes(dstv reflect.Value) error { return nil } +// decodeBytes is used to decode with a destination of []byte +func (br *byteReader) decodeBits(dstv reflect.Value) error { + length, err := br.decodeLength() + if err != nil { + return err + } + var v BitSequence + if err := br.decodeBitsFixedLength(&v, length); err != nil { + return err + } + in := dstv.Interface() + inType := reflect.TypeOf(in) + dstv.Set(reflect.ValueOf(v).Convert(inType)) + return nil +} + +func (br *byteReader) decodeBitsFixedLength(v *BitSequence, bytesLength uint) (err error) { + if bytesLength > math.MaxUint32 { + return ErrExceedingByteArrayLimit + } + bb := make([]byte, bytesLength) + if _, err = br.Reader.Read(bb); err != nil { + return err + } + if bytesLength == 0 { + return nil + } + *v = make(BitSequence, bytesLength*8) + for i := range *v { + mod := i % 8 + b := bb[i/8] + pow2 := byte(1 << mod) // powers of 2 + (*v)[i] = b&pow2 == pow2 // identify the bit + } + return nil +} + +// decodeFixedWidth E_{l∈N}(N_{2^8l} → Yl) (eq. C.5) func (br *byteReader) decodeFixedWidth(dstv reflect.Value, length uint) error { typ := dstv.Type() diff --git a/pkg/serialization/codec/jam/decode_test.go b/pkg/serialization/codec/jam/decode_test.go new file mode 100644 index 0000000..4980b52 --- /dev/null +++ b/pkg/serialization/codec/jam/decode_test.go @@ -0,0 +1,58 @@ +package jam + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDecodeBits(t *testing.T) { + tests := []struct { + name string + input []byte + expect BitSequence + leftover int + err error + }{{ + name: "empty", + input: []byte{}, + expect: BitSequence{}, + }, { + name: "1 bytes", + input: []byte{255}, + expect: BitSequence{true, true, true, true, true, true, true, true}, + }, { + name: "1.5 bytes", + input: []byte{0, 255}, + expect: BitSequence{ + false, false, false, false, false, false, false, false, + true, true, true, true, true, true, true, true, + }, + }, { + name: "5 bytes", + input: []byte{17, 25, 0, 1, 2}, + expect: BitSequence{ + true, false, false, false, true, false, false, false, + true, false, false, true, true, false, false, false, + false, false, false, false, false, false, false, false, + true, false, false, false, false, false, false, false, + false, true, false, false, false, false, false, false, + }, + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + buff := bytes.NewBuffer(tc.input) + d := NewDecoder(buff) + + actual := BitSequence{} + err := d.DecodeFixedLength(&actual, uint(len(tc.input))) + if tc.err != nil { + assert.Equal(t, tc.err, err) + } + + assert.Equal(t, tc.expect, actual) + assert.Equal(t, tc.leftover, buff.Len()) + }) + } +} diff --git a/pkg/serialization/codec/jam/encode.go b/pkg/serialization/codec/jam/encode.go index dda603f..fb87e00 100644 --- a/pkg/serialization/codec/jam/encode.go +++ b/pkg/serialization/codec/jam/encode.go @@ -47,6 +47,8 @@ func (bw *byteWriter) marshal(in interface{}) error { return bw.encodeFixedWidth(v, l) case []byte: return bw.encodeBytes(v) + case BitSequence: + return bw.encodeBits(v) case bool: return bw.encodeBool(v) default: @@ -74,10 +76,16 @@ func (bw *byteWriter) handleReflectTypes(in interface{}) error { case reflect.Array: return bw.encodeArray(in) case reflect.Slice: - if pk, ok := in.(ed25519.PublicKey); ok { - return bw.encodeEd25519PublicKey(pk) + switch v := in.(type) { + case ed25519.PublicKey: + return bw.encodeEd25519PublicKey(v) + case BitSequence: + return bw.encodeBits(v) + case []byte: + return bw.encodeBytes(v) + default: + return bw.encodeSlice(in) } - return bw.encodeSlice(in) case reflect.Map: return bw.encodeMap(in) default: @@ -288,6 +296,28 @@ func (bw *byteWriter) encodeBytes(b []byte) error { return err } +func (bw *byteWriter) encodeBits(bitSequence BitSequence) error { + length := len(bitSequence) / 8 + if length > 0 && length%8 == 0 { + length += 1 + } + err := bw.encodeLength(length) + if err != nil { + return err + } + + bb := make([]byte, length) + for i, b := range bitSequence { + if b { + pow2 := byte(1 << (i % 8)) // powers of 2 + bb[i/8] |= pow2 // identify the bit + } + } + + _, err = bw.Write(bb) + return err +} + func (bw *byteWriter) encodeFixedWidth(i interface{}, l uint) error { val := reflect.ValueOf(i) diff --git a/pkg/serialization/codec/jam/encode_decode_jam_test.go b/pkg/serialization/codec/jam/encode_decode_jam_test.go index 4fcfa57..012974f 100644 --- a/pkg/serialization/codec/jam/encode_decode_jam_test.go +++ b/pkg/serialization/codec/jam/encode_decode_jam_test.go @@ -23,6 +23,7 @@ type TestStruct struct { LargeUint uint PubKey *ed25519.PublicKey InnerSlice []InnerStruct + Bits jam.BitSequence } func TestMarshalUnmarshal(t *testing.T) { @@ -40,6 +41,15 @@ func TestMarshalUnmarshal(t *testing.T) { {2, 3, 4, 5}, {3, 4, 5, 6}, }, + Bits: jam.BitSequence{ + true, true, true, true, true, true, true, true, + true, true, true, true, true, true, true, false, + true, true, true, true, true, true, false, false, + true, true, true, true, false, false, false, false, + true, true, false, false, false, false, false, false, + true, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, + }, } marshaledData, err := jam.Marshal(original)