diff --git a/decode.go b/decode.go index 87a7a2f2..4c565063 100644 --- a/decode.go +++ b/decode.go @@ -412,6 +412,23 @@ func (bidm BigIntDecMode) valid() bool { return bidm >= 0 && bidm < maxBigIntDecMode } +// ByteStringToStringMode specifies the behavior when decoding a CBOR byte string into a Go string. +type ByteStringToStringMode int + +const ( + // ByteStringToStringError generates an error on an attempt to decode a CBOR byte string into a Go string. + ByteStringToStringError ByteStringToStringMode = iota + + // ByteStringToStringAllow permits decoding a CBOR byte string into a Go string. + ByteStringToStringAllow + + maxByteStringToStringMode +) + +func (bstsm ByteStringToStringMode) valid() bool { + return bstsm >= 0 && bstsm < maxByteStringToStringMode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -473,6 +490,9 @@ type DecOptions struct { // for this option, except for array and pointer-to-array types. If nil, the default is // []byte. DefaultByteStringType reflect.Type + + // ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string. + ByteStringToString ByteStringToStringMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -590,6 +610,9 @@ func (opts DecOptions) decMode() (*decMode, error) { if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) { return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType) } + if !opts.ByteStringToString.valid() { + return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString))) + } dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -606,6 +629,7 @@ func (opts DecOptions) decMode() (*decMode, error) { fieldNameMatching: opts.FieldNameMatching, bigIntDec: opts.BigIntDec, defaultByteStringType: opts.DefaultByteStringType, + byteStringToString: opts.ByteStringToString, } return &dm, nil } @@ -673,6 +697,7 @@ type decMode struct { fieldNameMatching FieldNameMatchingMode bigIntDec BigIntDecMode defaultByteStringType reflect.Type + byteStringToString ByteStringToStringMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -694,6 +719,7 @@ func (dm *decMode) DecOptions() DecOptions { FieldNameMatching: dm.fieldNameMatching, BigIntDec: dm.bigIntDec, DefaultByteStringType: dm.defaultByteStringType, + ByteStringToString: dm.byteStringToString, } } @@ -992,7 +1018,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case cborTypeByteString: b, copied := d.parseByteString() - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, d.dm.byteStringToString) case cborTypeTextString: b, err := d.parseTextString() @@ -1037,7 +1063,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, ByteStringToStringError) } if bi.IsUint64() { return fillPositiveInt(t, bi.Uint64(), v) @@ -1059,7 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin return nil } if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array { - return fillByteString(t, b, !copied, v) + return fillByteString(t, b, !copied, v, ByteStringToStringError) } if bi.IsInt64() { return fillNegativeInt(t, bi.Int64(), v) @@ -2219,7 +2245,7 @@ func fillFloat(t cborType, val float64, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error { +func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error { if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) { if v.CanAddr() { v = v.Addr() @@ -2232,6 +2258,10 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error } return errors.New("cbor: cannot set new value for " + v.Type().String()) } + if bsts == ByteStringToStringAllow && v.Kind() == reflect.String { + v.SetString(string(val)) + return nil + } if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { src := val if shared { diff --git a/decode_test.go b/decode_test.go index 937391bb..dc0c467a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -8246,6 +8246,64 @@ func TestUnmarshalDefaultByteStringType(t *testing.T) { } } +func TestDecModeInvalidByteStringToStringMode(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{ByteStringToString: -1}, + wantErrorMsg: "cbor: invalid ByteStringToString -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{ByteStringToString: 101}, + wantErrorMsg: "cbor: invalid ByteStringToString 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestUnmarshalByteStringToString(t *testing.T) { + var s string + + derror, err := DecOptions{ByteStringToString: ByteStringToStringError}.DecMode() + if err != nil { + t.Fatal(err) + } + + if err := derror.Unmarshal(hexDecode("43414243"), &s); err == nil { + t.Error("expected non-nil error from Unmarshal") + } + + if s != "" { + t.Errorf("expected destination string to be empty, got %q", s) + } + + dallow, err := DecOptions{ByteStringToString: ByteStringToStringAllow}.DecMode() + if err != nil { + t.Fatal(err) + } + + if dallow.Unmarshal(hexDecode("43414243"), &s); err != nil { + t.Errorf("expected nil error from Unmarshal, got: %v", err) + } + + if s != "ABC" { + t.Errorf("expected destination string to be \"ABC\", got %q", s) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }