diff --git a/encode.go b/encode.go index eba09c53..db111530 100644 --- a/encode.go +++ b/encode.go @@ -12,7 +12,6 @@ import ( "math" "math/big" "reflect" - "sort" "strconv" "sync" "time" @@ -947,38 +946,6 @@ func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) return nil } -type mapEncodeFunc struct { - kf, ef encodeFunc -} - -func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { - if v.IsNil() && em.nilContainers == NilContainerAsNull { - e.Write(cborNil) - return nil - } - if b := em.encTagBytes(v.Type()); b != nil { - e.Write(b) - } - mlen := v.Len() - if mlen == 0 { - return e.WriteByte(byte(cborTypeMap)) - } - if em.sort != SortNone { - return me.encodeCanonical(e, em, v) - } - encodeHead(e, byte(cborTypeMap), uint64(mlen)) - iter := v.MapRange() - for iter.Next() { - if err := me.kf(e, em, iter.Key()); err != nil { - return err - } - if err := me.ef(e, em, iter.Value()); err != nil { - return err - } - } - return nil -} - type keyValue struct { keyCBORData, keyValueCBORData []byte keyLen, keyValueLen int @@ -1044,52 +1011,6 @@ func putKeyValues(x *[]keyValue) { keyValuePool.Put(x) } -func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error { - kve := getEncoderBuffer() // accumulated cbor encoded key-values - kvsp := getKeyValues(v.Len()) // for sorting keys - kvs := *kvsp - iter := v.MapRange() - for i := 0; iter.Next(); i++ { - off := kve.Len() - if err := me.kf(kve, em, iter.Key()); err != nil { - putEncoderBuffer(kve) - putKeyValues(kvsp) - return err - } - n1 := kve.Len() - off - if err := me.ef(kve, em, iter.Value()); err != nil { - putEncoderBuffer(kve) - putKeyValues(kvsp) - return err - } - n2 := kve.Len() - off - // Save key and keyvalue length to create slice later. - kvs[i] = keyValue{keyLen: n1, keyValueLen: n2} - } - - b := kve.Bytes() - for i, off := 0, 0; i < len(kvs); i++ { - kvs[i].keyCBORData = b[off : off+kvs[i].keyLen] - kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen] - off += kvs[i].keyValueLen - } - - if em.sort == SortBytewiseLexical { - sort.Sort(&bytewiseKeyValueSorter{kvs}) - } else { - sort.Sort(&lengthFirstKeyValueSorter{kvs}) - } - - encodeHead(e, byte(cborTypeMap), uint64(len(kvs))) - for i := 0; i < len(kvs); i++ { - e.Write(kvs[i].keyValueCBORData) - } - - putEncoderBuffer(kve) - putKeyValues(kvsp) - return nil -} - func encodeStructToArray(e *encoderBuffer, em *encMode, v reflect.Value) (err error) { structType, err := getEncodingStructType(v.Type()) if err != nil { @@ -1383,75 +1304,6 @@ var ( typeByteString = reflect.TypeOf(ByteString("")) ) -func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) { - k := t.Kind() - if k == reflect.Ptr { - return getEncodeIndirectValueFunc(t), isEmptyPtr - } - switch t { - case typeSimpleValue: - return encodeMarshalerType, isEmptyUint - case typeTag: - return encodeTag, alwaysNotEmpty - case typeTime: - return encodeTime, alwaysNotEmpty - case typeBigInt: - return encodeBigInt, alwaysNotEmpty - case typeRawMessage: - return encodeMarshalerType, isEmptySlice - case typeByteString: - return encodeMarshalerType, isEmptyString - } - if reflect.PtrTo(t).Implements(typeMarshaler) { - return encodeMarshalerType, alwaysNotEmpty - } - if reflect.PtrTo(t).Implements(typeBinaryMarshaler) { - return encodeBinaryMarshalerType, isEmptyBinaryMarshaler - } - switch k { - case reflect.Bool: - return encodeBool, isEmptyBool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return encodeInt, isEmptyInt - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return encodeUint, isEmptyUint - case reflect.Float32, reflect.Float64: - return encodeFloat, isEmptyFloat - case reflect.String: - return encodeString, isEmptyString - case reflect.Slice, reflect.Array: - if t.Elem().Kind() == reflect.Uint8 { - return encodeByteString, isEmptySlice - } - f, _ := getEncodeFunc(t.Elem()) - if f == nil { - return nil, nil - } - return arrayEncodeFunc{f: f}.encode, isEmptySlice - case reflect.Map: - kf, _ := getEncodeFunc(t.Key()) - ef, _ := getEncodeFunc(t.Elem()) - if kf == nil || ef == nil { - return nil, nil - } - return mapEncodeFunc{kf: kf, ef: ef}.encode, isEmptyMap - case reflect.Struct: - // Get struct's special field "_" tag options - if f, ok := t.FieldByName("_"); ok { - tag := f.Tag.Get("cbor") - if tag != "-" { - if hasToArrayOption(tag) { - return encodeStructToArray, isEmptyStruct - } - } - } - return encodeStruct, isEmptyStruct - case reflect.Interface: - return encodeIntf, isEmptyIntf - } - return nil, nil -} - func getEncodeIndirectValueFunc(t reflect.Type) encodeFunc { for t.Kind() == reflect.Ptr { t = t.Elem() diff --git a/encode_map.go b/encode_map.go new file mode 100644 index 00000000..1a5a5e96 --- /dev/null +++ b/encode_map.go @@ -0,0 +1,199 @@ +// Copyright (c) Faye Amacker. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +//go:build go1.20 + +package cbor + +import ( + "reflect" + "sort" + "sync" +) + +type mapEncodeFunc struct { + kf, ef encodeFunc + kpool, vpool sync.Pool +} + +func (me *mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { + if v.IsNil() && em.nilContainers == NilContainerAsNull { + e.Write(cborNil) + return nil + } + if b := em.encTagBytes(v.Type()); b != nil { + e.Write(b) + } + mlen := v.Len() + if mlen == 0 { + return e.WriteByte(byte(cborTypeMap)) + } + if em.sort != SortNone { + return me.encodeCanonical(e, em, v) + } + encodeHead(e, byte(cborTypeMap), uint64(mlen)) + iterk := me.kpool.Get().(*reflect.Value) + defer func() { + iterk.SetZero() + me.kpool.Put(iterk) + }() + iterv := me.vpool.Get().(*reflect.Value) + defer func() { + iterv.SetZero() + me.vpool.Put(iterv) + }() + iter := v.MapRange() + for iter.Next() { + iterk.SetIterKey(iter) + iterv.SetIterValue(iter) + if err := me.kf(e, em, *iterk); err != nil { + return err + } + if err := me.ef(e, em, *iterv); err != nil { + return err + } + } + return nil +} + +func (me *mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error { + kve := getEncoderBuffer() // accumulated cbor encoded key-values + kvsp := getKeyValues(v.Len()) // for sorting keys + kvs := *kvsp + iterk := me.kpool.Get().(*reflect.Value) + defer func() { + iterk.SetZero() + me.kpool.Put(iterk) + }() + iterv := me.vpool.Get().(*reflect.Value) + defer func() { + iterv.SetZero() + me.vpool.Put(iterv) + }() + iter := v.MapRange() + for i := 0; iter.Next(); i++ { + iterk.SetIterKey(iter) + iterv.SetIterValue(iter) + off := kve.Len() + if err := me.kf(kve, em, *iterk); err != nil { + putEncoderBuffer(kve) + putKeyValues(kvsp) + return err + } + n1 := kve.Len() - off + if err := me.ef(kve, em, *iterv); err != nil { + putEncoderBuffer(kve) + putKeyValues(kvsp) + return err + } + n2 := kve.Len() - off + // Save key and keyvalue length to create slice later. + kvs[i] = keyValue{keyLen: n1, keyValueLen: n2} + } + + b := kve.Bytes() + for i, off := 0, 0; i < len(kvs); i++ { + kvs[i].keyCBORData = b[off : off+kvs[i].keyLen] + kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen] + off += kvs[i].keyValueLen + } + + if em.sort == SortBytewiseLexical { + sort.Sort(&bytewiseKeyValueSorter{kvs}) + } else { + sort.Sort(&lengthFirstKeyValueSorter{kvs}) + } + + encodeHead(e, byte(cborTypeMap), uint64(len(kvs))) + for i := 0; i < len(kvs); i++ { + e.Write(kvs[i].keyValueCBORData) + } + + putEncoderBuffer(kve) + putKeyValues(kvsp) + return nil +} + +func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) { + k := t.Kind() + if k == reflect.Ptr { + return getEncodeIndirectValueFunc(t), isEmptyPtr + } + switch t { + case typeSimpleValue: + return encodeMarshalerType, isEmptyUint + case typeTag: + return encodeTag, alwaysNotEmpty + case typeTime: + return encodeTime, alwaysNotEmpty + case typeBigInt: + return encodeBigInt, alwaysNotEmpty + case typeRawMessage: + return encodeMarshalerType, isEmptySlice + case typeByteString: + return encodeMarshalerType, isEmptyString + } + if reflect.PtrTo(t).Implements(typeMarshaler) { + return encodeMarshalerType, alwaysNotEmpty + } + if reflect.PtrTo(t).Implements(typeBinaryMarshaler) { + return encodeBinaryMarshalerType, isEmptyBinaryMarshaler + } + switch k { + case reflect.Bool: + return encodeBool, isEmptyBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt, isEmptyInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint, isEmptyUint + case reflect.Float32, reflect.Float64: + return encodeFloat, isEmptyFloat + case reflect.String: + return encodeString, isEmptyString + case reflect.Slice, reflect.Array: + if t.Elem().Kind() == reflect.Uint8 { + return encodeByteString, isEmptySlice + } + f, _ := getEncodeFunc(t.Elem()) + if f == nil { + return nil, nil + } + return arrayEncodeFunc{f: f}.encode, isEmptySlice + case reflect.Map: + kf, _ := getEncodeFunc(t.Key()) + ef, _ := getEncodeFunc(t.Elem()) + if kf == nil || ef == nil { + return nil, nil + } + return (&mapEncodeFunc{ + kf: kf, + ef: ef, + kpool: sync.Pool{ + New: func() interface{} { + rk := reflect.New(t.Key()).Elem() + return &rk + }, + }, + vpool: sync.Pool{ + New: func() interface{} { + rv := reflect.New(t.Elem()).Elem() + return &rv + }, + }, + }).encode, isEmptyMap + case reflect.Struct: + // Get struct's special field "_" tag options + if f, ok := t.FieldByName("_"); ok { + tag := f.Tag.Get("cbor") + if tag != "-" { + if hasToArrayOption(tag) { + return encodeStructToArray, isEmptyStruct + } + } + } + return encodeStruct, isEmptyStruct + case reflect.Interface: + return encodeIntf, isEmptyIntf + } + return nil, nil +} diff --git a/encode_map_go117.go b/encode_map_go117.go new file mode 100644 index 00000000..0970e1d1 --- /dev/null +++ b/encode_map_go117.go @@ -0,0 +1,158 @@ +// Copyright (c) Faye Amacker. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +//go:build !go1.20 + +package cbor + +import ( + "reflect" + "sort" +) + +type mapEncodeFunc struct { + kf, ef encodeFunc +} + +func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { + if v.IsNil() && em.nilContainers == NilContainerAsNull { + e.Write(cborNil) + return nil + } + if b := em.encTagBytes(v.Type()); b != nil { + e.Write(b) + } + mlen := v.Len() + if mlen == 0 { + return e.WriteByte(byte(cborTypeMap)) + } + if em.sort != SortNone { + return me.encodeCanonical(e, em, v) + } + encodeHead(e, byte(cborTypeMap), uint64(mlen)) + iter := v.MapRange() + for iter.Next() { + if err := me.kf(e, em, iter.Key()); err != nil { + return err + } + if err := me.ef(e, em, iter.Value()); err != nil { + return err + } + } + return nil +} + +func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error { + kve := getEncoderBuffer() // accumulated cbor encoded key-values + kvsp := getKeyValues(v.Len()) // for sorting keys + kvs := *kvsp + iter := v.MapRange() + for i := 0; iter.Next(); i++ { + off := kve.Len() + if err := me.kf(kve, em, iter.Key()); err != nil { + putEncoderBuffer(kve) + putKeyValues(kvsp) + return err + } + n1 := kve.Len() - off + if err := me.ef(kve, em, iter.Value()); err != nil { + putEncoderBuffer(kve) + putKeyValues(kvsp) + return err + } + n2 := kve.Len() - off + // Save key and keyvalue length to create slice later. + kvs[i] = keyValue{keyLen: n1, keyValueLen: n2} + } + + b := kve.Bytes() + for i, off := 0, 0; i < len(kvs); i++ { + kvs[i].keyCBORData = b[off : off+kvs[i].keyLen] + kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen] + off += kvs[i].keyValueLen + } + + if em.sort == SortBytewiseLexical { + sort.Sort(&bytewiseKeyValueSorter{kvs}) + } else { + sort.Sort(&lengthFirstKeyValueSorter{kvs}) + } + + encodeHead(e, byte(cborTypeMap), uint64(len(kvs))) + for i := 0; i < len(kvs); i++ { + e.Write(kvs[i].keyValueCBORData) + } + + putEncoderBuffer(kve) + putKeyValues(kvsp) + return nil +} + +func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) { + k := t.Kind() + if k == reflect.Ptr { + return getEncodeIndirectValueFunc(t), isEmptyPtr + } + switch t { + case typeSimpleValue: + return encodeMarshalerType, isEmptyUint + case typeTag: + return encodeTag, alwaysNotEmpty + case typeTime: + return encodeTime, alwaysNotEmpty + case typeBigInt: + return encodeBigInt, alwaysNotEmpty + case typeRawMessage: + return encodeMarshalerType, isEmptySlice + case typeByteString: + return encodeMarshalerType, isEmptyString + } + if reflect.PtrTo(t).Implements(typeMarshaler) { + return encodeMarshalerType, alwaysNotEmpty + } + if reflect.PtrTo(t).Implements(typeBinaryMarshaler) { + return encodeBinaryMarshalerType, isEmptyBinaryMarshaler + } + switch k { + case reflect.Bool: + return encodeBool, isEmptyBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt, isEmptyInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint, isEmptyUint + case reflect.Float32, reflect.Float64: + return encodeFloat, isEmptyFloat + case reflect.String: + return encodeString, isEmptyString + case reflect.Slice, reflect.Array: + if t.Elem().Kind() == reflect.Uint8 { + return encodeByteString, isEmptySlice + } + f, _ := getEncodeFunc(t.Elem()) + if f == nil { + return nil, nil + } + return arrayEncodeFunc{f: f}.encode, isEmptySlice + case reflect.Map: + kf, _ := getEncodeFunc(t.Key()) + ef, _ := getEncodeFunc(t.Elem()) + if kf == nil || ef == nil { + return nil, nil + } + return mapEncodeFunc{kf: kf, ef: ef}.encode, isEmptyMap + case reflect.Struct: + // Get struct's special field "_" tag options + if f, ok := t.FieldByName("_"); ok { + tag := f.Tag.Get("cbor") + if tag != "-" { + if hasToArrayOption(tag) { + return encodeStructToArray, isEmptyStruct + } + } + } + return encodeStruct, isEmptyStruct + case reflect.Interface: + return encodeIntf, isEmptyIntf + } + return nil, nil +} diff --git a/go.mod b/go.mod index 49d74dbd..e6de70ab 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/fxamacker/cbor/v2 -go 1.12 +go 1.20 require github.com/x448/float16 v0.8.4