diff --git a/httperrors/errors.go b/httperrors/errors.go index 6ba2680..92382b3 100644 --- a/httperrors/errors.go +++ b/httperrors/errors.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "reflect" + "strings" jsoniter "github.com/json-iterator/go" "github.com/mitchellh/mapstructure" @@ -14,15 +15,13 @@ var json = jsoniter.ConfigCompatibleWithStandardLibrary // New returns a new http Error object func New(httpCode int, format string, a ...any) *Error { - - text := fmt.Sprintf(format, a...) - if text == "" { - text = http.StatusText(httpCode) + if strings.TrimSpace(format) == "" { + format = http.StatusText(httpCode) } return &Error{ HTTPCode: httpCode, - Err: errors.New(text), + Err: fmt.Errorf(format, a...), } } @@ -45,6 +44,24 @@ func (e *Error) Error() string { return fmt.Sprintf("(%d): %s", e.HTTPCode, e.Err.Error()) } +func (e *Error) Clone() *Error { + err := &Error{ + HTTPCode: e.HTTPCode, + Err: e.Err, + Code: e.Code, + Meta: e.Meta, + } + + if e.Fields != nil { + err.Fields = make(map[string]any, len(e.Fields)) + } + + for k, v := range e.Fields { + err.Fields[k] = v + } + return err +} + // SetHTTPCode sets the error's http code. func (e *Error) SetHTTPCode(httpCode int) *Error { e.HTTPCode = httpCode @@ -102,6 +119,13 @@ func (e *Error) Unwrap() error { return e.Err } +func (e *Error) Is(target error) bool { + if t, ok := target.(*Error); ok { + return errors.Is(e.Err, t.Err) + } + return errors.Is(e.Err, target) +} + // MarshalJSON implements the json.Marshaler interface. func (e Error) MarshalJSON() ([]byte, error) { jsonData := map[string]any{} @@ -129,7 +153,9 @@ func (e Error) MarshalJSON() ([]byte, error) { jsonData[key.String()] = value.MapIndex(key).Interface() } default: - if _, ok := e.Meta.(error); !ok { + if err, ok := e.Meta.(error); ok { + jsonData["meta"] = err.Error() + } else { jsonData["meta"] = e.Meta } } @@ -160,24 +186,3 @@ func (e Error) MarshalJSON() ([]byte, error) { func As(err error) (t *Error, ok bool) { return t, errors.As(err, &t) } - -// Wrap httperrors wrapper helper -func Wrap(err error, httpCode ...int) (e *Error) { - if err == nil { - return nil - } - - if errors.As(err, &e) { - return e - } - - var code int - if len(httpCode) > 0 { - code = httpCode[0] - } - - return &Error{ - HTTPCode: code, - Err: err, - } -} diff --git a/httperrors/errors_test.go b/httperrors/errors_test.go index 0931438..da4fd18 100644 --- a/httperrors/errors_test.go +++ b/httperrors/errors_test.go @@ -8,6 +8,12 @@ import ( "github.com/stretchr/testify/assert" ) +type E string + +func (e E) Error() string { + return string(e) +} + type ErrorInfo struct { Err string `json:"error"` Reqid string `json:"reqid"` @@ -59,11 +65,12 @@ func TestError(t *testing.T) { "ftype": "HANDLER", }) + assert.True(err.HasField("x-request-id")) + assert.False(err.HasField("x-access-key")) + data, e := json.Marshal(err) assert.NoError(e) - fmt.Printf("data: %s\n", string(data)) - var obj map[string]any e = json.Unmarshal(data, &obj) assert.NoError(e) @@ -81,4 +88,100 @@ func TestError(t *testing.T) { assert.Equal("1.799868", fmt.Sprintf("%v", obj["latency"])) assert.Equal("HANDLER", obj["ftype"]) } + + { + err2 := err.Clone() + _ = err2.SetMeta(map[string]any{ + "error": "invalid arguments", + "reqid": "F4CD:20C1B9:2894CD0:3468624:6692A040", + "details": []string{ + "title field is required", + "content field is required", + }, + }) + + _ = err2.SetCode("INVALID_ARGUMENTS") + + data, e := json.Marshal(err2) + assert.NoError(e) + + var obj map[string]any + e = json.Unmarshal(data, &obj) + assert.NoError(e) + assert.Equal("invalid arguments", obj["error"]) + assert.Equal("F4CD:20C1B9:2894CD0:3468624:6692A040", obj["reqid"]) + assert.Equal([]any{ + "title field is required", + "content field is required", + }, obj["details"]) + assert.Equal("INVALID_ARGUMENTS", fmt.Sprintf("%v", obj["code"])) + } + + { + e := New(400, "") + assert.Equal(400, e.HTTPCode) + assert.Equal(400, e.StatusCode()) + assert.Equal("(400): Bad Request", e.Error()) + } + + { + err2 := &Error{} + assert.Equal("", err2.Error()) + _ = err2.SetHTTPCode(500) + _ = err2.SetCode("internal_error") + + assert.Equal(500, err2.HTTPCode) + assert.Equal("internal_error", err2.Code) + + assert.Nil(err2.Fields) + assert.False(err2.HasField("foo")) + + _ = err2.AddFields(map[string]any{"foo": "bar"}) + assert.NotNil(err2.Fields) + assert.Equal(map[string]any{"foo": "bar"}, err2.Fields) + + _ = err2.SetMeta(E("database internal error")) + data, e := json.Marshal(err2) + assert.NoError(e) + + var obj map[string]any + e = json.Unmarshal(data, &obj) + assert.NoError(e) + assert.Equal("database internal error", obj["meta"]) + + _ = err2.SetMeta(123) + data, e = json.Marshal(err2) + assert.NoError(e) + + e = json.Unmarshal(data, &obj) + assert.NoError(e) + assert.Equal(float64(123), obj["meta"]) + } + + { + e := err.Clone() + assert.Equal(400, e.HTTPCode) + assert.Equal(400, e.StatusCode()) + assert.Equal(err.Error(), e.Error()) + assert.Equal(errors.New("invalid arguments"), e.Unwrap()) + assert.Equal(err, e) + assert.True(errors.Is(err, e)) + assert.True(e.Is(err)) + assert.False(e.Is(errors.New("invalid arguments"))) + } + + { + e := fmt.Errorf("error: %w", err) + assert.True(errors.Is(e, err)) + } + + { + e, ok := As(err) + assert.True(ok) + assert.Equal(err, e) + + e, ok = As(errors.New("error")) + assert.False(ok) + assert.NotEqual(err, e) + } }