diff --git a/decode.go b/decode.go index 248e8184..7068e36f 100644 --- a/decode.go +++ b/decode.go @@ -834,6 +834,13 @@ const ( // and does not perform bounds checking. func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolint:gocyclo + // Decode CBOR nil/undefined to pointer value by setting pointer value to nil. + if d.nextCBORNil() && v.Kind() == reflect.Ptr { + d.skip() + v.Set(reflect.Zero(v.Type())) + return nil + } + if tInfo.spclType == specialTypeIface { if !v.IsNil() { // Use value type @@ -864,18 +871,17 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin } } - // Create new value for the pointer v to point to if CBOR value is not nil/undefined. - if !d.nextCBORNil() { - for v.Kind() == reflect.Ptr { - if v.IsNil() { - if !v.CanSet() { - d.skip() - return errors.New("cbor: cannot set new value for " + v.Type().String()) - } - v.Set(reflect.New(v.Type().Elem())) + // Create new value for the pointer v to point to. + // At this point, CBOR value is not nil/undefined if v is a pointer. + for v.Kind() == reflect.Ptr { + if v.IsNil() { + if !v.CanSet() { + d.skip() + return errors.New("cbor: cannot set new value for " + v.Type().String()) } - v = v.Elem() + v.Set(reflect.New(v.Type().Elem())) } + v = v.Elem() } // Strip self-described CBOR tag number. @@ -949,6 +955,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case cborTypePositiveInt: _, _, val := d.getHead() return fillPositiveInt(t, val, v) + case cborTypeNegativeInt: _, _, val := d.getHead() if val > math.MaxInt64 { @@ -970,15 +977,18 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin } nValue := int64(-1) ^ int64(val) return fillNegativeInt(t, nValue, v) + case cborTypeByteString: b := d.parseByteString() return fillByteString(t, b, v) + case cborTypeTextString: b, err := d.parseTextString() if err != nil { return err } return fillTextString(t, b, v) + case cborTypePrimitives: _, ai, val := d.getHead() switch ai { @@ -1054,6 +1064,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin } } return d.parseToValue(v, tInfo) + case cborTypeArray: if tInfo.nonPtrKind == reflect.Slice { return d.parseArrayToSlice(v, tInfo) @@ -1064,6 +1075,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin } d.skip() return &UnmarshalTypeError{CBORType: t.String(), GoType: tInfo.nonPtrType.String()} + case cborTypeMap: if tInfo.nonPtrKind == reflect.Struct { return d.parseMapToStruct(v, tInfo) @@ -1073,6 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin d.skip() return &UnmarshalTypeError{CBORType: t.String(), GoType: tInfo.nonPtrType.String()} } + return nil } diff --git a/decode_test.go b/decode_test.go index 047c4282..6b7f6d62 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1764,14 +1764,12 @@ var unmarshalTests = []unmarshalTest{ // This example is not well-formed because Simple value (with 5-bit value 24) must be >= 32. // See RFC 7049 section 2.3 for details, instead of the incorrect example in RFC 7049 Appendex A. // I reported an errata to RFC 7049 and Carsten Bormann confirmed at https://github.com/fxamacker/cbor/issues/46 - /* - { - data: hexDecode("f818"), - wantInterfaceValue: uint64(24), - wantValues: []interface{}{uint8(24), uint16(24), uint32(24), uint64(24), uint(24), int8(24), int16(24), int32(24), int64(24), int(24), float32(24), float64(24)}, - wrongTypes: []reflect.Type{typeByteSlice, typeString, typeBool, typeIntSlice, typeMapStringInt}, - }, - */ + // { + // data: hexDecode("f818"), + // wantInterfaceValue: uint64(24), + // wantValues: []interface{}{uint8(24), uint16(24), uint32(24), uint64(24), uint(24), int8(24), int16(24), int32(24), int64(24), int(24), float32(24), float64(24)}, + // wrongTypes: []reflect.Type{typeByteSlice, typeString, typeBool, typeIntSlice, typeMapStringInt}, + // }, { data: hexDecode("f820"), wantInterfaceValue: SimpleValue(32), @@ -2252,13 +2250,51 @@ func TestUnmarshalToEmptyInterface(t *testing.T) { func TestUnmarshalToRawMessage(t *testing.T) { for _, tc := range unmarshalTests { - var r RawMessage - if err := Unmarshal(tc.data, &r); err != nil { - t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err) - continue + testUnmarshalToRawMessage(t, tc.data) + } +} + +func testUnmarshalToRawMessage(t *testing.T, data []byte) { + cborNil := isCBORNil(data) + + // Decode to RawMessage + var r RawMessage + if err := Unmarshal(data, &r); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + } else if !bytes.Equal(r, data) { + t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, r, data) + } + + // Decode to *RawMesage (pr is nil) + var pr *RawMessage + if err := Unmarshal(data, &pr); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + } else { + if cborNil { + if pr != nil { + t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want nil *RawMessage", data, *pr) + } + } else { + if !bytes.Equal(*pr, data) { + t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, *pr, data) + } } - if !bytes.Equal(r, tc.data) { - t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", tc.data, r, tc.data) + } + + // Decode to *RawMessage (pr is not nil) + var ir RawMessage + pr = &ir + if err := Unmarshal(data, &pr); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + } else { + if cborNil { + if pr != nil { + t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want nil *RawMessage", data, *pr) + } + } else { + if !bytes.Equal(*pr, data) { + t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, *pr, data) + } } } } @@ -2266,34 +2302,117 @@ func TestUnmarshalToRawMessage(t *testing.T) { func TestUnmarshalToCompatibleTypes(t *testing.T) { for _, tc := range unmarshalTests { for _, wantValue := range tc.wantValues { - rv := reflect.New(reflect.TypeOf(wantValue)) - if err := Unmarshal(tc.data, rv.Interface()); err != nil { - t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err) - continue - } - compareNonFloats(t, tc.data, rv.Elem().Interface(), wantValue) + testUnmarshalToCompatibleType(t, tc.data, wantValue, func(gotValue interface{}) { + compareNonFloats(t, tc.data, gotValue, wantValue) + }) } } } +func testUnmarshalToCompatibleType(t *testing.T, data []byte, wantValue interface{}, compare func(gotValue interface{})) { + var rv reflect.Value + + cborNil := isCBORNil(data) + wantType := reflect.TypeOf(wantValue) + + // Decode to wantType, same as: + // var v wantType + // Unmarshal(tc.data, &v) + + rv = reflect.New(wantType) + if err := Unmarshal(data, rv.Interface()); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + return + } + compare(rv.Elem().Interface()) + + // Decode to *wantType (pv is nil), same as: + // var pv *wantType + // Unmarshal(tc.data, &pv) + + rv = reflect.New(reflect.PointerTo(wantType)) + if err := Unmarshal(data, rv.Interface()); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + return + } + if cborNil { + if !rv.Elem().IsNil() { + t.Errorf("Unmarshal(0x%x) = %v (%T), want nil", data, rv.Elem().Interface(), rv.Elem().Interface()) + } + } else { + compare(rv.Elem().Elem().Interface()) + } + + // Decode to *wantType (pv is not nil), same as: + // var v wantType + // pv := &v + // Unmarshal(tc.data, &pv) + + irv := reflect.New(wantType) + rv = reflect.New(reflect.PointerTo(wantType)) + rv.Elem().Set(irv) + if err := Unmarshal(data, rv.Interface()); err != nil { + t.Errorf("Unmarshal(0x%x) returned error %v", data, err) + return + } + if cborNil { + if !rv.Elem().IsNil() { + t.Errorf("Unmarshal(0x%x) = %v (%T), want nil", data, rv.Elem().Interface(), rv.Elem().Interface()) + } + } else { + compare(rv.Elem().Elem().Interface()) + } +} + func TestUnmarshalToIncompatibleTypes(t *testing.T) { for _, tc := range unmarshalTests { for _, wrongType := range tc.wrongTypes { - rv := reflect.New(wrongType) - err := Unmarshal(tc.data, rv.Interface()) - if err == nil { - t.Errorf("Unmarshal(0x%x, %s) didn't return an error", tc.data, wrongType.String()) - continue - } - if _, ok := err.(*UnmarshalTypeError); !ok { - t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", tc.data, err) - } else if !strings.Contains(err.Error(), "cannot unmarshal") { - t.Errorf("Unmarshal(0x%x) returned error %q, want error containing %q", tc.data, err.Error(), "cannot unmarshal") - } + testUnmarshalToIncompatibleType(t, tc.data, wrongType) } } } +func testUnmarshalToIncompatibleType(t *testing.T, data []byte, wrongType reflect.Type) { + var rv reflect.Value + + // Decode to wrongType, same as: + // var v wrongType + // Unmarshal(tc.data, &v) + + rv = reflect.New(wrongType) + if err := Unmarshal(data, rv.Interface()); err == nil { + t.Errorf("Unmarshal(0x%x) didn't return an error", data) + } else if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err) + } + + // Decode to *wrongType (pv is nil), same as: + // var pv *wrongType + // Unmarshal(tc.data, &pv) + + rv = reflect.New(reflect.PointerTo(wrongType)) + if err := Unmarshal(data, rv.Interface()); err == nil { + t.Errorf("Unmarshal(0x%x) didn't return an error", data) + } else if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err) + } + + // Decode to *wrongType (pv is not nil), same as: + // var v wrongType + // pv := &v + // Unmarshal(tc.data, &pv) + + irv := reflect.New(wrongType) + rv = reflect.New(reflect.PointerTo(wrongType)) + rv.Elem().Set(irv) + + if err := Unmarshal(data, rv.Interface()); err == nil { + t.Errorf("Unmarshal(0x%x) didn't return an error", data) + } else if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err) + } +} + func compareNonFloats(t *testing.T, data []byte, got interface{}, want interface{}) { switch tm := want.(type) { case time.Time: @@ -2321,26 +2440,16 @@ func TestUnmarshalFloatToEmptyInterface(t *testing.T) { func TestUnmarshalFloatToRawMessage(t *testing.T) { for _, tc := range unmarshalFloatTests { - var r RawMessage - if err := Unmarshal(tc.data, &r); err != nil { - t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err) - continue - } - if !bytes.Equal(r, tc.data) { - t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", tc.data, r, tc.data) - } + testUnmarshalToRawMessage(t, tc.data) } } func TestUnmarshalFloatToCompatibleTypes(t *testing.T) { for _, tc := range unmarshalFloatTests { for _, wantValue := range tc.wantValues { - rv := reflect.New(reflect.TypeOf(wantValue)) - if err := Unmarshal(tc.data, rv.Interface()); err != nil { - t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err) - continue - } - compareFloats(t, tc.data, rv.Elem().Interface(), wantValue, tc.equalityThreshold) + testUnmarshalToCompatibleType(t, tc.data, wantValue, func(gotValue interface{}) { + compareFloats(t, tc.data, gotValue, wantValue, tc.equalityThreshold) + }) } } } @@ -2348,17 +2457,7 @@ func TestUnmarshalFloatToCompatibleTypes(t *testing.T) { func TestUnmarshalFloatToIncompatibleTypes(t *testing.T) { for _, tc := range unmarshalFloatTests { for _, wrongType := range unmarshalFloatWrongTypes { - rv := reflect.New(wrongType) - err := Unmarshal(tc.data, rv.Interface()) - if err == nil { - t.Errorf("Unmarshal(0x%x) didn't return an error", tc.data) - continue - } - if _, ok := err.(*UnmarshalTypeError); !ok { - t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", tc.data, err) - } else if !strings.Contains(err.Error(), "cannot unmarshal") { - t.Errorf("Unmarshal(0x%x) returned error %q, want error containing %q", tc.data, err.Error(), "cannot unmarshal") - } + testUnmarshalToIncompatibleType(t, tc.data, wrongType) } } } @@ -7992,3 +8091,7 @@ func TestDecodeBignumToEmptyInterface(t *testing.T) { }) } } + +func isCBORNil(data []byte) bool { + return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) +}