Skip to content

Commit

Permalink
Use MapIter.SetKey/SetValue and sync.Pool to improve memory allocation
Browse files Browse the repository at this point in the history
Since go 1.18, the reflect package introduces MapIter.SetKey and
MapIter.SetValue that will do fewer memory allocation for map
iteration which is frequently used for CBOR encode operation. Plus,
usage of sync.Pool will further reduce memory allocation by reusing
the shared memory in the pool. Lastly, the Value.SetZero method
(available since go 1.20) is helpful to release memory allocation
to the GC when is no longer needed.

Signed-off-by: Vu Dinh <[email protected]>
  • Loading branch information
dinhxuanvu committed Jan 12, 2024
1 parent cd0553c commit 8c36ffb
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 149 deletions.
148 changes: 0 additions & 148 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"math"
"math/big"
"reflect"
"sort"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -947,38 +946,6 @@ func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value)
return nil
}

type mapEncodeFunc struct {
kf, ef encodeFunc
}

func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error {
if v.IsNil() && em.nilContainers == NilContainerAsNull {
e.Write(cborNil)
return nil
}
if b := em.encTagBytes(v.Type()); b != nil {
e.Write(b)
}
mlen := v.Len()
if mlen == 0 {
return e.WriteByte(byte(cborTypeMap))
}
if em.sort != SortNone {
return me.encodeCanonical(e, em, v)
}
encodeHead(e, byte(cborTypeMap), uint64(mlen))
iter := v.MapRange()
for iter.Next() {
if err := me.kf(e, em, iter.Key()); err != nil {
return err
}
if err := me.ef(e, em, iter.Value()); err != nil {
return err
}
}
return nil
}

type keyValue struct {
keyCBORData, keyValueCBORData []byte
keyLen, keyValueLen int
Expand Down Expand Up @@ -1044,52 +1011,6 @@ func putKeyValues(x *[]keyValue) {
keyValuePool.Put(x)
}

func (me mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error {
kve := getEncoderBuffer() // accumulated cbor encoded key-values
kvsp := getKeyValues(v.Len()) // for sorting keys
kvs := *kvsp
iter := v.MapRange()
for i := 0; iter.Next(); i++ {
off := kve.Len()
if err := me.kf(kve, em, iter.Key()); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n1 := kve.Len() - off
if err := me.ef(kve, em, iter.Value()); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n2 := kve.Len() - off
// Save key and keyvalue length to create slice later.
kvs[i] = keyValue{keyLen: n1, keyValueLen: n2}
}

b := kve.Bytes()
for i, off := 0, 0; i < len(kvs); i++ {
kvs[i].keyCBORData = b[off : off+kvs[i].keyLen]
kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen]
off += kvs[i].keyValueLen
}

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs})
}

encodeHead(e, byte(cborTypeMap), uint64(len(kvs)))
for i := 0; i < len(kvs); i++ {
e.Write(kvs[i].keyValueCBORData)
}

putEncoderBuffer(kve)
putKeyValues(kvsp)
return nil
}

func encodeStructToArray(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
Expand Down Expand Up @@ -1383,75 +1304,6 @@ var (
typeByteString = reflect.TypeOf(ByteString(""))
)

func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
k := t.Kind()
if k == reflect.Ptr {
return getEncodeIndirectValueFunc(t), isEmptyPtr
}
switch t {
case typeSimpleValue:
return encodeMarshalerType, isEmptyUint
case typeTag:
return encodeTag, alwaysNotEmpty
case typeTime:
return encodeTime, alwaysNotEmpty
case typeBigInt:
return encodeBigInt, alwaysNotEmpty
case typeRawMessage:
return encodeMarshalerType, isEmptySlice
case typeByteString:
return encodeMarshalerType, isEmptyString
}
if reflect.PtrTo(t).Implements(typeMarshaler) {
return encodeMarshalerType, alwaysNotEmpty
}
if reflect.PtrTo(t).Implements(typeBinaryMarshaler) {
return encodeBinaryMarshalerType, isEmptyBinaryMarshaler
}
switch k {
case reflect.Bool:
return encodeBool, isEmptyBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return encodeInt, isEmptyInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return encodeUint, isEmptyUint
case reflect.Float32, reflect.Float64:
return encodeFloat, isEmptyFloat
case reflect.String:
return encodeString, isEmptyString
case reflect.Slice, reflect.Array:
if t.Elem().Kind() == reflect.Uint8 {
return encodeByteString, isEmptySlice
}
f, _ := getEncodeFunc(t.Elem())
if f == nil {
return nil, nil
}
return arrayEncodeFunc{f: f}.encode, isEmptySlice
case reflect.Map:
kf, _ := getEncodeFunc(t.Key())
ef, _ := getEncodeFunc(t.Elem())
if kf == nil || ef == nil {
return nil, nil
}
return mapEncodeFunc{kf: kf, ef: ef}.encode, isEmptyMap
case reflect.Struct:
// Get struct's special field "_" tag options
if f, ok := t.FieldByName("_"); ok {
tag := f.Tag.Get("cbor")
if tag != "-" {
if hasToArrayOption(tag) {
return encodeStructToArray, isEmptyStruct
}
}
}
return encodeStruct, isEmptyStruct
case reflect.Interface:
return encodeIntf, isEmptyIntf
}
return nil, nil
}

func getEncodeIndirectValueFunc(t reflect.Type) encodeFunc {
for t.Kind() == reflect.Ptr {
t = t.Elem()
Expand Down
199 changes: 199 additions & 0 deletions encode_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright (c) Faye Amacker. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

//go:build go1.20

package cbor

import (
"reflect"
"sort"
"sync"
)

type mapEncodeFunc struct {
kf, ef encodeFunc
kpool, vpool sync.Pool
}

func (me *mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error {
if v.IsNil() && em.nilContainers == NilContainerAsNull {
e.Write(cborNil)
return nil
}
if b := em.encTagBytes(v.Type()); b != nil {
e.Write(b)
}
mlen := v.Len()
if mlen == 0 {
return e.WriteByte(byte(cborTypeMap))
}
if em.sort != SortNone {
return me.encodeCanonical(e, em, v)
}
encodeHead(e, byte(cborTypeMap), uint64(mlen))
iterk := me.kpool.Get().(*reflect.Value)
defer func() {
iterk.SetZero()
me.kpool.Put(iterk)
}()
iterv := me.vpool.Get().(*reflect.Value)
defer func() {
iterv.SetZero()
me.vpool.Put(iterv)
}()
iter := v.MapRange()
for iter.Next() {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)
if err := me.kf(e, em, *iterk); err != nil {
return err
}
if err := me.ef(e, em, *iterv); err != nil {
return err
}
}
return nil
}

func (me *mapEncodeFunc) encodeCanonical(e *encoderBuffer, em *encMode, v reflect.Value) error {
kve := getEncoderBuffer() // accumulated cbor encoded key-values
kvsp := getKeyValues(v.Len()) // for sorting keys
kvs := *kvsp
iterk := me.kpool.Get().(*reflect.Value)
defer func() {
iterk.SetZero()
me.kpool.Put(iterk)
}()
iterv := me.vpool.Get().(*reflect.Value)
defer func() {
iterv.SetZero()
me.vpool.Put(iterv)
}()
iter := v.MapRange()
for i := 0; iter.Next(); i++ {
iterk.SetIterKey(iter)
iterv.SetIterValue(iter)
off := kve.Len()
if err := me.kf(kve, em, *iterk); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n1 := kve.Len() - off
if err := me.ef(kve, em, *iterv); err != nil {
putEncoderBuffer(kve)
putKeyValues(kvsp)
return err
}
n2 := kve.Len() - off
// Save key and keyvalue length to create slice later.
kvs[i] = keyValue{keyLen: n1, keyValueLen: n2}
}

b := kve.Bytes()
for i, off := 0, 0; i < len(kvs); i++ {
kvs[i].keyCBORData = b[off : off+kvs[i].keyLen]
kvs[i].keyValueCBORData = b[off : off+kvs[i].keyValueLen]
off += kvs[i].keyValueLen
}

if em.sort == SortBytewiseLexical {
sort.Sort(&bytewiseKeyValueSorter{kvs})
} else {
sort.Sort(&lengthFirstKeyValueSorter{kvs})
}

encodeHead(e, byte(cborTypeMap), uint64(len(kvs)))
for i := 0; i < len(kvs); i++ {
e.Write(kvs[i].keyValueCBORData)
}

putEncoderBuffer(kve)
putKeyValues(kvsp)
return nil
}

func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
k := t.Kind()
if k == reflect.Ptr {
return getEncodeIndirectValueFunc(t), isEmptyPtr
}
switch t {
case typeSimpleValue:
return encodeMarshalerType, isEmptyUint
case typeTag:
return encodeTag, alwaysNotEmpty
case typeTime:
return encodeTime, alwaysNotEmpty
case typeBigInt:
return encodeBigInt, alwaysNotEmpty
case typeRawMessage:
return encodeMarshalerType, isEmptySlice
case typeByteString:
return encodeMarshalerType, isEmptyString
}
if reflect.PtrTo(t).Implements(typeMarshaler) {
return encodeMarshalerType, alwaysNotEmpty
}
if reflect.PtrTo(t).Implements(typeBinaryMarshaler) {
return encodeBinaryMarshalerType, isEmptyBinaryMarshaler
}
switch k {
case reflect.Bool:
return encodeBool, isEmptyBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return encodeInt, isEmptyInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return encodeUint, isEmptyUint
case reflect.Float32, reflect.Float64:
return encodeFloat, isEmptyFloat
case reflect.String:
return encodeString, isEmptyString
case reflect.Slice, reflect.Array:
if t.Elem().Kind() == reflect.Uint8 {
return encodeByteString, isEmptySlice
}
f, _ := getEncodeFunc(t.Elem())
if f == nil {
return nil, nil
}
return arrayEncodeFunc{f: f}.encode, isEmptySlice
case reflect.Map:
kf, _ := getEncodeFunc(t.Key())
ef, _ := getEncodeFunc(t.Elem())
if kf == nil || ef == nil {
return nil, nil
}
return (&mapEncodeFunc{
kf: kf,
ef: ef,
kpool: sync.Pool{
New: func() interface{} {
rk := reflect.New(t.Key()).Elem()
return &rk
},
},
vpool: sync.Pool{
New: func() interface{} {
rv := reflect.New(t.Elem()).Elem()
return &rv
},
},
}).encode, isEmptyMap
case reflect.Struct:
// Get struct's special field "_" tag options
if f, ok := t.FieldByName("_"); ok {
tag := f.Tag.Get("cbor")
if tag != "-" {
if hasToArrayOption(tag) {
return encodeStructToArray, isEmptyStruct
}
}
}
return encodeStruct, isEmptyStruct
case reflect.Interface:
return encodeIntf, isEmptyIntf
}
return nil, nil
}
Loading

0 comments on commit 8c36ffb

Please sign in to comment.