Skip to content

Commit

Permalink
Hide request handler from public access
Browse files Browse the repository at this point in the history
  • Loading branch information
nhatthm committed Apr 29, 2021
1 parent da7f95e commit 86f61ab
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
23 changes: 10 additions & 13 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ import (
// Header is a list of HTTP headers.
type Header map[string]string

// RequestHandler handles the request and returns a result or an error.
type RequestHandler func(r *http.Request) ([]byte, error)

// Request is an expectation.
type Request struct {
parent *Server
Expand All @@ -34,8 +31,8 @@ type Request struct {
StatusCode int
// ResponseHeader is a list of response headers to be sent to client when the request is handled.
ResponseHeader Header
// Do handles the request and returns a result or an error.
Do RequestHandler

handle func(r *http.Request) ([]byte, error)

// The number of times to return the return arguments when setting
// expectations. 0 means to always return the value.
Expand All @@ -60,7 +57,7 @@ func newRequest(parent *Server, method string, requestURI string) *Request {
Repeatability: 0,
waitFor: nil,

Do: func(r *http.Request) ([]byte, error) {
handle: func(r *http.Request) ([]byte, error) {
return nil, nil
},
}
Expand Down Expand Up @@ -214,7 +211,7 @@ func (r *Request) Return(v interface{}) *Request {
panic(fmt.Errorf("%w: unexpected response data type: %T", ErrUnsupportedDataType, body))
}

return r.Handler(func(_ *http.Request) ([]byte, error) {
return r.WithHandler(func(_ *http.Request) ([]byte, error) {
return body, nil
})
}
Expand All @@ -232,7 +229,7 @@ func (r *Request) Returnf(format string, args ...interface{}) *Request {
// Server.Expect(http.MethodGet, "/path").
// ReturnJSON(map[string]string{"foo": "bar"})
func (r *Request) ReturnJSON(body interface{}) *Request {
return r.Handler(func(_ *http.Request) ([]byte, error) {
return r.WithHandler(func(_ *http.Request) ([]byte, error) {
return json.Marshal(body)
})
}
Expand All @@ -249,22 +246,22 @@ func (r *Request) ReturnFile(filePath string) *Request {
panic(err)
}

return r.Handler(func(_ *http.Request) ([]byte, error) {
return r.WithHandler(func(_ *http.Request) ([]byte, error) {
// nolint:gosec // filePath is cleaned above.
return ioutil.ReadFile(filePath)
})
}

// Handler sets the handler to handle a given request.
// WithHandler sets the handler to handle a given request.
//
// Server.Expect(http.MethodGet, "/path").
// Handler(func(_ *http.Request) ([]byte, error) {
// WithHandler(func(_ *http.Request) ([]byte, error) {
// return []byte("hello world!"), nil
// })
func (r *Request) Handler(handler RequestHandler) *Request {
func (r *Request) WithHandler(handler func(r *http.Request) ([]byte, error)) *Request {
r.lock()
defer r.unlock()
r.Do = handler
r.handle = handler

return r
}
Expand Down
8 changes: 4 additions & 4 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestRequest_Return(t *testing.T) {
})
} else {
r.Return(tc.body)
result, err := r.Do(nil)
result, err := r.handle(nil)

assert.Equal(t, tc.expectedBody, result)
assert.NoError(t, err)
Expand All @@ -232,7 +232,7 @@ func TestRequest_Returnf(t *testing.T) {
t.Parallel()

r := &Request{parent: &Server{}}
result, err := r.Returnf("hello %s", "john").Do(nil)
result, err := r.Returnf("hello %s", "john").handle(nil)

expectedBody := []byte(`hello john`)

Expand All @@ -246,7 +246,7 @@ func TestRequest_ReturnJSON(t *testing.T) {
r := &Request{parent: &Server{}}
r.ReturnJSON(map[string]string{"foo": "bar"})

result, err := r.Do(nil)
result, err := r.handle(nil)

assert.Equal(t, `{"foo":"bar"}`, string(result))
assert.NoError(t, err)
Expand All @@ -264,7 +264,7 @@ func TestRequest_ReturnFile(t *testing.T) {

r.ReturnFile("resources/fixtures/response.txt")

result, err := r.Do(nil)
result, err := r.handle(nil)

assert.Equal(t, `hello world!`, strings.TrimSpace(string(result)))
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

w.WriteHeader(expected.StatusCode)

body, err := expected.Do(r)
body, err := expected.handle(r)
require.NoError(s.test, err)

_, err = w.Write(body)
Expand Down

0 comments on commit 86f61ab

Please sign in to comment.