diff --git a/decode_test.go b/decode_test.go index a7a583f3..99265b63 100644 --- a/decode_test.go +++ b/decode_test.go @@ -7514,7 +7514,7 @@ func TestUnmarshalToInterface(t *testing.T) { if err != nil { t.Errorf("Marshal(%+v) returned error %v", tc.v, err) } else if !bytes.Equal(data, tc.data) { - t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.v) + t.Errorf("Marshal(%+v) = 0x%x, want 0x%x", tc.v, data, tc.data) } // Unmarshal to empty interface diff --git a/encode.go b/encode.go index 5e9c5a4e..087f95f3 100644 --- a/encode.go +++ b/encode.go @@ -973,8 +973,13 @@ func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) return nil } +// encodeKeyValueFunc encodes key/value pairs in map (v). +// If kvs is provided (having the same length as v), length of encoded key and value are stored in kvs. +// kvs is used for canonical encoding of map. +type encodeKeyValueFunc func(e *encoderBuffer, em *encMode, v reflect.Value, kvs []keyValue) error + type mapEncodeFunc struct { - kf, ef encodeFunc + e encodeKeyValueFunc } func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { @@ -993,16 +998,8 @@ func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) e return me.encodeCanonical(e, em, v) } encodeHead(e, byte(cborTypeMap), uint64(mlen)) - iter := v.MapRange() - for iter.Next() { - if err := me.kf(e, em, iter.Key()); err != nil { - return err - } - if err := me.ef(e, em, iter.Value()); err != nil { - return err - } - } - return nil + + return me.e(e, em, v, nil) } type keyValue struct { @@ -1071,26 +1068,17 @@ func putKeyValues(x *[]keyValue) { } func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error { - kve := getEncoderBuffer() // accumulated cbor encoded key-values + kve := getEncoderBuffer() // accumulated cbor encoded key-values + defer putEncoderBuffer(kve) + kvsp := getKeyValues(v.Len()) // for sorting keys + defer putKeyValues(kvsp) + kvs := *kvsp - iter := v.MapRange() - for i := 0; iter.Next(); i++ { - off := kve.Len() - if err := me.kf(kve, em, iter.Key()); err != nil { - putEncoderBuffer(kve) - putKeyValues(kvsp) - return err - } - n1 := kve.Len() - off - if err := me.ef(kve, em, iter.Value()); err != nil { - putEncoderBuffer(kve) - putKeyValues(kvsp) - return err - } - n2 := kve.Len() - off - // Save key and keyvalue length to create slice later. - kvs[i] = keyValue{keyLen: n1, keyValueLen: n2} + + err := me.e(kve, em, v, kvs) + if err != nil { + return err } b := kve.Bytes() @@ -1111,8 +1099,6 @@ func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect e.Write(kvs[i].keyValueCBORData) } - putEncoderBuffer(kve) - putKeyValues(kvsp) return nil } @@ -1463,12 +1449,11 @@ func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) { } return arrayEncodeFunc{f: f}.encode, isEmptySlice case reflect.Map: - kf, _ := getEncodeFunc(t.Key()) - ef, _ := getEncodeFunc(t.Elem()) - if kf == nil || ef == nil { + f := getEncodeMapFunc(t) + if f == nil { return nil, nil } - return mapEncodeFunc{kf: kf, ef: ef}.encode, isEmptyMap + return f, isEmptyMap case reflect.Struct: // Get struct's special field "_" tag options if f, ok := t.FieldByName("_"); ok { diff --git a/encode_map_go117.go b/encode_map_go117.go new file mode 100644 index 00000000..e94f7c87 --- /dev/null +++ b/encode_map_go117.go @@ -0,0 +1,49 @@ +// Copyright (c) Faye Amacker. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cbor + +import ( + "reflect" +) + +type mapKeyValueEncodeFunc struct { + kf, ef encodeFunc +} + +func (me *mapKeyValueEncodeFunc) encodeKeyValues(e *encoderBuffer, em *encMode, v reflect.Value, kvs []keyValue) error { + trackKeyValueLength := len(kvs) == v.Len() + + iter := v.MapRange() + for i := 0; iter.Next(); i++ { + off := e.Len() + + if err := me.kf(e, em, iter.Key()); err != nil { + return err + } + if trackKeyValueLength { + kvs[i].keyLen = e.Len() - off + } + + if err := me.ef(e, em, iter.Value()); err != nil { + return err + } + if trackKeyValueLength { + kvs[i].keyValueLen = e.Len() - off + } + } + + return nil +} + +func getEncodeMapFunc(t reflect.Type) encodeFunc { + kf, _ := getEncodeFunc(t.Key()) + ef, _ := getEncodeFunc(t.Elem()) + if kf == nil || ef == nil { + return nil + } + mkv := &mapKeyValueEncodeFunc{kf: kf, ef: ef} + return mapEncodeFunc{ + e: mkv.encodeKeyValues, + }.encode +}