Skip to content

Commit

Permalink
Fix decoding CBOR nil to *cbor.SimpleValue
Browse files Browse the repository at this point in the history
Unmarshalling CBOR nil or CBOR undefined into a Go pointer should
always set the pointer to nil.

This commit fixes crash bug when unmarshalling CBOR nil or CBOR
undefined into *cbor.SimpleValue by setting the pointer to nil.

Also, added more tests for decoding to uninitialized and initialized
pointer values.

Separately (not part of this commit), the fuzzer was updated to
attempt unmarshaling to *cbor.SimpleValue.
  • Loading branch information
fxamacker committed Jan 1, 2024
1 parent 5ff9771 commit 4a2755d
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 66 deletions.
33 changes: 23 additions & 10 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,13 @@ const (
// and does not perform bounds checking.
func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolint:gocyclo

// Decode CBOR nil/undefined to pointer value by setting pointer value to nil.
if d.nextCBORNil() && v.Kind() == reflect.Ptr {
d.skip()
v.Set(reflect.Zero(v.Type()))
return nil
}

if tInfo.spclType == specialTypeIface {
if !v.IsNil() {
// Use value type
Expand Down Expand Up @@ -864,18 +871,17 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
}

// Create new value for the pointer v to point to if CBOR value is not nil/undefined.
if !d.nextCBORNil() {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
if !v.CanSet() {
d.skip()
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
v.Set(reflect.New(v.Type().Elem()))
// Create new value for the pointer v to point to.
// At this point, CBOR value is not nil/undefined if v is a pointer.
for v.Kind() == reflect.Ptr {
if v.IsNil() {
if !v.CanSet() {
d.skip()
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
v = v.Elem()
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}

// Strip self-described CBOR tag number.
Expand Down Expand Up @@ -949,6 +955,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
case cborTypePositiveInt:
_, _, val := d.getHead()
return fillPositiveInt(t, val, v)

case cborTypeNegativeInt:
_, _, val := d.getHead()
if val > math.MaxInt64 {
Expand All @@ -970,15 +977,18 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
nValue := int64(-1) ^ int64(val)
return fillNegativeInt(t, nValue, v)

case cborTypeByteString:
b := d.parseByteString()
return fillByteString(t, b, v)

case cborTypeTextString:
b, err := d.parseTextString()
if err != nil {
return err
}
return fillTextString(t, b, v)

case cborTypePrimitives:
_, ai, val := d.getHead()
switch ai {
Expand Down Expand Up @@ -1054,6 +1064,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
}
return d.parseToValue(v, tInfo)

case cborTypeArray:
if tInfo.nonPtrKind == reflect.Slice {
return d.parseArrayToSlice(v, tInfo)
Expand All @@ -1064,6 +1075,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
d.skip()
return &UnmarshalTypeError{CBORType: t.String(), GoType: tInfo.nonPtrType.String()}

case cborTypeMap:
if tInfo.nonPtrKind == reflect.Struct {
return d.parseMapToStruct(v, tInfo)
Expand All @@ -1073,6 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
d.skip()
return &UnmarshalTypeError{CBORType: t.String(), GoType: tInfo.nonPtrType.String()}
}

return nil
}

Expand Down
215 changes: 159 additions & 56 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1764,14 +1764,12 @@ var unmarshalTests = []unmarshalTest{
// This example is not well-formed because Simple value (with 5-bit value 24) must be >= 32.
// See RFC 7049 section 2.3 for details, instead of the incorrect example in RFC 7049 Appendex A.
// I reported an errata to RFC 7049 and Carsten Bormann confirmed at https://github.com/fxamacker/cbor/issues/46
/*
{
data: hexDecode("f818"),
wantInterfaceValue: uint64(24),
wantValues: []interface{}{uint8(24), uint16(24), uint32(24), uint64(24), uint(24), int8(24), int16(24), int32(24), int64(24), int(24), float32(24), float64(24)},
wrongTypes: []reflect.Type{typeByteSlice, typeString, typeBool, typeIntSlice, typeMapStringInt},
},
*/
// {
// data: hexDecode("f818"),
// wantInterfaceValue: uint64(24),
// wantValues: []interface{}{uint8(24), uint16(24), uint32(24), uint64(24), uint(24), int8(24), int16(24), int32(24), int64(24), int(24), float32(24), float64(24)},
// wrongTypes: []reflect.Type{typeByteSlice, typeString, typeBool, typeIntSlice, typeMapStringInt},
// },
{
data: hexDecode("f820"),
wantInterfaceValue: SimpleValue(32),
Expand Down Expand Up @@ -2252,48 +2250,169 @@ func TestUnmarshalToEmptyInterface(t *testing.T) {

func TestUnmarshalToRawMessage(t *testing.T) {
for _, tc := range unmarshalTests {
var r RawMessage
if err := Unmarshal(tc.data, &r); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err)
continue
testUnmarshalToRawMessage(t, tc.data)
}
}

func testUnmarshalToRawMessage(t *testing.T, data []byte) {
cborNil := isCBORNil(data)

// Decode to RawMessage
var r RawMessage
if err := Unmarshal(data, &r); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
} else if !bytes.Equal(r, data) {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, r, data)
}

// Decode to *RawMesage (pr is nil)
var pr *RawMessage
if err := Unmarshal(data, &pr); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
} else {
if cborNil {
if pr != nil {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want nil *RawMessage", data, *pr)
}
} else {
if !bytes.Equal(*pr, data) {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, *pr, data)
}
}
if !bytes.Equal(r, tc.data) {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", tc.data, r, tc.data)
}

// Decode to *RawMessage (pr is not nil)
var ir RawMessage
pr = &ir
if err := Unmarshal(data, &pr); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
} else {
if cborNil {
if pr != nil {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want nil *RawMessage", data, *pr)
}
} else {
if !bytes.Equal(*pr, data) {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", data, *pr, data)
}
}
}
}

func TestUnmarshalToCompatibleTypes(t *testing.T) {
for _, tc := range unmarshalTests {
for _, wantValue := range tc.wantValues {
rv := reflect.New(reflect.TypeOf(wantValue))
if err := Unmarshal(tc.data, rv.Interface()); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err)
continue
}
compareNonFloats(t, tc.data, rv.Elem().Interface(), wantValue)
testUnmarshalToCompatibleType(t, tc.data, wantValue, func(gotValue interface{}) {
compareNonFloats(t, tc.data, gotValue, wantValue)
})
}
}
}

func testUnmarshalToCompatibleType(t *testing.T, data []byte, wantValue interface{}, compare func(gotValue interface{})) {
var rv reflect.Value

cborNil := isCBORNil(data)
wantType := reflect.TypeOf(wantValue)

// Decode to wantType, same as:
// var v wantType
// Unmarshal(tc.data, &v)

rv = reflect.New(wantType)
if err := Unmarshal(data, rv.Interface()); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
return
}
compare(rv.Elem().Interface())

// Decode to *wantType (pv is nil), same as:
// var pv *wantType
// Unmarshal(tc.data, &pv)

rv = reflect.New(reflect.PointerTo(wantType))

Check failure on line 2333 in decode_test.go

View workflow job for this annotation

GitHub Actions / test ubuntu-latest go-1.17

undefined: reflect.PointerTo
if err := Unmarshal(data, rv.Interface()); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
return
}
if cborNil {
if !rv.Elem().IsNil() {
t.Errorf("Unmarshal(0x%x) = %v (%T), want nil", data, rv.Elem().Interface(), rv.Elem().Interface())
}
} else {
compare(rv.Elem().Elem().Interface())
}

// Decode to *wantType (pv is not nil), same as:
// var v wantType
// pv := &v
// Unmarshal(tc.data, &pv)

irv := reflect.New(wantType)
rv = reflect.New(reflect.PointerTo(wantType))

Check failure on line 2352 in decode_test.go

View workflow job for this annotation

GitHub Actions / test ubuntu-latest go-1.17

undefined: reflect.PointerTo
rv.Elem().Set(irv)
if err := Unmarshal(data, rv.Interface()); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", data, err)
return
}
if cborNil {
if !rv.Elem().IsNil() {
t.Errorf("Unmarshal(0x%x) = %v (%T), want nil", data, rv.Elem().Interface(), rv.Elem().Interface())
}
} else {
compare(rv.Elem().Elem().Interface())
}
}

func TestUnmarshalToIncompatibleTypes(t *testing.T) {
for _, tc := range unmarshalTests {
for _, wrongType := range tc.wrongTypes {
rv := reflect.New(wrongType)
err := Unmarshal(tc.data, rv.Interface())
if err == nil {
t.Errorf("Unmarshal(0x%x, %s) didn't return an error", tc.data, wrongType.String())
continue
}
if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", tc.data, err)
} else if !strings.Contains(err.Error(), "cannot unmarshal") {
t.Errorf("Unmarshal(0x%x) returned error %q, want error containing %q", tc.data, err.Error(), "cannot unmarshal")
}
testUnmarshalToIncompatibleType(t, tc.data, wrongType)
}
}
}

func testUnmarshalToIncompatibleType(t *testing.T, data []byte, wrongType reflect.Type) {
var rv reflect.Value

// Decode to wrongType, same as:
// var v wrongType
// Unmarshal(tc.data, &v)

rv = reflect.New(wrongType)
if err := Unmarshal(data, rv.Interface()); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", data)
} else if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err)
}

// Decode to *wrongType (pv is nil), same as:
// var pv *wrongType
// Unmarshal(tc.data, &pv)

rv = reflect.New(reflect.PointerTo(wrongType))

Check failure on line 2393 in decode_test.go

View workflow job for this annotation

GitHub Actions / test ubuntu-latest go-1.17

undefined: reflect.PointerTo
if err := Unmarshal(data, rv.Interface()); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", data)
} else if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err)
}

// Decode to *wrongType (pv is not nil), same as:
// var v wrongType
// pv := &v
// Unmarshal(tc.data, &pv)

irv := reflect.New(wrongType)
rv = reflect.New(reflect.PointerTo(wrongType))

Check failure on line 2406 in decode_test.go

View workflow job for this annotation

GitHub Actions / test ubuntu-latest go-1.17

undefined: reflect.PointerTo
rv.Elem().Set(irv)

if err := Unmarshal(data, rv.Interface()); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", data)
} else if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", data, err)
}
}

func compareNonFloats(t *testing.T, data []byte, got interface{}, want interface{}) {
switch tm := want.(type) {
case time.Time:
Expand Down Expand Up @@ -2321,44 +2440,24 @@ func TestUnmarshalFloatToEmptyInterface(t *testing.T) {

func TestUnmarshalFloatToRawMessage(t *testing.T) {
for _, tc := range unmarshalFloatTests {
var r RawMessage
if err := Unmarshal(tc.data, &r); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err)
continue
}
if !bytes.Equal(r, tc.data) {
t.Errorf("Unmarshal(0x%x) returned RawMessage %v, want %v", tc.data, r, tc.data)
}
testUnmarshalToRawMessage(t, tc.data)
}
}

func TestUnmarshalFloatToCompatibleTypes(t *testing.T) {
for _, tc := range unmarshalFloatTests {
for _, wantValue := range tc.wantValues {
rv := reflect.New(reflect.TypeOf(wantValue))
if err := Unmarshal(tc.data, rv.Interface()); err != nil {
t.Errorf("Unmarshal(0x%x) returned error %v", tc.data, err)
continue
}
compareFloats(t, tc.data, rv.Elem().Interface(), wantValue, tc.equalityThreshold)
testUnmarshalToCompatibleType(t, tc.data, wantValue, func(gotValue interface{}) {
compareFloats(t, tc.data, gotValue, wantValue, tc.equalityThreshold)
})
}
}
}

func TestUnmarshalFloatToIncompatibleTypes(t *testing.T) {
for _, tc := range unmarshalFloatTests {
for _, wrongType := range unmarshalFloatWrongTypes {
rv := reflect.New(wrongType)
err := Unmarshal(tc.data, rv.Interface())
if err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", tc.data)
continue
}
if _, ok := err.(*UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*UnmarshalTypeError)", tc.data, err)
} else if !strings.Contains(err.Error(), "cannot unmarshal") {
t.Errorf("Unmarshal(0x%x) returned error %q, want error containing %q", tc.data, err.Error(), "cannot unmarshal")
}
testUnmarshalToIncompatibleType(t, tc.data, wrongType)
}
}
}
Expand Down Expand Up @@ -7992,3 +8091,7 @@ func TestDecodeBignumToEmptyInterface(t *testing.T) {
})
}
}

func isCBORNil(data []byte) bool {
return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7)
}

0 comments on commit 4a2755d

Please sign in to comment.