diff --git a/decode.go b/decode.go index f7af4a21..248e8184 100644 --- a/decode.go +++ b/decode.go @@ -264,21 +264,40 @@ func (tm TagsMode) valid() bool { return tm >= 0 && tm < maxTagsMode } -// IntDecMode specifies which Go int type (int64 or uint64) should -// be used when decoding CBOR int (major type 0 and 1) to Go interface{}. +// IntDecMode specifies which Go type (int64, uint64, or big.Int) should +// be used when decoding CBOR integers (major type 0 and 1) to Go interface{}. type IntDecMode int const ( - // IntDecConvertNone affects how CBOR int (major type 0 and 1) decodes to Go interface{}. - // It makes CBOR positive int (major type 0) decode to uint64 value, and - // CBOR negative int (major type 1) decode to int64 value. + // IntDecConvertNone affects how CBOR integers (major type 0 and 1) decode to Go interface{}. + // It decodes CBOR unsigned integer (major type 0) to: + // - uint64 + // It decodes CBOR negative integer (major type 1) to: + // - int64 if value fits + // - big.Int or *big.Int (see BigIntDecMode) if value doesn't fit into int64 IntDecConvertNone IntDecMode = iota - // IntDecConvertSigned affects how CBOR int (major type 0 and 1) decodes to Go interface{}. - // It makes CBOR positive/negative int (major type 0 and 1) decode to int64 value. - // If value overflows int64, UnmarshalTypeError is returned. + // IntDecConvertSigned affects how CBOR integers (major type 0 and 1) decode to Go interface{}. + // It decodes CBOR integers (major type 0 and 1) to: + // - int64 if value fits + // - big.Int or *big.Int (see BigIntDecMode) if value < math.MinInt64 + // - return UnmarshalTypeError if value > math.MaxInt64 + // Deprecated: IntDecConvertSigned should not be used. + // Please use other options, such as IntDecConvertSignedOrError, IntDecConvertSignedOrBigInt, IntDecConvertNone. IntDecConvertSigned + // IntDecConvertSignedOrFail affects how CBOR integers (major type 0 and 1) decode to Go interface{}. + // It decodes CBOR integers (major type 0 and 1) to: + // - int64 if value fits + // - return UnmarshalTypeError if value doesn't fit into int64 + IntDecConvertSignedOrFail + + // IntDecConvertSigned affects how CBOR integers (major type 0 and 1) decode to Go interface{}. + // It makes CBOR integers (major type 0 and 1) decode to: + // - int64 if value fits + // - big.Int or *big.Int (see BigIntDecMode) if value doesn't fit into int64 + IntDecConvertSignedOrBigInt + maxIntDec ) @@ -1194,32 +1213,63 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli switch t { case cborTypePositiveInt: _, _, val := d.getHead() - if d.dm.intDec == IntDecConvertNone { + + switch d.dm.intDec { + case IntDecConvertNone: return val, nil - } - if val > math.MaxInt64 { - return nil, &UnmarshalTypeError{ - CBORType: t.String(), - GoType: reflect.TypeOf(int64(0)).String(), - errorMsg: strconv.FormatUint(val, 10) + " overflows Go's int64", + + case IntDecConvertSigned, IntDecConvertSignedOrFail: + if val > math.MaxInt64 { + return nil, &UnmarshalTypeError{ + CBORType: t.String(), + GoType: reflect.TypeOf(int64(0)).String(), + errorMsg: strconv.FormatUint(val, 10) + " overflows Go's int64", + } } + + return int64(val), nil + + case IntDecConvertSignedOrBigInt: + if val > math.MaxInt64 { + bi := new(big.Int).SetUint64(val) + if d.dm.bigIntDec == BigIntDecodePointer { + return bi, nil + } + return *bi, nil + } + + return int64(val), nil + + default: + // not reachable } - return int64(val), nil + case cborTypeNegativeInt: _, _, val := d.getHead() + if val > math.MaxInt64 { // CBOR negative integer value overflows Go int64, use big.Int instead. bi := new(big.Int).SetUint64(val) bi.Add(bi, big.NewInt(1)) bi.Neg(bi) + if d.dm.intDec == IntDecConvertSignedOrFail { + return nil, &UnmarshalTypeError{ + CBORType: t.String(), + GoType: reflect.TypeOf(int64(0)).String(), + errorMsg: bi.String() + " overflows Go's int64", + } + } + if d.dm.bigIntDec == BigIntDecodePointer { return bi, nil } return *bi, nil } + nValue := int64(-1) ^ int64(val) return nValue, nil + case cborTypeByteString: return d.parseByteString(), nil case cborTypeTextString: diff --git a/decode_test.go b/decode_test.go index 9231c8f1..fc2edbc6 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4704,8 +4704,168 @@ func TestDecModeInvalidIntDec(t *testing.T) { } } -func TestIntDec(t *testing.T) { - dm, err := DecOptions{IntDec: IntDecConvertSigned}.DecMode() +func TestIntDecConvertNone(t *testing.T) { + dm, err := DecOptions{ + IntDec: IntDecConvertNone, + BigIntDec: BigIntDecodePointer, + }.DecMode() + if err != nil { + t.Errorf("DecMode() returned an error %+v", err) + } + + testCases := []struct { + name string + cborData []byte + wantObj interface{} + }{ + { + name: "CBOR pos int", + cborData: hexDecode("1a000f4240"), + wantObj: uint64(1000000), + }, + { + name: "CBOR pos int overflows int64", + cborData: hexDecode("1b8000000000000000"), // math.MaxInt64+1 + wantObj: uint64(math.MaxInt64 + 1), + }, + { + name: "CBOR neg int", + cborData: hexDecode("3903e7"), + wantObj: int64(-1000), + }, + { + name: "CBOR neg int overflows int64", + cborData: hexDecode("3b8000000000000000"), // math.MinInt64-1 + wantObj: new(big.Int).Sub(big.NewInt(math.MinInt64), big.NewInt(1)), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var v interface{} + err := dm.Unmarshal(tc.cborData, &v) + if err == nil { + if !reflect.DeepEqual(v, tc.wantObj) { + t.Errorf("Unmarshal(0x%x) return %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj) + } + } else { + t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err) + } + }) + } +} + +func TestIntDecConvertSigned(t *testing.T) { + dm, err := DecOptions{ + IntDec: IntDecConvertSigned, + BigIntDec: BigIntDecodePointer, + }.DecMode() + if err != nil { + t.Errorf("DecMode() returned an error %+v", err) + } + + testCases := []struct { + name string + cborData []byte + wantObj interface{} + wantErrorMsg string + }{ + { + name: "CBOR pos int", + cborData: hexDecode("1a000f4240"), + wantObj: int64(1000000), + }, + { + name: "CBOR pos int overflows int64", + cborData: hexDecode("1b8000000000000000"), // math.MaxInt64+1 + wantErrorMsg: "9223372036854775808 overflows Go's int64", + }, + { + name: "CBOR neg int", + cborData: hexDecode("3903e7"), + wantObj: int64(-1000), + }, + { + name: "CBOR neg int overflows int64", + cborData: hexDecode("3b8000000000000000"), // math.MinInt64-1 + wantObj: new(big.Int).Sub(big.NewInt(math.MinInt64), big.NewInt(1)), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var v interface{} + err := dm.Unmarshal(tc.cborData, &v) + if err == nil { + if tc.wantErrorMsg != "" { + t.Errorf("Unmarshal(0x%x) didn't return an error, want %q", tc.cborData, tc.wantErrorMsg) + } else if !reflect.DeepEqual(v, tc.wantObj) { + t.Errorf("Unmarshal(0x%x) return %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj) + } + } else { + if tc.wantErrorMsg == "" { + t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err) + } else if !strings.Contains(err.Error(), tc.wantErrorMsg) { + t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg) + } + } + }) + } +} + +func TestIntDecConvertSignedOrBigInt(t *testing.T) { + dm, err := DecOptions{ + IntDec: IntDecConvertSignedOrBigInt, + BigIntDec: BigIntDecodePointer, + }.DecMode() + if err != nil { + t.Errorf("DecMode() returned an error %+v", err) + } + + testCases := []struct { + name string + cborData []byte + wantObj interface{} + }{ + { + name: "CBOR pos int", + cborData: hexDecode("1a000f4240"), + wantObj: int64(1000000), + }, + { + name: "CBOR pos int overflows int64", + cborData: hexDecode("1b8000000000000000"), + wantObj: new(big.Int).Add(big.NewInt(math.MaxInt64), big.NewInt(1)), + }, + { + name: "CBOR neg int", + cborData: hexDecode("3903e7"), + wantObj: int64(-1000), + }, + { + name: "CBOR neg int overflows int64", + cborData: hexDecode("3b8000000000000000"), + wantObj: new(big.Int).Sub(big.NewInt(math.MinInt64), big.NewInt(1)), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var v interface{} + err := dm.Unmarshal(tc.cborData, &v) + if err == nil { + if !reflect.DeepEqual(v, tc.wantObj) { + t.Errorf("Unmarshal(0x%x) return %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj) + } + } else { + t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err) + } + }) + } +} + +func TestIntDecConvertSignedOrError(t *testing.T) { + dm, err := DecOptions{ + IntDec: IntDecConvertSignedOrFail, + BigIntDec: BigIntDecodePointer, + }.DecMode() if err != nil { t.Errorf("DecMode() returned an error %+v", err) } @@ -4723,14 +4883,19 @@ func TestIntDec(t *testing.T) { }, { name: "CBOR pos int overflows int64", - cborData: hexDecode("1bffffffffffffffff"), - wantErrorMsg: "18446744073709551615 overflows Go's int64", + cborData: hexDecode("1b8000000000000000"), // math.MaxInt64+1 + wantErrorMsg: "9223372036854775808 overflows Go's int64", }, { name: "CBOR neg int", cborData: hexDecode("3903e7"), wantObj: int64(-1000), }, + { + name: "CBOR neg int overflows int64", + cborData: hexDecode("3b8000000000000000"), // math.MinInt64-1 + wantErrorMsg: "-9223372036854775809 overflows Go's int64", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) {