From 1de53527a1f6983700030c5e6a579cb52601b1c0 Mon Sep 17 00:00:00 2001 From: Nate Brennand Date: Sat, 29 Oct 2022 18:00:48 -0700 Subject: [PATCH] add default client method & test --- client/base_client.go | 4 ++++ client/client.go | 11 +++++++++-- client/client_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/client/base_client.go b/client/base_client.go index 2e7b16ae6..57ec235cd 100644 --- a/client/base_client.go +++ b/client/base_client.go @@ -35,5 +35,9 @@ func (w wrapperClient) SendRequestWithCtx(ctx context.Context, method string, ra // wrapBaseClientWithNoopCtx "upgrades" a BaseClient to BaseClientWithCtx so that requests can be // send with a request context. func wrapBaseClientWithNoopCtx(c BaseClient) BaseClientWithCtx { + // the default library client has SendRequestWithCtx, use it if available. + if typedClient, ok := c.(BaseClientWithCtx); ok { + return typedClient + } return wrapperClient{BaseClient: c} } diff --git a/client/client.go b/client/client.go index b7b8f874a..eddfdd059 100644 --- a/client/client.go +++ b/client/client.go @@ -2,6 +2,7 @@ package client import ( + "context" "encoding/json" "fmt" "net/http" @@ -44,7 +45,7 @@ func defaultHTTPClient() *http.Client { } } -func (c *Client) basicAuth() (string, string) { +func (c *Client) basicAuth() (username, password string) { return c.Credentials.Username, c.Credentials.Password } @@ -89,6 +90,12 @@ func (c *Client) doWithErr(req *http.Request) (*http.Response, error) { // SendRequest verifies, constructs, and authorizes an HTTP request. func (c *Client) SendRequest(method string, rawURL string, data url.Values, + headers map[string]interface{}) (*http.Response, error) { + return c.SendRequestWithCtx(context.TODO(), method, rawURL, data, headers) +} + +// SendRequestWithCtx verifies, constructs, and authorizes an HTTP request. +func (c *Client) SendRequestWithCtx(ctx context.Context, method string, rawURL string, data url.Values, headers map[string]interface{}) (*http.Response, error) { u, err := url.Parse(rawURL) if err != nil { @@ -112,7 +119,7 @@ func (c *Client) SendRequest(method string, rawURL string, data url.Values, valueReader = strings.NewReader(data.Encode()) } - req, err := http.NewRequest(method, u.String(), valueReader) + req, err := http.NewRequestWithContext(ctx, method, u.String(), valueReader) if err != nil { return nil, err } diff --git a/client/client_test.go b/client/client_test.go index d1e7d88aa..4f011201c 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client_test import ( + "context" "encoding/json" "io" "net/http" @@ -210,6 +211,33 @@ func TestClient_SetTimeoutTimesOut(t *testing.T) { assert.Error(t, err) } +func TestClient_SetTimeoutTimesOutViaContext(t *testing.T) { + handlerDelay := 100 * time.Microsecond + clientTimeout := 10 * time.Microsecond + assert.True(t, clientTimeout < handlerDelay) + + timeoutServer := httptest.NewServer(http.HandlerFunc( + func(writer http.ResponseWriter, _ *http.Request) { + d := map[string]interface{}{ + "response": "ok", + } + time.Sleep(100 * time.Microsecond) + encoder := json.NewEncoder(writer) + err := encoder.Encode(&d) + if err != nil { + t.Error(err) + } + writer.WriteHeader(http.StatusOK) + })) + defer timeoutServer.Close() + + c := NewClient("user", "pass") + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Microsecond) + defer cancel() + _, err := c.SendRequestWithCtx(ctx, "GET", timeoutServer.URL, nil, nil) //nolint:bodyclose + assert.Error(t, err) +} + func TestClient_SetTimeoutSucceeds(t *testing.T) { timeoutServer := httptest.NewServer(http.HandlerFunc( func(writer http.ResponseWriter, request *http.Request) {