diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 71b2761..75c4629 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@v2 - name: Start Centrifugo - run: docker run -d -p 8000:8000 centrifugo/centrifugo:latest centrifugo --client_insecure + run: docker run -d -p 8000:8000 centrifugo/centrifugo:latest centrifugo --client.insecure - name: Test run: go test -race -v diff --git a/errors.go b/errors.go index d79fdc5..6b9ff93 100644 --- a/errors.go +++ b/errors.go @@ -33,6 +33,10 @@ func (t TransportError) Error() string { return fmt.Sprintf("transport error: %v", t.Err) } +func (t TransportError) Unwrap() error { + return t.Err +} + type ConnectError struct { Err error } @@ -41,6 +45,10 @@ func (c ConnectError) Error() string { return fmt.Sprintf("connect error: %v", c.Err) } +func (c ConnectError) Unwrap() error { + return c.Err +} + type RefreshError struct { Err error } @@ -49,12 +57,20 @@ func (r RefreshError) Error() string { return fmt.Sprintf("refresh error: %v", r.Err) } +func (r RefreshError) Unwrap() error { + return r.Err +} + type ConfigurationError struct { Err error } -func (r ConfigurationError) Error() string { - return fmt.Sprintf("configuration error: %v", r.Err) +func (c ConfigurationError) Error() string { + return fmt.Sprintf("configuration error: %v", c.Err) +} + +func (c ConfigurationError) Unwrap() error { + return c.Err } type SubscriptionSubscribeError struct { @@ -65,6 +81,10 @@ func (s SubscriptionSubscribeError) Error() string { return fmt.Sprintf("subscribe error: %v", s.Err) } +func (s SubscriptionSubscribeError) Unwrap() error { + return s.Err +} + type SubscriptionRefreshError struct { Err error } @@ -72,3 +92,7 @@ type SubscriptionRefreshError struct { func (s SubscriptionRefreshError) Error() string { return fmt.Sprintf("refresh error: %v", s.Err) } + +func (s SubscriptionRefreshError) Unwrap() error { + return s.Err +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..7e7214c --- /dev/null +++ b/errors_test.go @@ -0,0 +1,73 @@ +package centrifuge_test + +import ( + "errors" + "strings" + "testing" + + "github.com/centrifugal/centrifuge-go" +) + +func TestErrors(t *testing.T) { + cases := []struct { + name string + rootError error + factory func(err error) error + }{ + { + name: "SubscriptionSubscribeError", + rootError: centrifuge.ErrTimeout, + factory: func(err error) error { + return centrifuge.SubscriptionSubscribeError{Err: err} + }, + }, + { + name: "SubscriptionRefreshError", + rootError: centrifuge.ErrUnauthorized, + factory: func(err error) error { + return centrifuge.SubscriptionRefreshError{Err: err} + }, + }, + { + name: "ConfigurationError", + rootError: centrifuge.ErrClientClosed, + factory: func(err error) error { + return centrifuge.ConfigurationError{Err: err} + }, + }, + { + name: "RefreshError", + rootError: centrifuge.ErrClientClosed, + factory: func(err error) error { + return centrifuge.RefreshError{Err: err} + }, + }, + { + name: "ConnectError", + rootError: centrifuge.ErrClientClosed, + factory: func(err error) error { + return centrifuge.ConnectError{Err: err} + }, + }, + { + name: "TransportError", + rootError: centrifuge.ErrClientClosed, + factory: func(err error) error { + return centrifuge.TransportError{Err: err} + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := c.factory(c.rootError) + parts := strings.Split(err.Error(), ": ") + if parts[1] != c.rootError.Error() { + t.Errorf("unexpected error string: %v", err) + } + + if !errors.Is(err, c.rootError) { + t.Errorf("expected root error to be wrapped") + } + }) + } +}