diff --git a/decode.go b/decode.go index 0b44124d..4981794d 100644 --- a/decode.go +++ b/decode.go @@ -601,69 +601,96 @@ const ( defaultMaxMapPairs = 131072 minMaxMapPairs = 16 maxMaxMapPairs = 2147483647 + + defaultMaxNestedLevels = 32 + minMaxNestedLevels = 4 + maxMaxNestedLevels = 65535 ) func (opts DecOptions) decMode() (*decMode, error) { if !opts.DupMapKey.valid() { return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey))) } + if !opts.TimeTag.valid() { return nil, errors.New("cbor: invalid TimeTag " + strconv.Itoa(int(opts.TimeTag))) } + if !opts.IndefLength.valid() { return nil, errors.New("cbor: invalid IndefLength " + strconv.Itoa(int(opts.IndefLength))) } + if !opts.TagsMd.valid() { return nil, errors.New("cbor: invalid TagsMd " + strconv.Itoa(int(opts.TagsMd))) } + if !opts.IntDec.valid() { return nil, errors.New("cbor: invalid IntDec " + strconv.Itoa(int(opts.IntDec))) } + if !opts.MapKeyByteString.valid() { return nil, errors.New("cbor: invalid MapKeyByteString " + strconv.Itoa(int(opts.MapKeyByteString))) } + if opts.MaxNestedLevels == 0 { - opts.MaxNestedLevels = 32 - } else if opts.MaxNestedLevels < 4 || opts.MaxNestedLevels > 65535 { - return nil, errors.New("cbor: invalid MaxNestedLevels " + strconv.Itoa(opts.MaxNestedLevels) + " (range is [4, 65535])") + opts.MaxNestedLevels = defaultMaxNestedLevels + } else if opts.MaxNestedLevels < minMaxNestedLevels || opts.MaxNestedLevels > maxMaxNestedLevels { + return nil, errors.New("cbor: invalid MaxNestedLevels " + strconv.Itoa(opts.MaxNestedLevels) + + " (range is [" + strconv.Itoa(minMaxNestedLevels) + ", " + strconv.Itoa(maxMaxNestedLevels) + "])") } + if opts.MaxArrayElements == 0 { opts.MaxArrayElements = defaultMaxArrayElements } else if opts.MaxArrayElements < minMaxArrayElements || opts.MaxArrayElements > maxMaxArrayElements { - return nil, errors.New("cbor: invalid MaxArrayElements " + strconv.Itoa(opts.MaxArrayElements) + " (range is [" + strconv.Itoa(minMaxArrayElements) + ", " + strconv.Itoa(maxMaxArrayElements) + "])") + return nil, errors.New("cbor: invalid MaxArrayElements " + strconv.Itoa(opts.MaxArrayElements) + + " (range is [" + strconv.Itoa(minMaxArrayElements) + ", " + strconv.Itoa(maxMaxArrayElements) + "])") } + if opts.MaxMapPairs == 0 { opts.MaxMapPairs = defaultMaxMapPairs } else if opts.MaxMapPairs < minMaxMapPairs || opts.MaxMapPairs > maxMaxMapPairs { - return nil, errors.New("cbor: invalid MaxMapPairs " + strconv.Itoa(opts.MaxMapPairs) + " (range is [" + strconv.Itoa(minMaxMapPairs) + ", " + strconv.Itoa(maxMaxMapPairs) + "])") + return nil, errors.New("cbor: invalid MaxMapPairs " + strconv.Itoa(opts.MaxMapPairs) + + " (range is [" + strconv.Itoa(minMaxMapPairs) + ", " + strconv.Itoa(maxMaxMapPairs) + "])") } + if !opts.ExtraReturnErrors.valid() { return nil, errors.New("cbor: invalid ExtraReturnErrors " + strconv.Itoa(int(opts.ExtraReturnErrors))) } + if opts.DefaultMapType != nil && opts.DefaultMapType.Kind() != reflect.Map { return nil, fmt.Errorf("cbor: invalid DefaultMapType %s", opts.DefaultMapType) } + if !opts.UTF8.valid() { return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8))) } + if !opts.FieldNameMatching.valid() { return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching))) } + if !opts.BigIntDec.valid() { return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec))) } - if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) { + + if opts.DefaultByteStringType != nil && + opts.DefaultByteStringType.Kind() != reflect.String && + (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) { return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType) } + if !opts.ByteStringToString.valid() { return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString))) } + if !opts.FieldNameByteString.valid() { return nil, errors.New("cbor: invalid FieldNameByteString " + strconv.Itoa(int(opts.FieldNameByteString))) } + if !opts.UnrecognizedTagToAny.valid() { return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny))) } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -684,6 +711,7 @@ func (opts DecOptions) decMode() (*decMode, error) { fieldNameByteString: opts.FieldNameByteString, unrecognizedTagToAny: opts.UnrecognizedTagToAny, } + return &dm, nil } @@ -790,9 +818,9 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error { d := decoder{data: data, dm: dm} // Check well-formedness. - off := d.off // Save offset before data validation - err := d.wellformed(false) // don't allow any extra data after valid data item. - d.off = off // Restore offset + off := d.off // Save offset before data validation + err := d.wellformed(false, false) // don't allow any extra data after valid data item. + d.off = off // Restore offset if err != nil { return err } @@ -810,9 +838,9 @@ func (dm *decMode) UnmarshalFirst(data []byte, v interface{}) (rest []byte, err d := decoder{data: data, dm: dm} // check well-formedness. - off := d.off // Save offset before data validation - err = d.wellformed(true) // allow extra data after well-formed data item - d.off = off // Restore offset + off := d.off // Save offset before data validation + err = d.wellformed(true, false) // allow extra data after well-formed data item + d.off = off // Restore offset // If it is well-formed, parse the value. This is structured like this to allow // better test coverage @@ -853,7 +881,7 @@ func (dm *decMode) Valid(data []byte) error { // an ExtraneousDataError is returned. func (dm *decMode) Wellformed(data []byte) error { d := decoder{data: data, dm: dm} - return d.wellformed(false) + return d.wellformed(false, false) } // NewDecoder returns a new decoder that reads from r using dm DecMode. diff --git a/diagnose.go b/diagnose.go index 43e6a14c..7b4a4676 100644 --- a/diagnose.go +++ b/diagnose.go @@ -235,7 +235,7 @@ func (di *diagnose) diagFirst() (string, []byte, error) { func (di *diagnose) wellformed(allowExtraData bool) error { off := di.d.off - err := di.d.wellformed(allowExtraData) + err := di.d.wellformed(allowExtraData, false) di.d.off = off return err } diff --git a/encode.go b/encode.go index 86cc47ab..f04559fe 100644 --- a/encode.go +++ b/encode.go @@ -100,6 +100,23 @@ type Marshaler interface { MarshalCBOR() ([]byte, error) } +// MarshalerError represents error from checking encoded CBOR data item +// returned from MarshalCBOR for well-formedness and some very limited tag validation. +type MarshalerError struct { + typ reflect.Type + err error +} + +func (e *MarshalerError) Error() string { + return "cbor: error calling MarshalCBOR for type " + + e.typ.String() + + ": " + e.err.Error() +} + +func (e *MarshalerError) Unwrap() error { + return e.err +} + // UnsupportedTypeError is returned by Marshal when attempting to encode value // of an unsupported type. type UnsupportedTypeError struct { @@ -632,6 +649,75 @@ type encMode struct { var defaultEncMode, _ = EncOptions{}.encMode() +// These four decoding modes are used by getMarshalerDecMode. +// maxNestedLevels, maxArrayElements, and maxMapPairs are +// set to max allowed limits to avoid rejecting Marshaler +// output that would have been the allowable output of a +// non-Marshaler object that exceeds default limits. +var ( + marshalerForbidIndefLengthForbidTagsDecMode = decMode{ + maxNestedLevels: maxMaxNestedLevels, + maxArrayElements: maxMaxArrayElements, + maxMapPairs: maxMaxMapPairs, + indefLength: IndefLengthForbidden, + tagsMd: TagsForbidden, + } + + marshalerAllowIndefLengthForbidTagsDecMode = decMode{ + maxNestedLevels: maxMaxNestedLevels, + maxArrayElements: maxMaxArrayElements, + maxMapPairs: maxMaxMapPairs, + indefLength: IndefLengthAllowed, + tagsMd: TagsForbidden, + } + + marshalerForbidIndefLengthAllowTagsDecMode = decMode{ + maxNestedLevels: maxMaxNestedLevels, + maxArrayElements: maxMaxArrayElements, + maxMapPairs: maxMaxMapPairs, + indefLength: IndefLengthForbidden, + tagsMd: TagsAllowed, + } + + marshalerAllowIndefLengthAllowTagsDecMode = decMode{ + maxNestedLevels: maxMaxNestedLevels, + maxArrayElements: maxMaxArrayElements, + maxMapPairs: maxMaxMapPairs, + indefLength: IndefLengthAllowed, + tagsMd: TagsAllowed, + } +) + +// getMarshalerDecMode returns one of four existing decoding modes +// which can be reused (safe for parallel use) for the purpose of +// checking if data returned by Marshaler is well-formed. +func getMarshalerDecMode(indefLength IndefLengthMode, tagsMd TagsMode) *decMode { + switch { + case indefLength == IndefLengthAllowed && tagsMd == TagsAllowed: + return &marshalerAllowIndefLengthAllowTagsDecMode + + case indefLength == IndefLengthAllowed && tagsMd == TagsForbidden: + return &marshalerAllowIndefLengthForbidTagsDecMode + + case indefLength == IndefLengthForbidden && tagsMd == TagsAllowed: + return &marshalerForbidIndefLengthAllowTagsDecMode + + case indefLength == IndefLengthForbidden && tagsMd == TagsForbidden: + return &marshalerForbidIndefLengthForbidTagsDecMode + + default: + // This should never happen, unless we add new options to + // IndefLengthMode or TagsMode without updating this function. + return &decMode{ + maxNestedLevels: maxMaxNestedLevels, + maxArrayElements: maxMaxArrayElements, + maxMapPairs: maxMaxMapPairs, + indefLength: indefLength, + tagsMd: tagsMd, + } + } +} + // EncOptions returns user specified options used to create this EncMode. func (em *encMode) EncOptions() EncOptions { return EncOptions{ @@ -1345,6 +1431,14 @@ func encodeMarshalerType(e *encoderBuffer, em *encMode, v reflect.Value) error { if err != nil { return err } + + // Verify returned CBOR data item from MarshalCBOR() is well-formed and passes tag validity for builtin tags 0-3. + d := decoder{data: data, dm: getMarshalerDecMode(em.indefLength, em.tagsMd)} + err = d.wellformed(false, true) + if err != nil { + return &MarshalerError{typ: v.Type(), err: err} + } + e.Write(data) return nil } diff --git a/encode_test.go b/encode_test.go index bb4bdade..2ed416bd 100644 --- a/encode_test.go +++ b/encode_test.go @@ -4183,3 +4183,163 @@ func TestMarshalFieldNameType(t *testing.T) { }) } } + +func TestMarshalRawMessageContainingMalformedCBORData(t *testing.T) { + testCases := []struct { + name string + value interface{} + wantErrorMsg string + }{ + // Nil RawMessage and empty RawMessage are encoded as CBOR nil. + { + name: "truncated data", + value: RawMessage{0xa6}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawMessage: unexpected EOF", + }, + { + name: "malformed data", + value: RawMessage{0x1f}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawMessage: cbor: invalid additional information 31 for type positive integer", + }, + { + name: "extraneous data", + value: RawMessage{0x01, 0x01}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawMessage: cbor: 1 bytes of extraneous data starting at index 1", + }, + { + name: "invalid builtin tag", + value: RawMessage{0xc0, 0x01}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawMessage: cbor: tag number 0 must be followed by text string, got positive integer", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b, err := Marshal(tc.value) + if err == nil { + t.Errorf("Marshal(%v) didn't return an error, want error %q", tc.value, tc.wantErrorMsg) + } else if _, ok := err.(*MarshalerError); !ok { + t.Errorf("Marshal(%v) error type %T, want *MarshalerError", tc.value, err) + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("Marshal(%v) error %q, want %q", tc.value, err.Error(), tc.wantErrorMsg) + } + if b != nil { + t.Errorf("Marshal(%v) = 0x%x, want nil", tc.value, b) + } + }) + } +} + +type marshaler struct { + data []byte +} + +func (m marshaler) MarshalCBOR() (data []byte, err error) { + return m.data, nil +} + +func TestMarshalerReturnsMalformedCBORData(t *testing.T) { + + testCases := []struct { + name string + value interface{} + wantErrorMsg string + }{ + { + name: "truncated data", + value: marshaler{data: []byte{0xa6}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: unexpected EOF", + }, + { + name: "malformed data", + value: marshaler{data: []byte{0x1f}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: cbor: invalid additional information 31 for type positive integer", + }, + { + name: "extraneous data", + value: marshaler{data: []byte{0x01, 0x01}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: cbor: 1 bytes of extraneous data starting at index 1", + }, + { + name: "invalid builtin tag", + value: marshaler{data: []byte{0xc0, 0x01}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: cbor: tag number 0 must be followed by text string, got positive integer", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b, err := Marshal(tc.value) + if err == nil { + t.Errorf("Marshal(%v) didn't return an error, want error %q", tc.value, tc.wantErrorMsg) + } else if _, ok := err.(*MarshalerError); !ok { + t.Errorf("Marshal(%v) error type %T, want *MarshalerError", tc.value, err) + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("Marshal(%v) error %q, want %q", tc.value, err.Error(), tc.wantErrorMsg) + } + if b != nil { + t.Errorf("Marshal(%v) = 0x%x, want nil", tc.value, b) + } + }) + } +} + +func TestMarshalerReturnsDisallowedCBORData(t *testing.T) { + + testCases := []struct { + name string + encOpts EncOptions + value interface{} + wantErrorMsg string + }{ + { + name: "enc mode forbids indefinite length, data has indefinite length", + encOpts: EncOptions{IndefLength: IndefLengthForbidden}, + value: marshaler{data: hexDecode("5f42010243030405ff")}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: cbor: indefinite-length byte string isn't allowed", + }, + { + name: "enc mode allows indefinite length, data has indefinite length", + encOpts: EncOptions{IndefLength: IndefLengthAllowed}, + value: marshaler{data: hexDecode("5f42010243030405ff")}, + }, + { + name: "enc mode forbids tags, data has tags", + encOpts: EncOptions{TagsMd: TagsForbidden}, + value: marshaler{data: hexDecode("c074323031332d30332d32315432303a30343a30305a")}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.marshaler: cbor: CBOR tag isn't allowed", + }, + { + name: "enc mode allows tags, data has tags", + encOpts: EncOptions{TagsMd: TagsAllowed}, + value: marshaler{data: hexDecode("c074323031332d30332d32315432303a30343a30305a")}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + em, err := tc.encOpts.EncMode() + if err != nil { + t.Fatal(err) + } + + b, err := em.Marshal(tc.value) + if tc.wantErrorMsg == "" { + if err != nil { + t.Errorf("Marshal(%v) returned error %q", tc.value, err) + } + } else { + if err == nil { + t.Errorf("Marshal(%v) didn't return an error, want error %q", tc.value, tc.wantErrorMsg) + } else if _, ok := err.(*MarshalerError); !ok { + t.Errorf("Marshal(%v) error type %T, want *MarshalerError", tc.value, err) + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("Marshal(%v) error %q, want %q", tc.value, err.Error(), tc.wantErrorMsg) + } + if b != nil { + t.Errorf("Marshal(%v) = 0x%x, want nil", tc.value, b) + } + } + }) + } +} diff --git a/stream.go b/stream.go index 02fea43c..7ce0c43c 100644 --- a/stream.go +++ b/stream.go @@ -84,7 +84,7 @@ func (dec *Decoder) readNext() (int, error) { if dec.off < len(dec.buf) { dec.d.reset(dec.buf[dec.off:]) off := dec.off // Save offset before data validation - validErr = dec.d.wellformed(true) + validErr = dec.d.wellformed(true, false) dec.off = off // Restore offset if validErr == nil { diff --git a/tag_test.go b/tag_test.go index 4f243314..dec31815 100644 --- a/tag_test.go +++ b/tag_test.go @@ -1444,3 +1444,49 @@ func TestDecodeRegisterTagForUnmarshaler(t *testing.T) { t.Errorf("Marshal(%v) returned %v, want %v", v2, b, data) } } + +func TestMarshalRawTagContainingMalformedCBORData(t *testing.T) { + testCases := []struct { + name string + value interface{} + wantErrorMsg string + }{ + // Nil RawMessage and empty RawMessage are encoded as CBOR nil. + { + name: "truncated data", + value: RawTag{Number: 100, Content: RawMessage{0xa6}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawTag: unexpected EOF", + }, + { + name: "malformed data", + value: RawTag{Number: 100, Content: RawMessage{0x1f}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawTag: cbor: invalid additional information 31 for type positive integer", + }, + { + name: "extraneous data", + value: RawTag{Number: 100, Content: RawMessage{0x01, 0x01}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawTag: cbor: 1 bytes of extraneous data starting at index 3", + }, + { + name: "invalid builtin tag", + value: RawTag{Number: 0, Content: RawMessage{0x01}}, + wantErrorMsg: "cbor: error calling MarshalCBOR for type cbor.RawTag: cbor: tag number 0 must be followed by text string, got positive integer", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b, err := Marshal(tc.value) + if err == nil { + t.Errorf("Marshal(%v) didn't return an error, want error %q", tc.value, tc.wantErrorMsg) + } else if _, ok := err.(*MarshalerError); !ok { + t.Errorf("Marshal(%v) error type %T, want *MarshalerError", tc.value, err) + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("Marshal(%v) error %q, want %q", tc.value, err.Error(), tc.wantErrorMsg) + } + if b != nil { + t.Errorf("Marshal(%v) = 0x%x, want nil", tc.value, b) + } + }) + } +} diff --git a/valid.go b/valid.go index a5213d06..11013faa 100644 --- a/valid.go +++ b/valid.go @@ -82,11 +82,11 @@ func (e *ExtraneousDataError) Error() string { // allowExtraData indicates if extraneous data is allowed after the CBOR data item. // - use allowExtraData = true when using Decoder.Decode() // - use allowExtraData = false when using Unmarshal() -func (d *decoder) wellformed(allowExtraData bool) error { +func (d *decoder) wellformed(allowExtraData bool, checkBuiltinTags bool) error { if len(d.data) == d.off { return io.EOF } - _, err := d.wellformedInternal(0) + _, err := d.wellformedInternal(0, checkBuiltinTags) if err == nil { if !allowExtraData && d.off != len(d.data) { err = &ExtraneousDataError{len(d.data) - d.off, d.off} @@ -96,7 +96,7 @@ func (d *decoder) wellformed(allowExtraData bool) error { } // wellformedInternal checks data's well-formedness and returns max depth and error. -func (d *decoder) wellformedInternal(depth int) (int, error) { +func (d *decoder) wellformedInternal(depth int, checkBuiltinTags bool) (int, error) { t, ai, val, err := d.wellformedHead() if err != nil { return 0, err @@ -108,7 +108,7 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { if d.dm.indefLength == IndefLengthForbidden { return 0, &IndefiniteLengthError{t} } - return d.wellformedIndefiniteString(t, depth) + return d.wellformedIndefiniteString(t, depth, checkBuiltinTags) } valInt := int(val) if valInt < 0 { @@ -119,6 +119,7 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { return 0, io.ErrUnexpectedEOF } d.off += valInt + case cborTypeArray, cborTypeMap: depth++ if depth > d.dm.maxNestedLevels { @@ -129,7 +130,7 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { if d.dm.indefLength == IndefLengthForbidden { return 0, &IndefiniteLengthError{t} } - return d.wellformedIndefiniteArrayOrMap(t, depth) + return d.wellformedIndefiniteArrayOrMap(t, depth, checkBuiltinTags) } valInt := int(val) @@ -156,7 +157,7 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { for j := 0; j < count; j++ { for i := 0; i < valInt; i++ { var dpt int - if dpt, err = d.wellformedInternal(depth); err != nil { + if dpt, err = d.wellformedInternal(depth, checkBuiltinTags); err != nil { return 0, err } if dpt > maxDepth { @@ -165,20 +166,29 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { } } depth = maxDepth + case cborTypeTag: if d.dm.tagsMd == TagsForbidden { return 0, &TagsMdError{} } + tagNum := val + // Scan nested tag numbers to avoid recursion. for { if len(d.data) == d.off { // Tag number must be followed by tag content. return 0, io.ErrUnexpectedEOF } + if checkBuiltinTags { + err = validBuiltinTag(tagNum, d.data[d.off]) + if err != nil { + return 0, err + } + } if cborType(d.data[d.off]&0xe0) != cborTypeTag { break } - if _, _, _, err = d.wellformedHead(); err != nil { + if _, _, tagNum, err = d.wellformedHead(); err != nil { return 0, err } depth++ @@ -187,13 +197,14 @@ func (d *decoder) wellformedInternal(depth int) (int, error) { } } // Check tag content. - return d.wellformedInternal(depth) + return d.wellformedInternal(depth, checkBuiltinTags) } + return depth, nil } // wellformedIndefiniteString checks indefinite length byte/text string's well-formedness and returns max depth and error. -func (d *decoder) wellformedIndefiniteString(t cborType, depth int) (int, error) { +func (d *decoder) wellformedIndefiniteString(t cborType, depth int, checkBuiltinTags bool) (int, error) { var err error for { if len(d.data) == d.off { @@ -211,7 +222,7 @@ func (d *decoder) wellformedIndefiniteString(t cborType, depth int) (int, error) if (d.data[d.off] & 0x1f) == 31 { return 0, &SyntaxError{"cbor: indefinite-length " + t.String() + " chunk is not definite-length"} } - if depth, err = d.wellformedInternal(depth); err != nil { + if depth, err = d.wellformedInternal(depth, checkBuiltinTags); err != nil { return 0, err } } @@ -219,7 +230,7 @@ func (d *decoder) wellformedIndefiniteString(t cborType, depth int) (int, error) } // wellformedIndefiniteArrayOrMap checks indefinite length array/map's well-formedness and returns max depth and error. -func (d *decoder) wellformedIndefiniteArrayOrMap(t cborType, depth int) (int, error) { +func (d *decoder) wellformedIndefiniteArrayOrMap(t cborType, depth int, checkBuiltinTags bool) (int, error) { var err error maxDepth := depth i := 0 @@ -232,7 +243,7 @@ func (d *decoder) wellformedIndefiniteArrayOrMap(t cborType, depth int) (int, er break } var dpt int - if dpt, err = d.wellformedInternal(depth); err != nil { + if dpt, err = d.wellformedInternal(depth, checkBuiltinTags); err != nil { return 0, err } if dpt > maxDepth { diff --git a/valid_test.go b/valid_test.go index 64a7e585..f4637b87 100644 --- a/valid_test.go +++ b/valid_test.go @@ -63,7 +63,7 @@ func TestValidOnStreamingData(t *testing.T) { } d := decoder{data: buf.Bytes(), dm: defaultDecMode} for i := 0; i < len(marshalTests); i++ { - if err := d.wellformed(true); err != nil { + if err := d.wellformed(true, false); err != nil { t.Errorf("wellformed() returned error %v", err) } } @@ -111,7 +111,7 @@ func TestDepth(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { d := decoder{data: tc.data, dm: defaultDecMode} - depth, err := d.wellformedInternal(0) + depth, err := d.wellformedInternal(0, false) if err != nil { t.Errorf("wellformed(0x%x) returned error %v", tc.data, err) } @@ -176,7 +176,7 @@ func TestDepthError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { dm, _ := tc.opts.decMode() d := decoder{data: tc.data, dm: dm} - if _, err := d.wellformedInternal(0); err == nil { + if _, err := d.wellformedInternal(0, false); err == nil { t.Errorf("wellformed(0x%x) didn't return an error", tc.data) } else if _, ok := err.(*MaxNestedLevelError); !ok { t.Errorf("wellformed(0x%x) returned wrong error type %T, want (*MaxNestedLevelError)", tc.data, err) @@ -186,3 +186,113 @@ func TestDepthError(t *testing.T) { }) } } + +func TestValidBuiltinTagTest(t *testing.T) { + testCases := []struct { + name string + data []byte + }{ + { + name: "tag 0", + data: hexDecode("c074323031332d30332d32315432303a30343a30305a"), + }, + { + name: "tag 1", + data: hexDecode("c11a514b67b0"), + }, + { + name: "tag 2", + data: hexDecode("c249010000000000000000"), + }, + { + name: "tag 3", + data: hexDecode("c349010000000000000000"), + }, + { + name: "nested tag 0", + data: hexDecode("d9d9f7c074323031332d30332d32315432303a30343a30305a"), + }, + { + name: "nested tag 1", + data: hexDecode("d9d9f7c11a514b67b0"), + }, + { + name: "nested tag 2", + data: hexDecode("d9d9f7c249010000000000000000"), + }, + { + name: "nested tag 3", + data: hexDecode("d9d9f7c349010000000000000000"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := decoder{data: tc.data, dm: defaultDecMode} + if err := d.wellformed(true, true); err != nil { + t.Errorf("wellformed(0x%x) returned error %v", tc.data, err) + } + }) + } +} + +func TestInvalidBuiltinTagTest(t *testing.T) { + testCases := []struct { + name string + data []byte + wantErrorMsg string + }{ + { + name: "tag 0", + data: hexDecode("c01a514b67b0"), + wantErrorMsg: "cbor: tag number 0 must be followed by text string, got positive integer", + }, + { + name: "tag 1", + data: hexDecode("c174323031332d30332d32315432303a30343a30305a"), + wantErrorMsg: "cbor: tag number 1 must be followed by integer or floating-point number, got UTF-8 text string", + }, + { + name: "tag 2", + data: hexDecode("c269010000000000000000"), + wantErrorMsg: "cbor: tag number 2 or 3 must be followed by byte string, got UTF-8 text string", + }, + { + name: "tag 3", + data: hexDecode("c300"), + wantErrorMsg: "cbor: tag number 2 or 3 must be followed by byte string, got positive integer", + }, + { + name: "nested tag 0", + data: hexDecode("d9d9f7c01a514b67b0"), + wantErrorMsg: "cbor: tag number 0 must be followed by text string, got positive integer", + }, + { + name: "nested tag 1", + data: hexDecode("d9d9f7c174323031332d30332d32315432303a30343a30305a"), + wantErrorMsg: "cbor: tag number 1 must be followed by integer or floating-point number, got UTF-8 text string", + }, + { + name: "nested tag 2", + data: hexDecode("d9d9f7c269010000000000000000"), + wantErrorMsg: "cbor: tag number 2 or 3 must be followed by byte string, got UTF-8 text string", + }, + { + name: "nested tag 3", + data: hexDecode("d9d9f7c300"), + wantErrorMsg: "cbor: tag number 2 or 3 must be followed by byte string, got positive integer", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := decoder{data: tc.data, dm: defaultDecMode} + err := d.wellformed(true, true) + if err == nil { + t.Errorf("wellformed(0x%x) didn't return an error", tc.data) + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("wellformed(0x%x) error %q, want %q", tc.data, err.Error(), tc.wantErrorMsg) + } + }) + } +}