From 30aefa6c0bdac69c31175b1e96a83952cc28d127 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Wed, 13 Dec 2023 08:55:46 -0500 Subject: [PATCH 1/3] Add option to encode Go strings to CBOR byte strings. The StringType encode option supports use cases that must not produce invalid CBOR (even if well-formed) and must be able to encode Go strings that do not contain valid UTF-8 sequences, without the overhead of sanitizing all input Go values. The default value of the option is TextStringType and encodes Go strings to CBOR major type 3, which is identical to the preexisting behavior. Signed-off-by: Ben Luddy --- encode.go | 89 +++++++++++++++++++++++++++++++++++--------------- encode_test.go | 61 ++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 27 deletions(-) diff --git a/encode.go b/encode.go index a08320b2..eba09c53 100644 --- a/encode.go +++ b/encode.go @@ -156,6 +156,27 @@ func (sm SortMode) valid() bool { return sm >= 0 && sm < maxSortMode } +// StringMode specifies how to encode Go string values. +type StringMode int + +const ( + // StringToTextString encodes Go string to CBOR text string (major type 3). + StringToTextString StringMode = iota + + // StringToByteString encodes Go string to CBOR byte string (major type 2). + StringToByteString +) + +func (st StringMode) cborType() (cborType, error) { + switch st { + case StringToTextString: + return cborTypeTextString, nil + case StringToByteString: + return cborTypeByteString, nil + } + return 0, errors.New("cbor: invalid StringType " + strconv.Itoa(int(st))) +} + // ShortestFloatMode specifies which floating-point format should // be used as the shortest possible format for CBOR encoding. // It is not used for encoding Infinity and NaN values. @@ -355,6 +376,11 @@ type EncOptions struct { // OmitEmptyMode specifies how to encode struct fields with omitempty tag. OmitEmpty OmitEmptyMode + + // String specifies which CBOR type to use when encoding Go strings. + // - CBOR text string (major type 3) is default + // - CBOR byte string (major type 2) + String StringMode } // CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding, @@ -533,18 +559,24 @@ func (opts EncOptions) encMode() (*encMode, error) { if !opts.OmitEmpty.valid() { return nil, errors.New("cbor: invalid OmitEmpty " + strconv.Itoa(int(opts.OmitEmpty))) } + stringMajorType, err := opts.String.cborType() + if err != nil { + return nil, err + } em := encMode{ - sort: opts.Sort, - shortestFloat: opts.ShortestFloat, - nanConvert: opts.NaNConvert, - infConvert: opts.InfConvert, - bigIntConvert: opts.BigIntConvert, - time: opts.Time, - timeTag: opts.TimeTag, - indefLength: opts.IndefLength, - nilContainers: opts.NilContainers, - tagsMd: opts.TagsMd, - omitEmpty: opts.OmitEmpty, + sort: opts.Sort, + shortestFloat: opts.ShortestFloat, + nanConvert: opts.NaNConvert, + infConvert: opts.InfConvert, + bigIntConvert: opts.BigIntConvert, + time: opts.Time, + timeTag: opts.TimeTag, + indefLength: opts.IndefLength, + nilContainers: opts.NilContainers, + tagsMd: opts.TagsMd, + omitEmpty: opts.OmitEmpty, + stringType: opts.String, + stringMajorType: stringMajorType, } return &em, nil } @@ -557,21 +589,23 @@ type EncMode interface { } type encMode struct { - tags tagProvider - sort SortMode - shortestFloat ShortestFloatMode - nanConvert NaNConvertMode - infConvert InfConvertMode - bigIntConvert BigIntConvertMode - time TimeMode - timeTag EncTagMode - indefLength IndefLengthMode - nilContainers NilContainersMode - tagsMd TagsMode - omitEmpty OmitEmptyMode -} - -var defaultEncMode = &encMode{} + tags tagProvider + sort SortMode + shortestFloat ShortestFloatMode + nanConvert NaNConvertMode + infConvert InfConvertMode + bigIntConvert BigIntConvertMode + time TimeMode + timeTag EncTagMode + indefLength IndefLengthMode + nilContainers NilContainersMode + tagsMd TagsMode + omitEmpty OmitEmptyMode + stringType StringMode + stringMajorType cborType +} + +var defaultEncMode, _ = EncOptions{}.encMode() // EncOptions returns user specified options used to create this EncMode. func (em *encMode) EncOptions() EncOptions { @@ -586,6 +620,7 @@ func (em *encMode) EncOptions() EncOptions { IndefLength: em.indefLength, TagsMd: em.tagsMd, OmitEmpty: em.omitEmpty, + String: em.stringType, } } @@ -882,7 +917,7 @@ func encodeString(e *encoderBuffer, em *encMode, v reflect.Value) error { e.Write(b) } s := v.String() - encodeHead(e, byte(cborTypeTextString), uint64(len(s))) + encodeHead(e, byte(em.stringMajorType), uint64(len(s))) e.WriteString(s) return nil } diff --git a/encode_test.go b/encode_test.go index 8a5d1b5a..ec92334f 100644 --- a/encode_test.go +++ b/encode_test.go @@ -3675,6 +3675,29 @@ func TestEncModeInvalidTimeTag(t *testing.T) { } } +func TestEncModeStringType(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + wantErrorMsg string + }{ + { + name: "", + opts: EncOptions{String: -1}, + wantErrorMsg: "cbor: invalid StringType -1", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.EncMode() + if err == nil { + t.Errorf("EncMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("EncMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + func TestEncIndefiniteLengthOption(t *testing.T) { // Default option allows indefinite length items var buf bytes.Buffer @@ -3994,3 +4017,41 @@ func TestMapWithSimpleValueKey(t *testing.T) { t.Errorf("Marshal(%v) = 0x%x, want 0x%x", v, encodedData, data) } } + +func TestMarshalStringType(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + in string + want []byte + }{ + { + name: "to byte string", + opts: EncOptions{String: StringToByteString}, + in: "01234", + want: hexDecode("453031323334"), + }, + { + name: "to text string", + opts: EncOptions{String: StringToTextString}, + in: "01234", + want: hexDecode("653031323334"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + em, err := tc.opts.EncMode() + if err != nil { + t.Fatal(err) + } + + got, err := em.Marshal(tc.in) + if err != nil { + t.Errorf("unexpected error from Marshal(%q): %v", tc.in, err) + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("Marshal(%q): wanted %x, got %x", tc.in, tc.want, got) + } + }) + } +} From ebc84ccb9cd354f08b874c8264cb28ccbc22512d Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Fri, 15 Dec 2023 16:40:19 -0500 Subject: [PATCH 2/3] Add option to decode CBOR byte string into interface{} as Go string. Signed-off-by: Ben Luddy --- bytestring.go | 3 +- decode.go | 166 +++++++++++++++++++++++++++++++------------------ decode_test.go | 98 ++++++++++++++++++++++++++++- diagnose.go | 6 +- 4 files changed, 207 insertions(+), 66 deletions(-) diff --git a/bytestring.go b/bytestring.go index 26f5ef91..52a28eda 100644 --- a/bytestring.go +++ b/bytestring.go @@ -57,6 +57,7 @@ func (bs *ByteString) UnmarshalCBOR(data []byte) error { return &UnmarshalTypeError{CBORType: typ.String(), GoType: typeByteString.String()} } - *bs = ByteString(d.parseByteString()) + b, _ := d.parseByteString() + *bs = ByteString(b) return nil } diff --git a/decode.go b/decode.go index ffb62e82..06da2b27 100644 --- a/decode.go +++ b/decode.go @@ -467,6 +467,12 @@ type DecOptions struct { // BigIntDec specifies how to decode CBOR bignum to Go interface{}. BigIntDec BigIntDecMode + + // DefaultByteStringType is the Go type that should be produced when decoding a CBOR byte + // string into an empty interface value. Types to which a []byte is convertible are valid + // for this option, except for array and pointer-to-array types. If nil, the default is + // []byte. + DefaultByteStringType reflect.Type } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -581,21 +587,25 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.BigIntDec.valid() { return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec))) } + if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) { + return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType) + } dm := decMode{ - dupMapKey: opts.DupMapKey, - timeTag: opts.TimeTag, - maxNestedLevels: opts.MaxNestedLevels, - maxArrayElements: opts.MaxArrayElements, - maxMapPairs: opts.MaxMapPairs, - indefLength: opts.IndefLength, - tagsMd: opts.TagsMd, - intDec: opts.IntDec, - mapKeyByteString: opts.MapKeyByteString, - extraReturnErrors: opts.ExtraReturnErrors, - defaultMapType: opts.DefaultMapType, - utf8: opts.UTF8, - fieldNameMatching: opts.FieldNameMatching, - bigIntDec: opts.BigIntDec, + dupMapKey: opts.DupMapKey, + timeTag: opts.TimeTag, + maxNestedLevels: opts.MaxNestedLevels, + maxArrayElements: opts.MaxArrayElements, + maxMapPairs: opts.MaxMapPairs, + indefLength: opts.IndefLength, + tagsMd: opts.TagsMd, + intDec: opts.IntDec, + mapKeyByteString: opts.MapKeyByteString, + extraReturnErrors: opts.ExtraReturnErrors, + defaultMapType: opts.DefaultMapType, + utf8: opts.UTF8, + fieldNameMatching: opts.FieldNameMatching, + bigIntDec: opts.BigIntDec, + defaultByteStringType: opts.DefaultByteStringType, } return &dm, nil } @@ -647,21 +657,22 @@ type DecMode interface { } type decMode struct { - tags tagProvider - dupMapKey DupMapKeyMode - timeTag DecTagMode - maxNestedLevels int - maxArrayElements int - maxMapPairs int - indefLength IndefLengthMode - tagsMd TagsMode - intDec IntDecMode - mapKeyByteString MapKeyByteStringMode - extraReturnErrors ExtraDecErrorCond - defaultMapType reflect.Type - utf8 UTF8Mode - fieldNameMatching FieldNameMatchingMode - bigIntDec BigIntDecMode + tags tagProvider + dupMapKey DupMapKeyMode + timeTag DecTagMode + maxNestedLevels int + maxArrayElements int + maxMapPairs int + indefLength IndefLengthMode + tagsMd TagsMode + intDec IntDecMode + mapKeyByteString MapKeyByteStringMode + extraReturnErrors ExtraDecErrorCond + defaultMapType reflect.Type + utf8 UTF8Mode + fieldNameMatching FieldNameMatchingMode + bigIntDec BigIntDecMode + defaultByteStringType reflect.Type } var defaultDecMode, _ = DecOptions{}.decMode() @@ -669,19 +680,20 @@ var defaultDecMode, _ = DecOptions{}.decMode() // DecOptions returns user specified options used to create this DecMode. func (dm *decMode) DecOptions() DecOptions { return DecOptions{ - DupMapKey: dm.dupMapKey, - TimeTag: dm.timeTag, - MaxNestedLevels: dm.maxNestedLevels, - MaxArrayElements: dm.maxArrayElements, - MaxMapPairs: dm.maxMapPairs, - IndefLength: dm.indefLength, - TagsMd: dm.tagsMd, - IntDec: dm.intDec, - MapKeyByteString: dm.mapKeyByteString, - ExtraReturnErrors: dm.extraReturnErrors, - UTF8: dm.utf8, - FieldNameMatching: dm.fieldNameMatching, - BigIntDec: dm.bigIntDec, + DupMapKey: dm.dupMapKey, + TimeTag: dm.timeTag, + MaxNestedLevels: dm.maxNestedLevels, + MaxArrayElements: dm.maxArrayElements, + MaxMapPairs: dm.maxMapPairs, + IndefLength: dm.indefLength, + TagsMd: dm.tagsMd, + IntDec: dm.intDec, + MapKeyByteString: dm.mapKeyByteString, + ExtraReturnErrors: dm.extraReturnErrors, + UTF8: dm.utf8, + FieldNameMatching: dm.fieldNameMatching, + BigIntDec: dm.bigIntDec, + DefaultByteStringType: dm.defaultByteStringType, } } @@ -979,8 +991,8 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return fillNegativeInt(t, nValue, v) case cborTypeByteString: - b := d.parseByteString() - return fillByteString(t, b, v) + b, copied := d.parseByteString() + return fillByteString(t, b, !copied, v) case cborTypeTextString: b, err := d.parseTextString() @@ -1017,7 +1029,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin switch tagNum { case 2: // Bignum (tag 2) can be decoded to uint, int, float, slice, array, or big.Int. - b := d.parseByteString() + b, copied := d.parseByteString() bi := new(big.Int).SetBytes(b) if tInfo.nonPtrType == typeBigInt { @@ -1025,7 +1037,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, v) + return fillByteString(t, b, !copied, v) } if bi.IsUint64() { return fillPositiveInt(t, bi.Uint64(), v) @@ -1037,7 +1049,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin } case 3: // Bignum (tag 3) can be decoded to int, float, slice, array, or big.Int. - b := d.parseByteString() + b, copied := d.parseByteString() bi := new(big.Int).SetBytes(b) bi.Add(bi, big.NewInt(1)) bi.Neg(bi) @@ -1047,7 +1059,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, v) + return fillByteString(t, b, !copied, v) } if bi.IsInt64() { return fillNegativeInt(t, bi.Int64(), v) @@ -1279,7 +1291,29 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return nValue, nil case cborTypeByteString: - return d.parseByteString(), nil + switch d.dm.defaultByteStringType { + case nil, typeByteSlice: + b, copied := d.parseByteString() + if copied { + return b, nil + } + clone := make([]byte, len(b)) + copy(clone, b) + return clone, nil + case typeString: + b, _ := d.parseByteString() + return string(b), nil + default: + b, copied := d.parseByteString() + if copied || d.dm.defaultByteStringType.Kind() == reflect.String { + // Avoid an unnecessary copy since the conversion to string must + // copy the underlying bytes. + return reflect.ValueOf(b).Convert(d.dm.defaultByteStringType).Interface(), nil + } + clone := make([]byte, len(b)) + copy(clone, b) + return reflect.ValueOf(clone).Convert(d.dm.defaultByteStringType).Interface(), nil + } case cborTypeTextString: b, err := d.parseTextString() if err != nil { @@ -1296,7 +1330,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli d.off = tagOff return d.parseToTime() case 2: - b := d.parseByteString() + b, _ := d.parseByteString() bi := new(big.Int).SetBytes(b) if d.dm.bigIntDec == BigIntDecodePointer { @@ -1304,7 +1338,7 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli } return *bi, nil case 3: - b := d.parseByteString() + b, _ := d.parseByteString() bi := new(big.Int).SetBytes(b) bi.Add(bi, big.NewInt(1)) bi.Neg(bi) @@ -1376,15 +1410,16 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return nil, nil } -// parseByteString parses CBOR encoded byte string. It returns a byte slice -// pointing to a copy of parsed data. -func (d *decoder) parseByteString() []byte { +// parseByteString parses a CBOR encoded byte string. The returned byte slice +// may be backed directly by the input. The second return value will be true if +// and only if the slice is backed by a copy of the input. Callers are +// responsible for making a copy if necessary. +func (d *decoder) parseByteString() ([]byte, bool) { _, ai, val := d.getHead() if ai != 31 { - b := make([]byte, int(val)) - copy(b, d.data[d.off:d.off+int(val)]) + b := d.data[d.off : d.off+int(val)] d.off += int(val) - return b + return b, false } // Process indefinite length string chunks. b := []byte{} @@ -1393,7 +1428,7 @@ func (d *decoder) parseByteString() []byte { b = append(b, d.data[d.off:d.off+int(val)]...) d.off += int(val) } - return b + return b, true } // parseTextString parses CBOR encoded text string. It returns a byte slice @@ -2082,6 +2117,8 @@ var ( typeBigInt = reflect.TypeOf(big.Int{}) typeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() + typeString = reflect.TypeOf("") + typeByteSlice = reflect.TypeOf([]byte(nil)) ) func fillNil(_ cborType, v reflect.Value) error { @@ -2184,18 +2221,27 @@ func fillFloat(t cborType, val float64, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillByteString(t cborType, val []byte, v reflect.Value) error { +func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error { if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) { if v.CanAddr() { v = v.Addr() if u, ok := v.Interface().(encoding.BinaryUnmarshaler); ok { + // The contract of BinaryUnmarshaler forbids + // retaining the input bytes, so no copying is + // required even if val is shared. return u.UnmarshalBinary(val) } } return errors.New("cbor: cannot set new value for " + v.Type().String()) } if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { - v.SetBytes(val) + src := val + if shared { + // SetBytes shares the underlying bytes of the source slice. + src = make([]byte, len(val)) + copy(src, val) + } + v.SetBytes(src) return nil } if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 { diff --git a/decode_test.go b/decode_test.go index c029f50e..da50bffb 100644 --- a/decode_test.go +++ b/decode_test.go @@ -30,8 +30,6 @@ var ( typeInt64 = reflect.TypeOf(int64(0)) typeFloat32 = reflect.TypeOf(float32(0)) typeFloat64 = reflect.TypeOf(float64(0)) - typeString = reflect.TypeOf("") - typeByteSlice = reflect.TypeOf([]byte(nil)) typeByteArray = reflect.TypeOf([5]byte{}) typeIntSlice = reflect.TypeOf([]int{}) typeStringSlice = reflect.TypeOf([]string{}) @@ -8150,6 +8148,102 @@ func TestDecodeBignumToEmptyInterface(t *testing.T) { } } +func TestDecModeInvalidDefaultByteStringType(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "neither slice nor string", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf(int(42))}, + wantErrorMsg: "cbor: invalid DefaultByteStringType: int is not of kind string or []uint8", + }, + { + name: "slice of non-byte", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf([]int{})}, + wantErrorMsg: "cbor: invalid DefaultByteStringType: []int is not of kind string or []uint8", + }, + { + name: "pointer to byte array", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf(&[42]byte{})}, + wantErrorMsg: "cbor: invalid DefaultByteStringType: *[42]uint8 is not of kind string or []uint8", + }, + { + name: "byte array", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf([42]byte{})}, + wantErrorMsg: "cbor: invalid DefaultByteStringType: [42]uint8 is not of kind string or []uint8", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestUnmarshalDefaultByteStringType(t *testing.T) { + type namedByteSliceType []byte + + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want interface{} + }{ + { + name: "default to []byte", + opts: DecOptions{}, + in: hexDecode("43414243"), + want: []byte("ABC"), + }, + { + name: "explicitly []byte", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf([]byte(nil))}, + in: hexDecode("43414243"), + want: []byte("ABC"), + }, + { + name: "string", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf("")}, + in: hexDecode("43414243"), + want: "ABC", + }, + { + name: "ByteString", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf(ByteString(""))}, + in: hexDecode("43414243"), + want: ByteString("ABC"), + }, + { + name: "named []byte type", + opts: DecOptions{DefaultByteStringType: reflect.TypeOf(namedByteSliceType(nil))}, + in: hexDecode("43414243"), + want: namedByteSliceType("ABC"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var got interface{} + if err := dm.Unmarshal(tc.in, &got); err != nil { + t.Errorf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(tc.want, got) { + t.Errorf("got %#v, want %#v", got, tc.want) + } + }) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) } diff --git a/diagnose.go b/diagnose.go index dcb03002..43e6a14c 100644 --- a/diagnose.go +++ b/diagnose.go @@ -354,7 +354,7 @@ func (di *diagnose) item() error { //nolint:gocyclo return di.writeString(strconv.FormatInt(nValue, 10)) case cborTypeByteString: - b := di.d.parseByteString() + b, _ := di.d.parseByteString() return di.encodeByteString(b) case cborTypeTextString: @@ -418,7 +418,7 @@ func (di *diagnose) item() error { //nolint:gocyclo return errors.New("cbor: tag number 2 must be followed by byte string, got " + nt.String()) } - b := di.d.parseByteString() + b, _ := di.d.parseByteString() bi := new(big.Int).SetBytes(b) return di.writeString(bi.String()) @@ -427,7 +427,7 @@ func (di *diagnose) item() error { //nolint:gocyclo return errors.New("cbor: tag number 3 must be followed by byte string, got " + nt.String()) } - b := di.d.parseByteString() + b, _ := di.d.parseByteString() bi := new(big.Int).SetBytes(b) bi.Add(bi, big.NewInt(1)) bi.Neg(bi) From ae20110f264e7fcd5074c4860e132aab38be5cee Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Wed, 20 Dec 2023 12:12:16 -0500 Subject: [PATCH 3/3] Add option to permit decoding CBOR byte strings into Go strings. The unchanged default behavior is to produce an UnmarshalTypeError when decoding a CBOR byte string into a Go string. Signed-off-by: Ben Luddy --- decode.go | 38 +++++++++++++++++++++++++++++---- decode_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/decode.go b/decode.go index 06da2b27..0c6df079 100644 --- a/decode.go +++ b/decode.go @@ -412,6 +412,23 @@ func (bidm BigIntDecMode) valid() bool { return bidm >= 0 && bidm < maxBigIntDecMode } +// ByteStringToStringMode specifies the behavior when decoding a CBOR byte string into a Go string. +type ByteStringToStringMode int + +const ( + // ByteStringToStringForbidden generates an error on an attempt to decode a CBOR byte string into a Go string. + ByteStringToStringForbidden ByteStringToStringMode = iota + + // ByteStringToStringAllowed permits decoding a CBOR byte string into a Go string. + ByteStringToStringAllowed + + maxByteStringToStringMode +) + +func (bstsm ByteStringToStringMode) valid() bool { + return bstsm >= 0 && bstsm < maxByteStringToStringMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -473,6 +490,9 @@ type DecOptions struct { // for this option, except for array and pointer-to-array types. If nil, the default is // []byte. DefaultByteStringType reflect.Type + + // ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string. + ByteStringToString ByteStringToStringMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -590,6 +610,9 @@ func (opts DecOptions) decMode() (*decMode, error) { if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) { return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType) } + if !opts.ByteStringToString.valid() { + return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString))) + } dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -606,6 +629,7 @@ func (opts DecOptions) decMode() (*decMode, error) { fieldNameMatching: opts.FieldNameMatching, bigIntDec: opts.BigIntDec, defaultByteStringType: opts.DefaultByteStringType, + byteStringToString: opts.ByteStringToString, } return &dm, nil } @@ -673,6 +697,7 @@ type decMode struct { fieldNameMatching FieldNameMatchingMode bigIntDec BigIntDecMode defaultByteStringType reflect.Type + byteStringToString ByteStringToStringMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -694,6 +719,7 @@ func (dm *decMode) DecOptions() DecOptions { FieldNameMatching: dm.fieldNameMatching, BigIntDec: dm.bigIntDec, DefaultByteStringType: dm.defaultByteStringType, + ByteStringToString: dm.byteStringToString, } } @@ -992,7 +1018,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case cborTypeByteString: b, copied := d.parseByteString() - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, d.dm.byteStringToString) case cborTypeTextString: b, err := d.parseTextString() @@ -1037,7 +1063,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, ByteStringToStringForbidden) } if bi.IsUint64() { return fillPositiveInt(t, bi.Uint64(), v) @@ -1059,7 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, ByteStringToStringForbidden) } if bi.IsInt64() { return fillNegativeInt(t, bi.Int64(), v) @@ -2221,7 +2247,7 @@ func fillFloat(t cborType, val float64, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error { +func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error { if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) { if v.CanAddr() { v = v.Addr() @@ -2234,6 +2260,10 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error } return errors.New("cbor: cannot set new value for " + v.Type().String()) } + if bsts == ByteStringToStringAllowed && v.Kind() == reflect.String { + v.SetString(string(val)) + return nil + } if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { src := val if shared { diff --git a/decode_test.go b/decode_test.go index da50bffb..8da98997 100644 --- a/decode_test.go +++ b/decode_test.go @@ -8244,6 +8244,64 @@ func TestUnmarshalDefaultByteStringType(t *testing.T) { } } +func TestDecModeInvalidByteStringToStringMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{ByteStringToString: -1}, + wantErrorMsg: "cbor: invalid ByteStringToString -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{ByteStringToString: 101}, + wantErrorMsg: "cbor: invalid ByteStringToString 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestUnmarshalByteStringToString(t *testing.T) { + var s string + + derror, err := DecOptions{ByteStringToString: ByteStringToStringForbidden}.DecMode() + if err != nil { + t.Fatal(err) + } + + if err = derror.Unmarshal(hexDecode("43414243"), &s); err == nil { + t.Error("expected non-nil error from Unmarshal") + } + + if s != "" { + t.Errorf("expected destination string to be empty, got %q", s) + } + + dallow, err := DecOptions{ByteStringToString: ByteStringToStringAllowed}.DecMode() + if err != nil { + t.Fatal(err) + } + + if err = dallow.Unmarshal(hexDecode("43414243"), &s); err != nil { + t.Errorf("expected nil error from Unmarshal, got: %v", err) + } + + if s != "ABC" { + t.Errorf("expected destination string to be \"ABC\", got %q", s) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }