Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle network errors and don't cancel context in request before reading response #369

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions okta/requestExecutor.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ func (re *RequestExecutor) Do(ctx context.Context, req *http.Request, v interfac
re.freshCache = false
}
if !inCache {
resp, err := re.doWithRetries(ctx, req)
resp, done, err := re.doWithRetries(ctx, req)
defer done()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -492,12 +493,13 @@ func (o *oktaBackoff) Context() context.Context {
return o.ctx
}

func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, error) {
func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, func(), error) {
var bodyReader func() io.ReadCloser
done := func() {}
if req.Body != nil {
buf, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
return nil, done, err
}
bodyReader = func() io.ReadCloser {
return io.NopCloser(bytes.NewReader(buf))
Expand All @@ -508,9 +510,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request)
err error
)
if re.config.Okta.Client.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout))
defer cancel()
ctx, done = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout))
}
bOff := &oktaBackoff{
ctx: ctx,
Expand Down Expand Up @@ -549,7 +549,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request)
return errors.New("too many requests")
}
err = backoff.Retry(operation, bOff)
return resp, err
return resp, done, err
}

func tooManyRequests(resp *http.Response) bool {
Expand Down Expand Up @@ -649,7 +649,10 @@ func CheckResponseForError(resp *http.Response) error {
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
Expand All @@ -668,7 +671,10 @@ func buildResponse(resp *http.Response, re *RequestExecutor, v interface{}) (*Re
if err != nil {
return response, err
}
bodyBytes, _ := io.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close() // close it to avoid memory leaks
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/request_executor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package unit

import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"

"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/tests"
"github.com/stretchr/testify/assert"
)

// readerFun makes it easier to implement an inline reader as a closure function.
type readerFun func(p []byte) (n int, err error)

// Read, part of io.Reader interface.
func (r readerFun) Read(p []byte) (n int, err error) { return r(p) }

// slowTransport provides a dummy http-like transport serving fixed content, but slowly.
type slowTransport struct{}

// RoundTrip, part of http.Transport interface. This servers 42 as a JSON response, but slowly.
// In particular, we serve the response immediately, but getting the body takes some milliseconds.
func (t slowTransport) RoundTrip(req *http.Request) (*http.Response, error) {
realBody := strings.NewReader("42")
// This body takes 1 millisecond to read. It also needs a valid context for the whole duration.
slowBody := func(p []byte) (n int, err error) {
select {
case <-req.Context().Done():
return 0, req.Context().Err()
case <-time.After(1 * time.Millisecond):
return realBody.Read(p)
}
}

rsp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(readerFun(slowBody)),
Header: http.Header{},
Request: req,
}
rsp.Header.Set("Content-Type", "application/json")
return rsp, nil
}

// TestExecuteRequest tests that the request executor can handle a slow response.
// In particular, we want to make sure that the context is properly passed through
// and not canceled too early.
func TestExecuteRequest(t *testing.T) {
cfg := []okta.ConfigSetter{
okta.WithOrgUrl("https://fakeurl"), // This is ignored, but required for validator.
okta.WithToken("foo"), // ditto.
okta.WithHttpClientPtr(&http.Client{Transport: slowTransport{}}), // Use our more realistic transport.
okta.WithRequestTimeout(10), // The context issues are gated with actually having a timeout.
}
ctx, cl, err := tests.NewClient(context.Background(), cfg...)
assert.NoError(t, err, "Basic client errors")
req, err := http.NewRequest("GET", "https://fakeurl", http.NoBody)
assert.NoError(t, err, "Request building")
var out int
rs, err := cl.GetRequestExecutor().Do(ctx, req, &out)
assert.NoError(t, err, "Request execution")
if rs.StatusCode != 200 || out != 42 {
t.Errorf("Got val=%d status=%d, want 42 status=200", out, rs.StatusCode)
}
}