diff --git a/encoding/encoder.go b/encoding/encoder.go index 1d6c5164..511c073a 100644 --- a/encoding/encoder.go +++ b/encoding/encoder.go @@ -9,7 +9,7 @@ import ( "sync" ) -type encoderFunc func(v reflect.Value) interface{} +type encoderFunc func(v reflect.Value) (interface{}, error) // Encode returns the encoded value of v. // @@ -30,10 +30,10 @@ func Encode(v interface{}) (ev interface{}, err error) { } }() - return encode(reflect.ValueOf(v)), nil + return encode(reflect.ValueOf(v)) } -func encode(v reflect.Value) interface{} { +func encode(v reflect.Value) (interface{}, error) { return valueEncoder(v)(v) } @@ -65,7 +65,7 @@ func typeEncoder(t reflect.Type) encoderFunc { encoderCache.Lock() var wg sync.WaitGroup wg.Add(1) - encoderCache.m[t] = func(v reflect.Value) interface{} { + encoderCache.m[t] = func(v reflect.Value) (interface{}, error) { wg.Wait() return f(v) } diff --git a/encoding/encoder_test.go b/encoding/encoder_test.go index 8900fe9a..66e7621d 100644 --- a/encoding/encoder_test.go +++ b/encoding/encoder_test.go @@ -1,6 +1,7 @@ package encoding import ( + "errors" "image" "reflect" "testing" @@ -436,8 +437,8 @@ func TestEncodeCustomTypeEncodingValue(t *testing.T) { } SetTypeEncoding(reflect.TypeOf(innerType{}), - func(v interface{}) interface{} { - return map[string]interface{}{"someval": v.(innerType).Val} + func(v interface{}) (interface{}, error) { + return map[string]interface{}{"someval": v.(innerType).Val}, nil }, nil) out, err := Encode(outer) @@ -464,8 +465,8 @@ func TestEncodeCustomTypeEncodingPointer(t *testing.T) { } SetTypeEncoding(reflect.TypeOf((*innerType)(nil)), - func(v interface{}) interface{} { - return map[string]interface{}{"someval": v.(*innerType).Val} + func(v interface{}) (interface{}, error) { + return map[string]interface{}{"someval": v.(*innerType).Val}, nil }, nil) out, err := Encode(outer) @@ -488,8 +489,8 @@ func TestEncodeCustomRootTypeEncodingValue(t *testing.T) { } SetTypeEncoding(reflect.TypeOf(cType{}), - func(v interface{}) interface{} { - return map[string]interface{}{"someval": v.(cType).Val} + func(v interface{}) (interface{}, error) { + return map[string]interface{}{"someval": v.(cType).Val}, nil }, nil) out, err := Encode(in) @@ -512,8 +513,8 @@ func TestEncodeCustomRootTypeEncodingPointer(t *testing.T) { } SetTypeEncoding(reflect.TypeOf((*cType)(nil)), - func(v interface{}) interface{} { - return map[string]interface{}{"someval": v.(*cType).Val} + func(v interface{}) (interface{}, error) { + return map[string]interface{}{"someval": v.(*cType).Val}, nil }, nil) out, err := Encode(&in) @@ -524,3 +525,25 @@ func TestEncodeCustomRootTypeEncodingPointer(t *testing.T) { t.Errorf("got %q, want %q", out, want) } } + +func TestEncodeCustomRootTypeEncodingError(t *testing.T) { + type cType struct { + Val int + } + in := cType{Val: 5} + + cerr := errors.New("encode error") + + SetTypeEncoding(reflect.TypeOf((*cType)(nil)), + func(v interface{}) (interface{}, error) { + return nil, cerr + }, nil) + + _, err := Encode(&in) + if err == nil { + t.Errorf("got nil error, expected %v", cerr) + } + if err != cerr { + t.Errorf("got %q, want %q", err, cerr) + } +} diff --git a/encoding/encoder_types.go b/encoding/encoder_types.go index 20f16cad..c6d30fe2 100644 --- a/encoding/encoder_types.go +++ b/encoding/encoder_types.go @@ -60,85 +60,85 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { } } -func invalidValueEncoder(v reflect.Value) interface{} { - return nil +func invalidValueEncoder(v reflect.Value) (interface{}, error) { + return nil, nil } -func doNothingEncoder(v reflect.Value) interface{} { - return v.Interface() +func doNothingEncoder(v reflect.Value) (interface{}, error) { + return v.Interface(), nil } -func marshalerEncoder(v reflect.Value) interface{} { +func marshalerEncoder(v reflect.Value) (interface{}, error) { if v.Kind() == reflect.Ptr && v.IsNil() { - return nil + return nil, nil } m := v.Interface().(Marshaler) ev, err := m.MarshalRQL() if err != nil { - panic(&MarshalerError{v.Type(), err}) + return nil, &MarshalerError{v.Type(), err} } - return ev + return ev, nil } -func addrMarshalerEncoder(v reflect.Value) interface{} { +func addrMarshalerEncoder(v reflect.Value) (interface{}, error) { va := v.Addr() if va.IsNil() { - return nil + return nil, nil } m := va.Interface().(Marshaler) ev, err := m.MarshalRQL() if err != nil { - panic(&MarshalerError{v.Type(), err}) + return nil, &MarshalerError{v.Type(), err} } - return ev + return ev, nil } -func boolEncoder(v reflect.Value) interface{} { +func boolEncoder(v reflect.Value) (interface{}, error) { if v.Bool() { - return true + return true, nil } else { - return false + return false, nil } } -func intEncoder(v reflect.Value) interface{} { - return v.Int() +func intEncoder(v reflect.Value) (interface{}, error) { + return v.Int(), nil } -func uintEncoder(v reflect.Value) interface{} { - return v.Uint() +func uintEncoder(v reflect.Value) (interface{}, error) { + return v.Uint(), nil } -func floatEncoder(v reflect.Value) interface{} { - return v.Float() +func floatEncoder(v reflect.Value) (interface{}, error) { + return v.Float(), nil } -func stringEncoder(v reflect.Value) interface{} { - return v.String() +func stringEncoder(v reflect.Value) (interface{}, error) { + return v.String(), nil } -func interfaceEncoder(v reflect.Value) interface{} { +func interfaceEncoder(v reflect.Value) (interface{}, error) { if v.IsNil() { - return nil + return nil, nil } return encode(v.Elem()) } -func funcEncoder(v reflect.Value) interface{} { +func funcEncoder(v reflect.Value) (interface{}, error) { if v.IsNil() { - return nil + return nil, nil } - return v.Interface() + return v.Interface(), nil } -func asStringEncoder(v reflect.Value) interface{} { - return fmt.Sprintf("%v", v.Interface()) +func asStringEncoder(v reflect.Value) (interface{}, error) { + return fmt.Sprintf("%v", v.Interface()), nil } -func unsupportedTypeEncoder(v reflect.Value) interface{} { - panic(&UnsupportedTypeError{v.Type()}) +func unsupportedTypeEncoder(v reflect.Value) (interface{}, error) { + return nil, &UnsupportedTypeError{v.Type()} } type structEncoder struct { @@ -146,7 +146,7 @@ type structEncoder struct { fieldEncs []encoderFunc } -func (se *structEncoder) encode(v reflect.Value) interface{} { +func (se *structEncoder) encode(v reflect.Value) (interface{}, error) { m := make(map[string]interface{}) for i, f := range se.fields { fv := fieldByIndex(v, f.index) @@ -154,7 +154,10 @@ func (se *structEncoder) encode(v reflect.Value) interface{} { continue } - encField := se.fieldEncs[i](fv) + encField, err := se.fieldEncs[i](fv) + if err != nil { + return nil, err + } // If this field is a referenced field then attempt to extract the value. if f.reference { @@ -179,7 +182,7 @@ func (se *structEncoder) encode(v reflect.Value) interface{} { m[f.name] = encField } - return m + return m, nil } func getReferenceField(f field, v reflect.Value, encField interface{}) interface{} { @@ -240,18 +243,26 @@ type mapEncoder struct { keyEnc, elemEnc encoderFunc } -func (me *mapEncoder) encode(v reflect.Value) interface{} { +func (me *mapEncoder) encode(v reflect.Value) (interface{}, error) { if v.IsNil() { - return nil + return nil, nil } m := make(map[string]interface{}) for _, k := range v.MapKeys() { - m[me.keyEnc(k).(string)] = me.elemEnc(v.MapIndex(k)) + encV, err := me.elemEnc(v.MapIndex(k)) + if err != nil { + return nil, err + } + encK, err := me.keyEnc(k) + if err != nil { + return nil, err + } + m[encK.(string)] = encV } - return m + return m, nil } func newMapEncoder(t reflect.Type) encoderFunc { @@ -282,9 +293,9 @@ type sliceEncoder struct { arrayEnc encoderFunc } -func (se *sliceEncoder) encode(v reflect.Value) interface{} { +func (se *sliceEncoder) encode(v reflect.Value) (interface{}, error) { if v.IsNil() { - return []interface{}(nil) + return []interface{}(nil), nil } return se.arrayEnc(v) } @@ -302,15 +313,19 @@ type arrayEncoder struct { elemEnc encoderFunc } -func (ae *arrayEncoder) encode(v reflect.Value) interface{} { +func (ae *arrayEncoder) encode(v reflect.Value) (interface{}, error) { n := v.Len() a := make([]interface{}, n) for i := 0; i < n; i++ { - a[i] = ae.elemEnc(v.Index(i)) + var err error + a[i], err = ae.elemEnc(v.Index(i)) + if err != nil { + return nil, err + } } - return a + return a, nil } func newArrayEncoder(t reflect.Type) encoderFunc { @@ -325,9 +340,9 @@ type ptrEncoder struct { elemEnc encoderFunc } -func (pe *ptrEncoder) encode(v reflect.Value) interface{} { +func (pe *ptrEncoder) encode(v reflect.Value) (interface{}, error) { if v.IsNil() { - return nil + return nil, nil } return pe.elemEnc(v.Elem()) } @@ -341,7 +356,7 @@ type condAddrEncoder struct { canAddrEnc, elseEnc encoderFunc } -func (ce *condAddrEncoder) encode(v reflect.Value) interface{} { +func (ce *condAddrEncoder) encode(v reflect.Value) (interface{}, error) { if v.CanAddr() { return ce.canAddrEnc(v) } else { @@ -359,7 +374,7 @@ func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { // Pseudo-type encoders // Encode a time.Time value to the TIME RQL type -func timePseudoTypeEncoder(v reflect.Value) interface{} { +func timePseudoTypeEncoder(v reflect.Value) (interface{}, error) { t := v.Interface().(time.Time) timeVal := float64(t.UnixNano()) / float64(time.Second) @@ -374,11 +389,11 @@ func timePseudoTypeEncoder(v reflect.Value) interface{} { "$reql_type$": "TIME", "epoch_time": timeVal, "timezone": t.Format("-07:00"), - } + }, nil } // Encode a byte slice to the BINARY RQL type -func encodeByteSlice(v reflect.Value) interface{} { +func encodeByteSlice(v reflect.Value) (interface{}, error) { var b []byte if !v.IsNil() { b = v.Bytes() @@ -390,11 +405,11 @@ func encodeByteSlice(v reflect.Value) interface{} { return map[string]interface{}{ "$reql_type$": "BINARY", "data": string(dst), - } + }, nil } // Encode a byte array to the BINARY RQL type -func encodeByteArray(v reflect.Value) interface{} { +func encodeByteArray(v reflect.Value) (interface{}, error) { b := make([]byte, v.Len()) for i := 0; i < v.Len(); i++ { b[i] = v.Index(i).Interface().(byte) @@ -406,5 +421,5 @@ func encodeByteArray(v reflect.Value) interface{} { return map[string]interface{}{ "$reql_type$": "BINARY", "data": string(dst), - } + }, nil } diff --git a/encoding/encoding.go b/encoding/encoding.go index 99e8bbe5..ad631d78 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -43,11 +43,11 @@ func IgnoreType(t reflect.Type) { func SetTypeEncoding( t reflect.Type, - encode func(value interface{}) interface{}, + encode func(value interface{}) (interface{}, error), decode func(encoded interface{}, value reflect.Value) error, ) { encoderCache.Lock() - encoderCache.m[t] = func(v reflect.Value) interface{} { + encoderCache.m[t] = func(v reflect.Value) (interface{}, error) { return encode(v.Interface()) } encoderCache.Unlock()