Skip to content

Commit

Permalink
Merge pull request #465 from benluddy/go-string-as-cbor-byte-string
Browse files Browse the repository at this point in the history
New options for encoding Go strings to and from CBOR byte strings
  • Loading branch information
fxamacker authored Jan 9, 2024
2 parents 83ec0aa + ae20110 commit cd0553c
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 93 deletions.
3 changes: 2 additions & 1 deletion bytestring.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (bs *ByteString) UnmarshalCBOR(data []byte) error {
return &UnmarshalTypeError{CBORType: typ.String(), GoType: typeByteString.String()}
}

*bs = ByteString(d.parseByteString())
b, _ := d.parseByteString()
*bs = ByteString(b)
return nil
}
196 changes: 136 additions & 60 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,23 @@ func (bidm BigIntDecMode) valid() bool {
return bidm >= 0 && bidm < maxBigIntDecMode
}

// ByteStringToStringMode specifies the behavior when decoding a CBOR byte string into a Go string.
type ByteStringToStringMode int

const (
// ByteStringToStringForbidden generates an error on an attempt to decode a CBOR byte string into a Go string.
ByteStringToStringForbidden ByteStringToStringMode = iota

// ByteStringToStringAllowed permits decoding a CBOR byte string into a Go string.
ByteStringToStringAllowed

maxByteStringToStringMode
)

func (bstsm ByteStringToStringMode) valid() bool {
return bstsm >= 0 && bstsm < maxByteStringToStringMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -467,6 +484,15 @@ type DecOptions struct {

// BigIntDec specifies how to decode CBOR bignum to Go interface{}.
BigIntDec BigIntDecMode

// DefaultByteStringType is the Go type that should be produced when decoding a CBOR byte
// string into an empty interface value. Types to which a []byte is convertible are valid
// for this option, except for array and pointer-to-array types. If nil, the default is
// []byte.
DefaultByteStringType reflect.Type

// ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string.
ByteStringToString ByteStringToStringMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -581,21 +607,29 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.BigIntDec.valid() {
return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec)))
}
if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) {
return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType)
}
if !opts.ByteStringToString.valid() {
return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
}
return &dm, nil
}
Expand Down Expand Up @@ -647,41 +681,45 @@ type DecMode interface {
}

type decMode struct {
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
}

var defaultDecMode, _ = DecOptions{}.decMode()

// DecOptions returns user specified options used to create this DecMode.
func (dm *decMode) DecOptions() DecOptions {
return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
}
}

Expand Down Expand Up @@ -979,8 +1017,8 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return fillNegativeInt(t, nValue, v)

case cborTypeByteString:
b := d.parseByteString()
return fillByteString(t, b, v)
b, copied := d.parseByteString()
return fillByteString(t, b, !copied, v, d.dm.byteStringToString)

case cborTypeTextString:
b, err := d.parseTextString()
Expand Down Expand Up @@ -1017,15 +1055,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
switch tagNum {
case 2:
// Bignum (tag 2) can be decoded to uint, int, float, slice, array, or big.Int.
b := d.parseByteString()
b, copied := d.parseByteString()
bi := new(big.Int).SetBytes(b)

if tInfo.nonPtrType == typeBigInt {
v.Set(reflect.ValueOf(*bi))
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, v)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
Expand All @@ -1037,7 +1075,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
case 3:
// Bignum (tag 3) can be decoded to int, float, slice, array, or big.Int.
b := d.parseByteString()
b, copied := d.parseByteString()
bi := new(big.Int).SetBytes(b)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)
Expand All @@ -1047,7 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, v)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
Expand Down Expand Up @@ -1279,7 +1317,29 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nValue, nil

case cborTypeByteString:
return d.parseByteString(), nil
switch d.dm.defaultByteStringType {
case nil, typeByteSlice:
b, copied := d.parseByteString()
if copied {
return b, nil
}
clone := make([]byte, len(b))
copy(clone, b)
return clone, nil
case typeString:
b, _ := d.parseByteString()
return string(b), nil
default:
b, copied := d.parseByteString()
if copied || d.dm.defaultByteStringType.Kind() == reflect.String {
// Avoid an unnecessary copy since the conversion to string must
// copy the underlying bytes.
return reflect.ValueOf(b).Convert(d.dm.defaultByteStringType).Interface(), nil
}
clone := make([]byte, len(b))
copy(clone, b)
return reflect.ValueOf(clone).Convert(d.dm.defaultByteStringType).Interface(), nil
}
case cborTypeTextString:
b, err := d.parseTextString()
if err != nil {
Expand All @@ -1296,15 +1356,15 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
d.off = tagOff
return d.parseToTime()
case 2:
b := d.parseByteString()
b, _ := d.parseByteString()
bi := new(big.Int).SetBytes(b)

if d.dm.bigIntDec == BigIntDecodePointer {
return bi, nil
}
return *bi, nil
case 3:
b := d.parseByteString()
b, _ := d.parseByteString()
bi := new(big.Int).SetBytes(b)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)
Expand Down Expand Up @@ -1376,15 +1436,16 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, nil
}

// parseByteString parses CBOR encoded byte string. It returns a byte slice
// pointing to a copy of parsed data.
func (d *decoder) parseByteString() []byte {
// parseByteString parses a CBOR encoded byte string. The returned byte slice
// may be backed directly by the input. The second return value will be true if
// and only if the slice is backed by a copy of the input. Callers are
// responsible for making a copy if necessary.
func (d *decoder) parseByteString() ([]byte, bool) {
_, ai, val := d.getHead()
if ai != 31 {
b := make([]byte, int(val))
copy(b, d.data[d.off:d.off+int(val)])
b := d.data[d.off : d.off+int(val)]
d.off += int(val)
return b
return b, false
}
// Process indefinite length string chunks.
b := []byte{}
Expand All @@ -1393,7 +1454,7 @@ func (d *decoder) parseByteString() []byte {
b = append(b, d.data[d.off:d.off+int(val)]...)
d.off += int(val)
}
return b
return b, true
}

// parseTextString parses CBOR encoded text string. It returns a byte slice
Expand Down Expand Up @@ -2082,6 +2143,8 @@ var (
typeBigInt = reflect.TypeOf(big.Int{})
typeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
typeString = reflect.TypeOf("")
typeByteSlice = reflect.TypeOf([]byte(nil))
)

func fillNil(_ cborType, v reflect.Value) error {
Expand Down Expand Up @@ -2184,18 +2247,31 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillByteString(t cborType, val []byte, v reflect.Value) error {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error {
if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
if u, ok := v.Interface().(encoding.BinaryUnmarshaler); ok {
// The contract of BinaryUnmarshaler forbids
// retaining the input bytes, so no copying is
// required even if val is shared.
return u.UnmarshalBinary(val)
}
}
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
if bsts == ByteStringToStringAllowed && v.Kind() == reflect.String {
v.SetString(string(val))
return nil
}
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
v.SetBytes(val)
src := val
if shared {
// SetBytes shares the underlying bytes of the source slice.
src = make([]byte, len(val))
copy(src, val)
}
v.SetBytes(src)
return nil
}
if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 {
Expand Down
Loading

0 comments on commit cd0553c

Please sign in to comment.