Skip to content

Commit

Permalink
encoder func return error
Browse files Browse the repository at this point in the history
  • Loading branch information
CMogilko committed Aug 29, 2018
1 parent 3b438a0 commit b0a1d0a
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 67 deletions.
8 changes: 4 additions & 4 deletions encoding/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"sync"
)

type encoderFunc func(v reflect.Value) interface{}
type encoderFunc func(v reflect.Value) (interface{}, error)

// Encode returns the encoded value of v.
//
Expand All @@ -30,10 +30,10 @@ func Encode(v interface{}) (ev interface{}, err error) {
}
}()

return encode(reflect.ValueOf(v)), nil
return encode(reflect.ValueOf(v))
}

func encode(v reflect.Value) interface{} {
func encode(v reflect.Value) (interface{}, error) {
return valueEncoder(v)(v)
}

Expand Down Expand Up @@ -65,7 +65,7 @@ func typeEncoder(t reflect.Type) encoderFunc {
encoderCache.Lock()
var wg sync.WaitGroup
wg.Add(1)
encoderCache.m[t] = func(v reflect.Value) interface{} {
encoderCache.m[t] = func(v reflect.Value) (interface{}, error) {
wg.Wait()
return f(v)
}
Expand Down
39 changes: 31 additions & 8 deletions encoding/encoder_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package encoding

import (
"errors"
"image"
"reflect"
"testing"
Expand Down Expand Up @@ -436,8 +437,8 @@ func TestEncodeCustomTypeEncodingValue(t *testing.T) {
}

SetTypeEncoding(reflect.TypeOf(innerType{}),
func(v interface{}) interface{} {
return map[string]interface{}{"someval": v.(innerType).Val}
func(v interface{}) (interface{}, error) {
return map[string]interface{}{"someval": v.(innerType).Val}, nil
}, nil)

out, err := Encode(outer)
Expand All @@ -464,8 +465,8 @@ func TestEncodeCustomTypeEncodingPointer(t *testing.T) {
}

SetTypeEncoding(reflect.TypeOf((*innerType)(nil)),
func(v interface{}) interface{} {
return map[string]interface{}{"someval": v.(*innerType).Val}
func(v interface{}) (interface{}, error) {
return map[string]interface{}{"someval": v.(*innerType).Val}, nil
}, nil)

out, err := Encode(outer)
Expand All @@ -488,8 +489,8 @@ func TestEncodeCustomRootTypeEncodingValue(t *testing.T) {
}

SetTypeEncoding(reflect.TypeOf(cType{}),
func(v interface{}) interface{} {
return map[string]interface{}{"someval": v.(cType).Val}
func(v interface{}) (interface{}, error) {
return map[string]interface{}{"someval": v.(cType).Val}, nil
}, nil)

out, err := Encode(in)
Expand All @@ -512,8 +513,8 @@ func TestEncodeCustomRootTypeEncodingPointer(t *testing.T) {
}

SetTypeEncoding(reflect.TypeOf((*cType)(nil)),
func(v interface{}) interface{} {
return map[string]interface{}{"someval": v.(*cType).Val}
func(v interface{}) (interface{}, error) {
return map[string]interface{}{"someval": v.(*cType).Val}, nil
}, nil)

out, err := Encode(&in)
Expand All @@ -524,3 +525,25 @@ func TestEncodeCustomRootTypeEncodingPointer(t *testing.T) {
t.Errorf("got %q, want %q", out, want)
}
}

func TestEncodeCustomRootTypeEncodingError(t *testing.T) {
type cType struct {
Val int
}
in := cType{Val: 5}

cerr := errors.New("encode error")

SetTypeEncoding(reflect.TypeOf((*cType)(nil)),
func(v interface{}) (interface{}, error) {
return nil, cerr
}, nil)

_, err := Encode(&in)
if err == nil {
t.Errorf("got nil error, expected %v", cerr)
}
if err != cerr {
t.Errorf("got %q, want %q", err, cerr)
}
}
121 changes: 68 additions & 53 deletions encoding/encoder_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,101 +60,104 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
}
}

func invalidValueEncoder(v reflect.Value) interface{} {
return nil
func invalidValueEncoder(v reflect.Value) (interface{}, error) {
return nil, nil
}

func doNothingEncoder(v reflect.Value) interface{} {
return v.Interface()
func doNothingEncoder(v reflect.Value) (interface{}, error) {
return v.Interface(), nil
}

func marshalerEncoder(v reflect.Value) interface{} {
func marshalerEncoder(v reflect.Value) (interface{}, error) {
if v.Kind() == reflect.Ptr && v.IsNil() {
return nil
return nil, nil
}
m := v.Interface().(Marshaler)
ev, err := m.MarshalRQL()
if err != nil {
panic(&MarshalerError{v.Type(), err})
return nil, &MarshalerError{v.Type(), err}
}

return ev
return ev, nil
}

func addrMarshalerEncoder(v reflect.Value) interface{} {
func addrMarshalerEncoder(v reflect.Value) (interface{}, error) {
va := v.Addr()
if va.IsNil() {
return nil
return nil, nil
}
m := va.Interface().(Marshaler)
ev, err := m.MarshalRQL()
if err != nil {
panic(&MarshalerError{v.Type(), err})
return nil, &MarshalerError{v.Type(), err}
}

return ev
return ev, nil
}

func boolEncoder(v reflect.Value) interface{} {
func boolEncoder(v reflect.Value) (interface{}, error) {
if v.Bool() {
return true
return true, nil
} else {
return false
return false, nil
}
}

func intEncoder(v reflect.Value) interface{} {
return v.Int()
func intEncoder(v reflect.Value) (interface{}, error) {
return v.Int(), nil
}

func uintEncoder(v reflect.Value) interface{} {
return v.Uint()
func uintEncoder(v reflect.Value) (interface{}, error) {
return v.Uint(), nil
}

func floatEncoder(v reflect.Value) interface{} {
return v.Float()
func floatEncoder(v reflect.Value) (interface{}, error) {
return v.Float(), nil
}

func stringEncoder(v reflect.Value) interface{} {
return v.String()
func stringEncoder(v reflect.Value) (interface{}, error) {
return v.String(), nil
}

func interfaceEncoder(v reflect.Value) interface{} {
func interfaceEncoder(v reflect.Value) (interface{}, error) {
if v.IsNil() {
return nil
return nil, nil
}
return encode(v.Elem())
}

func funcEncoder(v reflect.Value) interface{} {
func funcEncoder(v reflect.Value) (interface{}, error) {
if v.IsNil() {
return nil
return nil, nil
}
return v.Interface()
return v.Interface(), nil
}

func asStringEncoder(v reflect.Value) interface{} {
return fmt.Sprintf("%v", v.Interface())
func asStringEncoder(v reflect.Value) (interface{}, error) {
return fmt.Sprintf("%v", v.Interface()), nil
}

func unsupportedTypeEncoder(v reflect.Value) interface{} {
panic(&UnsupportedTypeError{v.Type()})
func unsupportedTypeEncoder(v reflect.Value) (interface{}, error) {
return nil, &UnsupportedTypeError{v.Type()}
}

type structEncoder struct {
fields []field
fieldEncs []encoderFunc
}

func (se *structEncoder) encode(v reflect.Value) interface{} {
func (se *structEncoder) encode(v reflect.Value) (interface{}, error) {
m := make(map[string]interface{})
for i, f := range se.fields {
fv := fieldByIndex(v, f.index)
if !fv.IsValid() || f.omitEmpty && se.isEmptyValue(fv) {
continue
}

encField := se.fieldEncs[i](fv)
encField, err := se.fieldEncs[i](fv)
if err != nil {
return nil, err
}

// If this field is a referenced field then attempt to extract the value.
if f.reference {
Expand All @@ -179,7 +182,7 @@ func (se *structEncoder) encode(v reflect.Value) interface{} {
m[f.name] = encField
}

return m
return m, nil
}

func getReferenceField(f field, v reflect.Value, encField interface{}) interface{} {
Expand Down Expand Up @@ -240,18 +243,26 @@ type mapEncoder struct {
keyEnc, elemEnc encoderFunc
}

func (me *mapEncoder) encode(v reflect.Value) interface{} {
func (me *mapEncoder) encode(v reflect.Value) (interface{}, error) {
if v.IsNil() {
return nil
return nil, nil
}

m := make(map[string]interface{})

for _, k := range v.MapKeys() {
m[me.keyEnc(k).(string)] = me.elemEnc(v.MapIndex(k))
encV, err := me.elemEnc(v.MapIndex(k))
if err != nil {
return nil, err
}
encK, err := me.keyEnc(k)
if err != nil {
return nil, err
}
m[encK.(string)] = encV
}

return m
return m, nil
}

func newMapEncoder(t reflect.Type) encoderFunc {
Expand Down Expand Up @@ -282,9 +293,9 @@ type sliceEncoder struct {
arrayEnc encoderFunc
}

func (se *sliceEncoder) encode(v reflect.Value) interface{} {
func (se *sliceEncoder) encode(v reflect.Value) (interface{}, error) {
if v.IsNil() {
return []interface{}(nil)
return []interface{}(nil), nil
}
return se.arrayEnc(v)
}
Expand All @@ -302,15 +313,19 @@ type arrayEncoder struct {
elemEnc encoderFunc
}

func (ae *arrayEncoder) encode(v reflect.Value) interface{} {
func (ae *arrayEncoder) encode(v reflect.Value) (interface{}, error) {
n := v.Len()

a := make([]interface{}, n)
for i := 0; i < n; i++ {
a[i] = ae.elemEnc(v.Index(i))
var err error
a[i], err = ae.elemEnc(v.Index(i))
if err != nil {
return nil, err
}
}

return a
return a, nil
}

func newArrayEncoder(t reflect.Type) encoderFunc {
Expand All @@ -325,9 +340,9 @@ type ptrEncoder struct {
elemEnc encoderFunc
}

func (pe *ptrEncoder) encode(v reflect.Value) interface{} {
func (pe *ptrEncoder) encode(v reflect.Value) (interface{}, error) {
if v.IsNil() {
return nil
return nil, nil
}
return pe.elemEnc(v.Elem())
}
Expand All @@ -341,7 +356,7 @@ type condAddrEncoder struct {
canAddrEnc, elseEnc encoderFunc
}

func (ce *condAddrEncoder) encode(v reflect.Value) interface{} {
func (ce *condAddrEncoder) encode(v reflect.Value) (interface{}, error) {
if v.CanAddr() {
return ce.canAddrEnc(v)
} else {
Expand All @@ -359,7 +374,7 @@ func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
// Pseudo-type encoders

// Encode a time.Time value to the TIME RQL type
func timePseudoTypeEncoder(v reflect.Value) interface{} {
func timePseudoTypeEncoder(v reflect.Value) (interface{}, error) {
t := v.Interface().(time.Time)

timeVal := float64(t.UnixNano()) / float64(time.Second)
Expand All @@ -374,11 +389,11 @@ func timePseudoTypeEncoder(v reflect.Value) interface{} {
"$reql_type$": "TIME",
"epoch_time": timeVal,
"timezone": t.Format("-07:00"),
}
}, nil
}

// Encode a byte slice to the BINARY RQL type
func encodeByteSlice(v reflect.Value) interface{} {
func encodeByteSlice(v reflect.Value) (interface{}, error) {
var b []byte
if !v.IsNil() {
b = v.Bytes()
Expand All @@ -390,11 +405,11 @@ func encodeByteSlice(v reflect.Value) interface{} {
return map[string]interface{}{
"$reql_type$": "BINARY",
"data": string(dst),
}
}, nil
}

// Encode a byte array to the BINARY RQL type
func encodeByteArray(v reflect.Value) interface{} {
func encodeByteArray(v reflect.Value) (interface{}, error) {
b := make([]byte, v.Len())
for i := 0; i < v.Len(); i++ {
b[i] = v.Index(i).Interface().(byte)
Expand All @@ -406,5 +421,5 @@ func encodeByteArray(v reflect.Value) interface{} {
return map[string]interface{}{
"$reql_type$": "BINARY",
"data": string(dst),
}
}, nil
}
Loading

0 comments on commit b0a1d0a

Please sign in to comment.