From 490821525cf6d5303e4290321ddfb513efdcfdd9 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Thu, 18 Jan 2024 14:04:47 -0500 Subject: [PATCH] Add option to control the output CBOR type of struct field names. Signed-off-by: Ben Luddy --- cache.go | 6 ++++ encode.go | 60 ++++++++++++++++++++++++------- encode_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ structfields.go | 23 ++++++------ 4 files changed, 159 insertions(+), 24 deletions(-) diff --git a/cache.go b/cache.go index 2fdf114b..b3a68aa9 100644 --- a/cache.go +++ b/cache.go @@ -224,6 +224,12 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) { copy(flds[i].cborName[n:], flds[i].name) e.Reset() + encodeHead(e, byte(cborTypeByteString), uint64(len(flds[i].name))) + flds[i].cborNameByteString = make([]byte, e.Len()+len(flds[i].name)) + n = copy(flds[i].cborNameByteString, e.Bytes()) + copy(flds[i].cborNameByteString[n:], flds[i].name) + e.Reset() + hasKeyAsStr = true } diff --git a/encode.go b/encode.go index eba09c53..8db51b08 100644 --- a/encode.go +++ b/encode.go @@ -340,6 +340,23 @@ func (om OmitEmptyMode) valid() bool { return om >= 0 && om < maxOmitEmptyMode } +// StructFieldNameMode specifies the CBOR type to use when encoding struct field names. +type StructFieldNameMode int + +const ( + // StructFieldNameToTextString encodes struct fields to CBOR text string (major type 3). + StructFieldNameToTextString StructFieldNameMode = iota + + // StructFieldNameToTextString encodes struct fields to CBOR byte string (major type 2). + StructFieldNameToByteString + + maxStructFieldNameMode +) + +func (sfnm StructFieldNameMode) valid() bool { + return sfnm >= 0 && sfnm < maxStructFieldNameMode +} + // EncOptions specifies encoding options. type EncOptions struct { // Sort specifies sorting order. @@ -381,6 +398,9 @@ type EncOptions struct { // - CBOR text string (major type 3) is default // - CBOR byte string (major type 2) String StringMode + + // StructFieldName specifies the CBOR type to use when encoding struct field names. + StructFieldName StructFieldNameMode } // CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding, @@ -563,6 +583,9 @@ func (opts EncOptions) encMode() (*encMode, error) { if err != nil { return nil, err } + if !opts.StructFieldName.valid() { + return nil, errors.New("cbor: invalid StructFieldName " + strconv.Itoa(int(opts.StructFieldName))) + } em := encMode{ sort: opts.Sort, shortestFloat: opts.ShortestFloat, @@ -577,6 +600,7 @@ func (opts EncOptions) encMode() (*encMode, error) { omitEmpty: opts.OmitEmpty, stringType: opts.String, stringMajorType: stringMajorType, + structFieldName: opts.StructFieldName, } return &em, nil } @@ -603,6 +627,7 @@ type encMode struct { omitEmpty OmitEmptyMode stringType StringMode stringMajorType cborType + structFieldName StructFieldNameMode } var defaultEncMode, _ = EncOptions{}.encMode() @@ -610,17 +635,18 @@ var defaultEncMode, _ = EncOptions{}.encMode() // EncOptions returns user specified options used to create this EncMode. func (em *encMode) EncOptions() EncOptions { return EncOptions{ - Sort: em.sort, - ShortestFloat: em.shortestFloat, - NaNConvert: em.nanConvert, - InfConvert: em.infConvert, - BigIntConvert: em.bigIntConvert, - Time: em.time, - TimeTag: em.timeTag, - IndefLength: em.indefLength, - TagsMd: em.tagsMd, - OmitEmpty: em.omitEmpty, - String: em.stringType, + Sort: em.sort, + ShortestFloat: em.shortestFloat, + NaNConvert: em.nanConvert, + InfConvert: em.infConvert, + BigIntConvert: em.bigIntConvert, + Time: em.time, + TimeTag: em.timeTag, + IndefLength: em.indefLength, + TagsMd: em.tagsMd, + OmitEmpty: em.omitEmpty, + String: em.stringType, + StructFieldName: em.structFieldName, } } @@ -1137,7 +1163,11 @@ func encodeFixedLengthStruct(e *encoderBuffer, em *encMode, v reflect.Value, fld for i := 0; i < len(flds); i++ { f := flds[i] - e.Write(f.cborName) + if !f.keyAsInt && em.structFieldName == StructFieldNameToByteString { + e.Write(f.cborNameByteString) + } else { // int or text string + e.Write(f.cborName) + } fv := v.Field(f.idx[0]) if err := f.ef(e, em, fv); err != nil { @@ -1189,7 +1219,11 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) { } } - kve.Write(f.cborName) + if !f.keyAsInt && em.structFieldName == StructFieldNameToByteString { + kve.Write(f.cborNameByteString) + } else { // int or text string + kve.Write(f.cborName) + } if err := f.ef(kve, em, fv); err != nil { putEncoderBuffer(kve) diff --git a/encode_test.go b/encode_test.go index ec92334f..77513f18 100644 --- a/encode_test.go +++ b/encode_test.go @@ -3698,6 +3698,34 @@ func TestEncModeStringType(t *testing.T) { } } +func TestEncModeInvalidStructFieldNameMode(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + wantErrorMsg string + }{ + { + name: "", + opts: EncOptions{StructFieldName: -1}, + wantErrorMsg: "cbor: invalid StructFieldName -1", + }, + { + name: "", + opts: EncOptions{StructFieldName: 101}, + wantErrorMsg: "cbor: invalid StructFieldName 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.EncMode() + if err == nil { + t.Errorf("EncMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("EncMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + func TestEncIndefiniteLengthOption(t *testing.T) { // Default option allows indefinite length items var buf bytes.Buffer @@ -4055,3 +4083,69 @@ func TestMarshalStringType(t *testing.T) { }) } } + +func TestMarshalStructFieldNameType(t *testing.T) { + for _, tc := range []struct { + name string + opts EncOptions + in interface{} + want []byte + }{ + { + name: "fixed-length to text string", + opts: EncOptions{StructFieldName: StructFieldNameToTextString}, + in: struct { + F1 int `cbor:"1,keyasint"` + F2 int `cbor:"a"` + F3 int `cbor:"-3,keyasint"` + }{}, + want: hexDecode("a301006161002200"), + }, + { + name: "fixed-length to byte string", + opts: EncOptions{StructFieldName: StructFieldNameToByteString}, + in: struct { + F1 int `cbor:"1,keyasint"` + F2 int `cbor:"a"` + F3 int `cbor:"-3,keyasint"` + }{}, + want: hexDecode("a301004161002200"), + }, + { + name: "variable-length to text string", + opts: EncOptions{StructFieldName: StructFieldNameToTextString}, + in: struct { + F1 int `cbor:"1,omitempty,keyasint"` + F2 int `cbor:"a,omitempty"` + F3 int `cbor:"-3,omitempty,keyasint"` + }{F1: 7, F2: 7, F3: 7}, + want: hexDecode("a301076161072207"), + }, + { + name: "variable-length to byte string", + opts: EncOptions{StructFieldName: StructFieldNameToByteString}, + in: struct { + F1 int `cbor:"1,omitempty,keyasint"` + F2 int `cbor:"a,omitempty"` + F3 int `cbor:"-3,omitempty,keyasint"` + }{F1: 7, F2: 7, F3: 7}, + want: hexDecode("a301074161072207"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + em, err := tc.opts.EncMode() + if err != nil { + t.Fatal(err) + } + + got, err := em.Marshal(tc.in) + if err != nil { + t.Errorf("unexpected error from Marshal(%q): %v", tc.in, err) + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("Marshal(%q): wanted %x, got %x", tc.in, tc.want, got) + } + }) + } +} diff --git a/structfields.go b/structfields.go index e811aa1e..23a12bee 100644 --- a/structfields.go +++ b/structfields.go @@ -10,17 +10,18 @@ import ( ) type field struct { - name string - nameAsInt int64 // used to decoder to match field name with CBOR int - cborName []byte - idx []int - typ reflect.Type - ef encodeFunc - ief isEmptyFunc - typInfo *typeInfo // used to decoder to reuse type info - tagged bool // used to choose dominant field (at the same level tagged fields dominate untagged fields) - omitEmpty bool // used to skip empty field - keyAsInt bool // used to encode/decode field name as int + name string + nameAsInt int64 // used to decoder to match field name with CBOR int + cborName []byte + cborNameByteString []byte // major type 2 name encoding iff cborName has major type 3 + idx []int + typ reflect.Type + ef encodeFunc + ief isEmptyFunc + typInfo *typeInfo // used to decoder to reuse type info + tagged bool // used to choose dominant field (at the same level tagged fields dominate untagged fields) + omitEmpty bool // used to skip empty field + keyAsInt bool // used to encode/decode field name as int } type fields []*field