From c03ad0e3e879edfd01cbf36f198283ac41a4af27 Mon Sep 17 00:00:00 2001 From: Dan Wendorf Date: Mon, 8 Jan 2024 13:02:23 -0800 Subject: [PATCH 1/2] Add new `threadsafe` package The root gock package is not threadsafe because it modifies http.DefaultTransport. This can cause a race condition when other code is reading http.DefaultTransport. `threadsafe` is a reimplementation of the entire root package using a new `gock` struct to hold data that was previously global. --- threadsafe/gock.go | 270 ++++++++++++++++++ threadsafe/gock_test.go | 517 +++++++++++++++++++++++++++++++++++ threadsafe/matcher.go | 116 ++++++++ threadsafe/matcher_test.go | 172 ++++++++++++ threadsafe/matchers.go | 240 ++++++++++++++++ threadsafe/matchers_test.go | 253 +++++++++++++++++ threadsafe/mock.go | 172 ++++++++++++ threadsafe/mock_test.go | 143 ++++++++++ threadsafe/options.go | 8 + threadsafe/request.go | 330 ++++++++++++++++++++++ threadsafe/request_test.go | 318 +++++++++++++++++++++ threadsafe/responder.go | 111 ++++++++ threadsafe/responder_test.go | 191 +++++++++++++ threadsafe/response.go | 198 ++++++++++++++ threadsafe/response_test.go | 186 +++++++++++++ threadsafe/store.go | 90 ++++++ threadsafe/store_test.go | 95 +++++++ threadsafe/transport.go | 112 ++++++++ threadsafe/transport_test.go | 55 ++++ 19 files changed, 3577 insertions(+) create mode 100644 threadsafe/gock.go create mode 100644 threadsafe/gock_test.go create mode 100644 threadsafe/matcher.go create mode 100644 threadsafe/matcher_test.go create mode 100644 threadsafe/matchers.go create mode 100644 threadsafe/matchers_test.go create mode 100644 threadsafe/mock.go create mode 100644 threadsafe/mock_test.go create mode 100644 threadsafe/options.go create mode 100644 threadsafe/request.go create mode 100644 threadsafe/request_test.go create mode 100644 threadsafe/responder.go create mode 100644 threadsafe/responder_test.go create mode 100644 threadsafe/response.go create mode 100644 threadsafe/response_test.go create mode 100644 threadsafe/store.go create mode 100644 threadsafe/store_test.go create mode 100644 threadsafe/transport.go create mode 100644 threadsafe/transport_test.go diff --git a/threadsafe/gock.go b/threadsafe/gock.go new file mode 100644 index 0000000..7df7b9a --- /dev/null +++ b/threadsafe/gock.go @@ -0,0 +1,270 @@ +package threadsafe + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "sync" +) + +type Gock struct { + // mutex is used internally for locking thread-sensitive functions. + mutex sync.Mutex + // config global singleton store. + config struct { + Networking bool + NetworkingFilters []FilterRequestFunc + Observer ObserverFunc + } + // DumpRequest is a default implementation of ObserverFunc that dumps + // the HTTP/1.x wire representation of the http request + DumpRequest ObserverFunc + // track unmatched requests so they can be tested for + unmatchedRequests []*http.Request + + // storeMutex is used internally for store synchronization. + storeMutex sync.RWMutex + + // mocks is internally used to store registered mocks. + mocks []Mock + + // DefaultMatcher stores the default Matcher instance used to match mocks. + DefaultMatcher *MockMatcher + + // MatchersHeader exposes a slice of HTTP header specific mock matchers. + MatchersHeader []MatchFunc + // MatchersBody exposes a slice of HTTP body specific built-in mock matchers. + MatchersBody []MatchFunc + // Matchers stores all the built-in mock matchers. + Matchers []MatchFunc + + // BodyTypes stores the supported MIME body types for matching. + // Currently only text-based types. + BodyTypes []string + + // BodyTypeAliases stores a generic MIME type by alias. + BodyTypeAliases map[string]string + + // CompressionSchemes stores the supported Content-Encoding types for decompression. + CompressionSchemes []string + + intercepting bool + + DisableCallback func() + InterceptCallback func() + InterceptingCallback func() bool +} + +func NewGock() *Gock { + g := &Gock{ + DumpRequest: defaultDumpRequest, + + BodyTypes: []string{ + "text/html", + "text/plain", + "application/json", + "application/xml", + "multipart/form-data", + "application/x-www-form-urlencoded", + }, + + BodyTypeAliases: map[string]string{ + "html": "text/html", + "text": "text/plain", + "json": "application/json", + "xml": "application/xml", + "form": "multipart/form-data", + "url": "application/x-www-form-urlencoded", + }, + + // CompressionSchemes stores the supported Content-Encoding types for decompression. + CompressionSchemes: []string{ + "gzip", + }, + } + g.MatchersHeader = []MatchFunc{ + g.MatchMethod, + g.MatchScheme, + g.MatchHost, + g.MatchPath, + g.MatchHeaders, + g.MatchQueryParams, + g.MatchPathParams, + } + g.MatchersBody = []MatchFunc{ + g.MatchBody, + } + g.Matchers = append(g.MatchersHeader, g.MatchersBody...) + + // DefaultMatcher stores the default Matcher instance used to match mocks. + g.DefaultMatcher = g.NewMatcher() + return g +} + +// ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic +type ObserverFunc func(*http.Request, Mock) + +func defaultDumpRequest(request *http.Request, mock Mock) { + bytes, _ := httputil.DumpRequestOut(request, true) + fmt.Println(string(bytes)) + fmt.Printf("\nMatches: %v\n---\n", mock != nil) +} + +// New creates and registers a new HTTP mock with +// default settings and returns the Request DSL for HTTP mock +// definition and set up. +func (g *Gock) New(uri string) *Request { + g.Intercept() + + res := g.NewResponse() + req := g.NewRequest() + req.URLStruct, res.Error = url.Parse(normalizeURI(uri)) + + // Create the new mock expectation + exp := g.NewMock(req, res) + g.Register(exp) + + return req +} + +// Intercepting returns true if gock is currently able to intercept. +func (g *Gock) Intercepting() bool { + g.mutex.Lock() + defer g.mutex.Unlock() + + callbackResponse := true + if g.InterceptingCallback != nil { + callbackResponse = g.InterceptingCallback() + } + + return g.intercepting && callbackResponse +} + +// Intercept enables HTTP traffic interception via http.DefaultTransport. +// If you are using a custom HTTP transport, you have to use `gock.Transport()` +func (g *Gock) Intercept() { + if !g.Intercepting() { + g.mutex.Lock() + g.intercepting = true + + if g.InterceptCallback != nil { + g.InterceptCallback() + } + + g.mutex.Unlock() + } +} + +// InterceptClient allows the developer to intercept HTTP traffic using +// a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation. +func (g *Gock) InterceptClient(cli *http.Client) { + _, ok := cli.Transport.(*Transport) + if ok { + return // if transport already intercepted, just ignore it + } + cli.Transport = g.NewTransport(cli.Transport) +} + +// RestoreClient allows the developer to disable and restore the +// original transport in the given http.Client. +func (g *Gock) RestoreClient(cli *http.Client) { + trans, ok := cli.Transport.(*Transport) + if !ok { + return + } + cli.Transport = trans.Transport +} + +// Disable disables HTTP traffic interception by gock. +func (g *Gock) Disable() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.intercepting = false + + if g.DisableCallback != nil { + g.DisableCallback() + } +} + +// Off disables the default HTTP interceptors and removes +// all the registered mocks, even if they has not been intercepted yet. +func (g *Gock) Off() { + g.Flush() + g.Disable() +} + +// OffAll is like `Off()`, but it also removes the unmatched requests registry. +func (g *Gock) OffAll() { + g.Flush() + g.Disable() + g.CleanUnmatchedRequest() +} + +// Observe provides a hook to support inspection of the request and matched mock +func (g *Gock) Observe(fn ObserverFunc) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Observer = fn +} + +// EnableNetworking enables real HTTP networking +func (g *Gock) EnableNetworking() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Networking = true +} + +// DisableNetworking disables real HTTP networking +func (g *Gock) DisableNetworking() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.Networking = false +} + +// NetworkingFilter determines if an http.Request should be triggered or not. +func (g *Gock) NetworkingFilter(fn FilterRequestFunc) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.NetworkingFilters = append(g.config.NetworkingFilters, fn) +} + +// DisableNetworkingFilters disables registered networking filters. +func (g *Gock) DisableNetworkingFilters() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.config.NetworkingFilters = []FilterRequestFunc{} +} + +// GetUnmatchedRequests returns all requests that have been received but haven't matched any mock +func (g *Gock) GetUnmatchedRequests() []*http.Request { + g.mutex.Lock() + defer g.mutex.Unlock() + return g.unmatchedRequests +} + +// HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock +func (g *Gock) HasUnmatchedRequest() bool { + return len(g.GetUnmatchedRequests()) > 0 +} + +// CleanUnmatchedRequest cleans the unmatched requests internal registry. +func (g *Gock) CleanUnmatchedRequest() { + g.mutex.Lock() + defer g.mutex.Unlock() + g.unmatchedRequests = []*http.Request{} +} + +func (g *Gock) trackUnmatchedRequest(req *http.Request) { + g.mutex.Lock() + defer g.mutex.Unlock() + g.unmatchedRequests = append(g.unmatchedRequests, req) +} + +func normalizeURI(uri string) string { + if ok, _ := regexp.MatchString("^http[s]?", uri); !ok { + return "http://" + uri + } + return uri +} diff --git a/threadsafe/gock_test.go b/threadsafe/gock_test.go new file mode 100644 index 0000000..a503a28 --- /dev/null +++ b/threadsafe/gock_test.go @@ -0,0 +1,517 @@ +package threadsafe + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/nbio/st" +) + +func TestMockSimple(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockOff(t *testing.T) { + g := NewGock() + g.New("http://foo.com").Reply(201).JSON(map[string]string{"foo": "bar"}) + g.Off() + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Get("http://127.0.0.1:3123") + st.Reject(t, err, nil) +} + +func TestMockBodyStringResponse(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(200).BodyString("foo bar") + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo bar") +} + +func TestMockBodyMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").BodyString("foo bar").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockBodyCannotMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").BodyString("foo foo").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchCompressed(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo") + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte("foo bar")) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "text/plain") + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockBodyCannotMatchCompressed(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Compression("gzip").BodyString("foo bar").Reply(201).BodyString("foo foo") + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"foo": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"bar":"foo"}`) +} + +func TestMockBodyCannotMatchJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"bar": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Post("http://foo.com/bar", "application/json", bytes.NewBuffer([]byte(`{"foo":"bar"}`))) + st.Reject(t, err, nil) +} + +func TestMockBodyMatchCompressedJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + Compression("gzip"). + JSON(map[string]string{"foo": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte(`{"foo":"bar"}`)) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "application/json") + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"bar":"foo"}`) +} + +func TestMockBodyCannotMatchCompressedJSON(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + JSON(map[string]string{"bar": "bar"}). + Reply(201). + JSON(map[string]string{"bar": "foo"}) + + var compressed bytes.Buffer + w := gzip.NewWriter(&compressed) + w.Write([]byte(`{"foo":"bar"}`)) + w.Close() + c := &http.Client{} + g.InterceptClient(c) + req, err := http.NewRequest("POST", "http://foo.com/bar", &compressed) + st.Expect(t, err, nil) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "application/json") + _, err = c.Do(req) + st.Reject(t, err, nil) +} + +func TestMockMatchHeaders(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + MatchHeader("Content-Type", "(.*)/plain"). + Reply(200). + BodyString("foo foo") + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Post("http://foo.com", "text/plain", bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo foo") +} + +func TestMockMap(t *testing.T) { + g := NewGock() + defer after(g) + + mock := g.New("http://bar.com") + mock.Map(func(req *http.Request) *http.Request { + req.URL.Host = "bar.com" + return req + }) + mock.Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockFilter(t *testing.T) { + g := NewGock() + defer after(g) + + mock := g.New("http://foo.com") + mock.Filter(func(req *http.Request) bool { + return req.URL.Host == "foo.com" + }) + mock.Reply(201).JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 201) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestMockCounterDisabled(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get("http://foo.com") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, len(g.GetAll()), 0) +} + +func TestMockEnableNetwork(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.EnableNetworking() + defer g.DisableNetworking() + + g.New(ts.URL).Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, len(g.GetAll()), 0) + + res, err = c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) +} + +func TestMockEnableNetworkFilter(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.EnableNetworking() + defer g.DisableNetworking() + + g.NetworkingFilter(func(req *http.Request) bool { + return strings.Contains(req.URL.Host, "127.0.0.1") + }) + defer g.DisableNetworkingFilters() + + g.New(ts.URL).Reply(0).SetHeader("foo", "bar") + st.Expect(t, len(g.GetAll()), 1) + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Get(ts.URL) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, len(g.GetAll()), 0) +} + +func TestMockPersistent(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Get("/bar"). + Persist(). + Reply(200). + JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + for i := 0; i < 5; i++ { + res, err := c.Get("http://foo.com/bar") + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) + } +} + +func TestMockPersistTimes(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://127.0.0.1:1234"). + Get("/bar"). + Times(4). + Reply(200). + JSON(map[string]string{"foo": "bar"}) + + c := &http.Client{} + g.InterceptClient(c) + for i := 0; i < 5; i++ { + res, err := c.Get("http://127.0.0.1:1234/bar") + if i == 4 { + st.Reject(t, err, nil) + break + } + + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) + } +} + +func TestUnmatched(t *testing.T) { + g := NewGock() + defer after(g) + + // clear out any unmatchedRequests from other tests + g.unmatchedRequests = []*http.Request{} + + g.Intercept() + + c := &http.Client{} + g.InterceptClient(c) + _, err := c.Get("http://server.com/unmatched") + st.Reject(t, err, nil) + + unmatched := g.GetUnmatchedRequests() + st.Expect(t, len(unmatched), 1) + st.Expect(t, unmatched[0].URL.Host, "server.com") + st.Expect(t, unmatched[0].URL.Path, "/unmatched") + st.Expect(t, g.HasUnmatchedRequest(), true) +} + +func TestMultipleMocks(t *testing.T) { + g := NewGock() + defer g.Disable() + + g.New("http://server.com"). + Get("/foo"). + Reply(200). + JSON(map[string]string{"value": "foo"}) + + g.New("http://server.com"). + Get("/bar"). + Reply(200). + JSON(map[string]string{"value": "bar"}) + + g.New("http://server.com"). + Get("/baz"). + Reply(200). + JSON(map[string]string{"value": "baz"}) + + tests := []struct { + path string + }{ + {"/foo"}, + {"/bar"}, + {"/baz"}, + } + + c := &http.Client{} + g.InterceptClient(c) + for _, test := range tests { + res, err := c.Get("http://server.com" + test.path) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:15], `{"value":"`+test.path[1:]+`"}`) + } + + _, err := c.Get("http://server.com/foo") + st.Reject(t, err, nil) +} + +func TestInterceptClient(t *testing.T) { + g := NewGock() + defer after(g) + + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + req, err := http.NewRequest("GET", "http://foo.com", nil) + client := &http.Client{Transport: &http.Transport{}} + g.InterceptClient(client) + + res, err := client.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) +} + +func TestRestoreClient(t *testing.T) { + g := NewGock() + defer after(g) + + g.New("http://foo.com").Reply(204) + st.Expect(t, len(g.GetAll()), 1) + + req, err := http.NewRequest("GET", "http://foo.com", nil) + client := &http.Client{Transport: &http.Transport{}} + g.InterceptClient(client) + trans := client.Transport + + res, err := client.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + + g.RestoreClient(client) + st.Reject(t, trans, client.Transport) +} + +func TestMockRegExpMatching(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com"). + Post("/bar"). + MatchHeader("Authorization", "Bearer (.*)"). + BodyString(`{"foo":".*"}`). + Reply(200). + SetHeader("Server", "gock"). + JSON(map[string]string{"foo": "bar"}) + + req, _ := http.NewRequest("POST", "http://foo.com/bar", bytes.NewBuffer([]byte(`{"foo":"baz"}`))) + req.Header.Set("Authorization", "Bearer s3cr3t") + + c := &http.Client{} + g.InterceptClient(c) + res, err := c.Do(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 200) + st.Expect(t, res.Header.Get("Server"), "gock") + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body)[:13], `{"foo":"bar"}`) +} + +func TestObserve(t *testing.T) { + g := NewGock() + defer after(g) + var observedRequest *http.Request + var observedMock Mock + g.Observe(func(request *http.Request, mock Mock) { + observedRequest = request + observedMock = mock + }) + g.New("http://observe-foo.com").Reply(200) + req, _ := http.NewRequest("POST", "http://observe-foo.com", nil) + + c := &http.Client{} + g.InterceptClient(c) + c.Do(req) + + st.Expect(t, observedRequest.Host, "observe-foo.com") + st.Expect(t, observedMock.Request().URLStruct.Host, "observe-foo.com") +} + +func TestTryCreatingRacesInNew(t *testing.T) { + g := NewGock() + defer after(g) + for i := 0; i < 10; i++ { + go func() { + g.New("http://example.com") + }() + } +} + +func after(g *Gock) { + g.Flush() + g.Disable() +} diff --git a/threadsafe/matcher.go b/threadsafe/matcher.go new file mode 100644 index 0000000..ad10a84 --- /dev/null +++ b/threadsafe/matcher.go @@ -0,0 +1,116 @@ +package threadsafe + +import "net/http" + +// MatchFunc represents the required function +// interface implemented by matchers. +type MatchFunc func(*http.Request, *Request) (bool, error) + +// Matcher represents the required interface implemented by mock matchers. +type Matcher interface { + // Get returns a slice of registered function matchers. + Get() []MatchFunc + + // Add adds a new matcher function. + Add(MatchFunc) + + // Set sets the matchers functions stack. + Set([]MatchFunc) + + // Flush flushes the current matchers function stack. + Flush() + + // Match matches the given http.Request with a mock Request. + Match(*http.Request, *Request) (bool, error) +} + +// MockMatcher implements a mock matcher +type MockMatcher struct { + Matchers []MatchFunc + g *Gock +} + +// NewMatcher creates a new mock matcher +// using the default matcher functions. +func (g *Gock) NewMatcher() *MockMatcher { + m := g.NewEmptyMatcher() + for _, matchFn := range g.Matchers { + m.Add(matchFn) + } + return m +} + +// NewBasicMatcher creates a new matcher with header only mock matchers. +func (g *Gock) NewBasicMatcher() *MockMatcher { + m := g.NewEmptyMatcher() + for _, matchFn := range g.MatchersHeader { + m.Add(matchFn) + } + return m +} + +// NewEmptyMatcher creates a new empty matcher without default matchers. +func (g *Gock) NewEmptyMatcher() *MockMatcher { + return &MockMatcher{g: g, Matchers: []MatchFunc{}} +} + +// Get returns a slice of registered function matchers. +func (m *MockMatcher) Get() []MatchFunc { + m.g.mutex.Lock() + defer m.g.mutex.Unlock() + return m.Matchers +} + +// Add adds a new function matcher. +func (m *MockMatcher) Add(fn MatchFunc) { + m.Matchers = append(m.Matchers, fn) +} + +// Set sets a new stack of matchers functions. +func (m *MockMatcher) Set(stack []MatchFunc) { + m.Matchers = stack +} + +// Flush flushes the current matcher +func (m *MockMatcher) Flush() { + m.Matchers = []MatchFunc{} +} + +// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs +func (m *MockMatcher) Clone() *MockMatcher { + m2 := m.g.NewEmptyMatcher() + for _, mFn := range m.Get() { + m2.Add(mFn) + } + return m2 +} + +// Match matches the given http.Request with a mock request +// returning true in case that the request matches, otherwise false. +func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) { + for _, matcher := range m.Matchers { + matches, err := matcher(req, ereq) + if err != nil { + return false, err + } + if !matches { + return false, nil + } + } + return true, nil +} + +// MatchMock is a helper function that matches the given http.Request +// in the list of registered mocks, returning it if matches or error if it fails. +func (g *Gock) MatchMock(req *http.Request) (Mock, error) { + for _, mock := range g.GetAll() { + matches, err := mock.Match(req) + if err != nil { + return nil, err + } + if matches { + return mock, nil + } + } + return nil, nil +} diff --git a/threadsafe/matcher_test.go b/threadsafe/matcher_test.go new file mode 100644 index 0000000..c33a475 --- /dev/null +++ b/threadsafe/matcher_test.go @@ -0,0 +1,172 @@ +package threadsafe + +import ( + "net/http" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestRegisteredMatchers(t *testing.T) { + g := NewGock() + st.Expect(t, len(g.MatchersHeader), 7) + st.Expect(t, len(g.MatchersBody), 1) +} + +func TestNewMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + // Funcs are not comparable, checking slice length as it's better than nothing + // See https://golang.org/pkg/reflect/#DeepEqual + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + st.Expect(t, len(matcher.Get()), len(g.Matchers)) +} + +func TestNewBasicMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewBasicMatcher() + // Funcs are not comparable, checking slice length as it's better than nothing + // See https://golang.org/pkg/reflect/#DeepEqual + st.Expect(t, len(matcher.Matchers), len(g.MatchersHeader)) + st.Expect(t, len(matcher.Get()), len(g.MatchersHeader)) +} + +func TestNewEmptyMatcher(t *testing.T) { + g := NewGock() + matcher := g.NewEmptyMatcher() + st.Expect(t, len(matcher.Matchers), 0) + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherAdd(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, len(matcher.Get()), len(g.Matchers)+1) +} + +func TestMatcherSet(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + matchers := []MatchFunc{} + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Set(matchers) + st.Expect(t, matcher.Matchers, matchers) + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherGet(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + matchers := []MatchFunc{} + matcher.Set(matchers) + st.Expect(t, matcher.Get(), matchers) +} + +func TestMatcherFlush(t *testing.T) { + g := NewGock() + matcher := g.NewMatcher() + st.Expect(t, len(matcher.Matchers), len(g.Matchers)) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, len(matcher.Get()), len(g.Matchers)+1) + matcher.Flush() + st.Expect(t, len(matcher.Get()), 0) +} + +func TestMatcherClone(t *testing.T) { + g := NewGock() + matcher := g.DefaultMatcher.Clone() + st.Expect(t, len(matcher.Get()), len(g.DefaultMatcher.Get())) +} + +func TestMatcher(t *testing.T) { + cases := []struct { + method string + url string + matches bool + }{ + {"GET", "http://foo.com/bar", true}, + {"GET", "http://foo.com/baz", true}, + {"GET", "http://foo.com/foo", false}, + {"POST", "http://foo.com/bar", false}, + {"POST", "http://bar.com/bar", false}, + {"GET", "http://foo.com", false}, + } + + g := NewGock() + matcher := g.NewMatcher() + matcher.Flush() + st.Expect(t, len(matcher.Matchers), 0) + + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Method == "GET", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil + }) + + for _, test := range cases { + u, _ := url.Parse(test.url) + req := &http.Request{Method: test.method, URL: u} + matches, err := matcher.Match(req, nil) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchMock(t *testing.T) { + cases := []struct { + method string + url string + matches bool + }{ + {"GET", "http://foo.com/bar", true}, + {"GET", "http://foo.com/baz", true}, + {"GET", "http://foo.com/foo", false}, + {"POST", "http://foo.com/bar", false}, + {"POST", "http://bar.com/bar", false}, + {"GET", "http://foo.com", false}, + } + + g := NewGock() + matcher := g.DefaultMatcher + matcher.Flush() + st.Expect(t, len(matcher.Matchers), 0) + + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Method == "GET", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Path == "/baz" || req.URL.Path == "/bar", nil + }) + + for _, test := range cases { + g.Flush() + mock := g.New(test.url).method(test.method, "").Mock + + u, _ := url.Parse(test.url) + req := &http.Request{Method: test.method, URL: u} + + match, err := g.MatchMock(req) + st.Expect(t, err, nil) + if test.matches { + st.Expect(t, match, mock) + } else { + st.Expect(t, match, nil) + } + } + + g.DefaultMatcher.Matchers = g.Matchers +} diff --git a/threadsafe/matchers.go b/threadsafe/matchers.go new file mode 100644 index 0000000..9b1c0b3 --- /dev/null +++ b/threadsafe/matchers.go @@ -0,0 +1,240 @@ +package threadsafe + +import ( + "compress/gzip" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "reflect" + "regexp" + "strings" + + "github.com/h2non/parth" +) + +// EOL represents the end of line character. +const EOL = 0xa + +// MatchMethod matches the HTTP method of the given request. +func (g *Gock) MatchMethod(req *http.Request, ereq *Request) (bool, error) { + return ereq.Method == "" || req.Method == ereq.Method, nil +} + +// MatchScheme matches the request URL protocol scheme. +func (g *Gock) MatchScheme(req *http.Request, ereq *Request) (bool, error) { + return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil +} + +// MatchHost matches the HTTP host header field of the given request. +func (g *Gock) MatchHost(req *http.Request, ereq *Request) (bool, error) { + url := ereq.URLStruct + if strings.EqualFold(url.Host, req.URL.Host) { + return true, nil + } + if !ereq.Options.DisableRegexpHost { + return regexp.MatchString(url.Host, req.URL.Host) + } + return false, nil +} + +// MatchPath matches the HTTP URL path of the given request. +func (g *Gock) MatchPath(req *http.Request, ereq *Request) (bool, error) { + if req.URL.Path == ereq.URLStruct.Path { + return true, nil + } + return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path) +} + +// MatchHeaders matches the headers fields of the given request. +func (g *Gock) MatchHeaders(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.Header { + var err error + var match bool + var matchEscaped bool + + for _, field := range req.Header[key] { + match, err = regexp.MatchString(value[0], field) + // Some values may contain reserved regex params e.g. "()", try matching with these escaped. + matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field) + + if err != nil { + return false, err + } + if match || matchEscaped { + break + } + + } + + if !match && !matchEscaped { + return false, nil + } + } + return true, nil +} + +// MatchQueryParams matches the URL query params fields of the given request. +func (g *Gock) MatchQueryParams(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.URLStruct.Query() { + var err error + var match bool + + for _, field := range req.URL.Query()[key] { + match, err = regexp.MatchString(value[0], field) + if err != nil { + return false, err + } + if match { + break + } + } + + if !match { + return false, nil + } + } + return true, nil +} + +// MatchPathParams matches the URL path parameters of the given request. +func (g *Gock) MatchPathParams(req *http.Request, ereq *Request) (bool, error) { + for key, value := range ereq.PathParams { + var s string + + if err := parth.Sequent(req.URL.Path, key, &s); err != nil { + return false, nil + } + + if s != value { + return false, nil + } + } + return true, nil +} + +// MatchBody tries to match the request body. +// TODO: not too smart now, needs several improvements. +func (g *Gock) MatchBody(req *http.Request, ereq *Request) (bool, error) { + // If match body is empty, just continue + if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 { + return true, nil + } + + // Only can match certain MIME body types + if !g.supportedType(req, ereq) { + return false, nil + } + + // Can only match certain compression schemes + if !g.supportedCompressionScheme(req) { + return false, nil + } + + // Create a reader for the body depending on compression type + bodyReader := req.Body + if ereq.CompressionScheme != "" { + if ereq.CompressionScheme != req.Header.Get("Content-Encoding") { + return false, nil + } + compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme) + if err != nil { + return false, err + } + bodyReader = compressedBodyReader + } + + // Read the whole request body + body, err := ioutil.ReadAll(bodyReader) + if err != nil { + return false, err + } + + // Restore body reader stream + req.Body = createReadCloser(body) + + // If empty, ignore the match + if len(body) == 0 && len(ereq.BodyBuffer) != 0 { + return false, nil + } + + // Match body by atomic string comparison + bodyStr := castToString(body) + matchStr := castToString(ereq.BodyBuffer) + if bodyStr == matchStr { + return true, nil + } + + // Match request body by regexp + match, _ := regexp.MatchString(matchStr, bodyStr) + if match == true { + return true, nil + } + + // todo - add conditional do only perform the conversion of body bytes + // representation of JSON to a map and then compare them for equality. + + // Check if the key + value pairs match + var bodyMap map[string]interface{} + var matchMap map[string]interface{} + + // Ensure that both byte bodies that that should be JSON can be converted to maps. + umErr := json.Unmarshal(body, &bodyMap) + umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap) + if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) { + return true, nil + } + + return false, nil +} + +func (g *Gock) supportedType(req *http.Request, ereq *Request) bool { + mime := req.Header.Get("Content-Type") + if mime == "" { + return true + } + + mimeToMatch := ereq.Header.Get("Content-Type") + if mimeToMatch != "" { + return mime == mimeToMatch + } + + for _, kind := range g.BodyTypes { + if match, _ := regexp.MatchString(kind, mime); match { + return true + } + } + return false +} + +func (g *Gock) supportedCompressionScheme(req *http.Request) bool { + encoding := req.Header.Get("Content-Encoding") + if encoding == "" { + return true + } + + for _, kind := range g.CompressionSchemes { + if match, _ := regexp.MatchString(kind, encoding); match { + return true + } + } + return false +} + +func castToString(buf []byte) string { + str := string(buf) + tail := len(str) - 1 + if str[tail] == EOL { + str = str[:tail] + } + return str +} + +func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) { + switch scheme { + case "gzip": + return gzip.NewReader(r) + default: + return r, nil + } +} diff --git a/threadsafe/matchers_test.go b/threadsafe/matchers_test.go new file mode 100644 index 0000000..6db6c08 --- /dev/null +++ b/threadsafe/matchers_test.go @@ -0,0 +1,253 @@ +package threadsafe + +import ( + "net/http" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestMatchMethod(t *testing.T) { + cases := []struct { + value string + method string + matches bool + }{ + {"GET", "GET", true}, + {"POST", "POST", true}, + {"", "POST", true}, + {"POST", "GET", false}, + {"PUT", "GET", false}, + } + + for _, test := range cases { + req := &http.Request{Method: test.method} + ereq := &Request{Method: test.value} + matches, err := NewGock().MatchMethod(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchScheme(t *testing.T) { + cases := []struct { + value string + scheme string + matches bool + }{ + {"http", "http", true}, + {"https", "https", true}, + {"http", "https", false}, + {"", "https", true}, + {"https", "", true}, + } + + for _, test := range cases { + req := &http.Request{URL: &url.URL{Scheme: test.scheme}} + ereq := &Request{URLStruct: &url.URL{Scheme: test.value}} + matches, err := NewGock().MatchScheme(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchHost(t *testing.T) { + cases := []struct { + value string + url string + matches bool + matchesNonRegexp bool + }{ + {"foo.com", "foo.com", true, true}, + {"FOO.com", "foo.com", true, true}, + {"foo.net", "foo.com", false, false}, + {"foo.bar.net", "foo-bar.net", true, false}, + {"foo", "foo.com", true, false}, + {"(.*).com", "foo.com", true, false}, + {"127.0.0.1", "127.0.0.1", true, true}, + {"127.0.0.2", "127.0.0.1", false, false}, + {"127.0.0.*", "127.0.0.1", true, false}, + {"127.0.0.[0-9]", "127.0.0.7", true, false}, + } + + for _, test := range cases { + req := &http.Request{URL: &url.URL{Host: test.url}} + ereq := &Request{URLStruct: &url.URL{Host: test.value}} + matches, err := NewGock().MatchHost(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + ereq.WithOptions(Options{DisableRegexpHost: true}) + matches, err = NewGock().MatchHost(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matchesNonRegexp) + } +} + +func TestMatchPath(t *testing.T) { + cases := []struct { + value string + path string + matches bool + }{ + {"/foo", "/foo", true}, + {"/foo", "/foo/bar", true}, + {"bar", "/foo/bar", true}, + {"foo", "/foo/bar", true}, + {"bar$", "/foo/bar", true}, + {"/foo/*", "/foo/bar", true}, + {"/foo/[a-z]+", "/foo/bar", true}, + {"/foo/baz", "/foo/bar", false}, + {"/foo/baz", "/foo/bar", false}, + {"/foo/bar%3F+%C3%A9", "/foo/bar%3F+%C3%A9", true}, + } + + for _, test := range cases { + u, _ := url.Parse("http://foo.com" + test.path) + mu, _ := url.Parse("http://foo.com" + test.value) + req := &http.Request{URL: u} + ereq := &Request{URLStruct: mu} + matches, err := NewGock().MatchPath(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchHeaders(t *testing.T) { + cases := []struct { + values http.Header + headers http.Header + matches bool + }{ + {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, true}, + {http.Header{"foo": []string{"bar"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"bar": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false}, + {http.Header{"foofoo": []string{"bar"}}, http.Header{"foo": []string{"bar"}}, false}, + {http.Header{"foo": []string{"bar(.*)"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"foo": []string{"b(.*)"}}, http.Header{"foo": []string{"barbar"}}, true}, + {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"bar"}}, true}, + {http.Header{"foo": []string{"^bar$"}}, http.Header{"foo": []string{"barbar"}}, false}, + {http.Header{"UPPERCASE": []string{"bar"}}, http.Header{"UPPERCASE": []string{"bar"}}, true}, + {http.Header{"Mixed-CASE": []string{"bar"}}, http.Header{"Mixed-CASE": []string{"bar"}}, true}, + {http.Header{"User-Agent": []string{"Agent (version1.0)"}}, http.Header{"User-Agent": []string{"Agent (version1.0)"}}, true}, + {http.Header{"Content-Type": []string{"(.*)/plain"}}, http.Header{"Content-Type": []string{"text/plain"}}, true}, + } + + for _, test := range cases { + req := &http.Request{Header: test.headers} + ereq := &Request{Header: test.values} + matches, err := NewGock().MatchHeaders(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchQueryParams(t *testing.T) { + cases := []struct { + value string + path string + matches bool + }{ + {"foo=bar", "foo=bar", true}, + {"foo=bar", "foo=foo&foo=bar", true}, + {"foo=b*", "foo=bar", true}, + {"foo=.*", "foo=bar", true}, + {"foo=f[o]{2}", "foo=foo", true}, + {"foo=bar&bar=foo", "foo=bar&foo=foo&bar=foo", true}, + {"foo=", "foo=bar", true}, + {"foo=foo", "foo=bar", false}, + {"bar=bar", "foo=bar bar", false}, + } + + for _, test := range cases { + u, _ := url.Parse("http://foo.com/?" + test.path) + mu, _ := url.Parse("http://foo.com/?" + test.value) + req := &http.Request{URL: u} + ereq := &Request{URLStruct: mu} + matches, err := NewGock().MatchQueryParams(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchPathParams(t *testing.T) { + cases := []struct { + key string + value string + path string + matches bool + }{ + {"foo", "bar", "/foo/bar", true}, + {"foo", "bar", "/foo/test/bar", false}, + {"foo", "bar", "/test/foo/bar/ack", true}, + {"foo", "bar", "/foo", false}, + } + + for i, test := range cases { + u, _ := url.Parse("http://foo.com" + test.path) + mu, _ := url.Parse("http://foo.com" + test.path) + req := &http.Request{URL: u} + ereq := &Request{ + URLStruct: mu, + PathParams: map[string]string{test.key: test.value}, + } + matches, err := NewGock().MatchPathParams(req, ereq) + st.Expect(t, err, nil, i) + st.Expect(t, matches, test.matches, i) + } +} + +func TestMatchBody(t *testing.T) { + cases := []struct { + value string + body string + matches bool + }{ + {"foo bar", "foo bar\n", true}, + {"foo", "foo bar\n", true}, + {"f[o]+", "foo\n", true}, + {`"foo"`, `{"foo":"bar"}\n`, true}, + {`{"foo":"bar"}`, `{"foo":"bar"}\n`, true}, + {`{"foo":"foo"}`, `{"foo":"bar"}\n`, false}, + + {`{"foo":"bar","bar":"foo"}`, `{"bar":"foo","foo":"bar"}`, true}, + {`{"bar":"foo","foo":{"two":"three","three":"two"}}`, `{"foo":{"three":"two","two":"three"},"bar":"foo"}`, true}, + } + + g := NewGock() + for _, test := range cases { + req := &http.Request{Body: createReadCloser([]byte(test.body))} + ereq := &Request{BodyBuffer: []byte(test.value)} + matches, err := g.MatchBody(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} + +func TestMatchBody_MatchType(t *testing.T) { + body := `{"foo":"bar"}` + cases := []struct { + body string + requestContentType string + customBodyType string + matches bool + }{ + {body, "application/vnd.apiname.v1+json", "foobar", false}, + {body, "application/vnd.apiname.v1+json", "application/vnd.apiname.v1+json", true}, + {body, "application/json", "foobar", false}, + {body, "application/json", "", true}, + {"", "", "", true}, + } + + g := NewGock() + for _, test := range cases { + req := &http.Request{ + Header: http.Header{"Content-Type": []string{test.requestContentType}}, + Body: createReadCloser([]byte(test.body)), + } + ereq := g.NewRequest().BodyString(test.body).MatchType(test.customBodyType) + matches, err := g.MatchBody(req, ereq) + st.Expect(t, err, nil) + st.Expect(t, matches, test.matches) + } +} diff --git a/threadsafe/mock.go b/threadsafe/mock.go new file mode 100644 index 0000000..004263a --- /dev/null +++ b/threadsafe/mock.go @@ -0,0 +1,172 @@ +package threadsafe + +import ( + "net/http" + "sync" +) + +// Mock represents the required interface that must +// be implemented by HTTP mock instances. +type Mock interface { + // Disable disables the current mock manually. + Disable() + + // Done returns true if the current mock is disabled. + Done() bool + + // Request returns the mock Request instance. + Request() *Request + + // Response returns the mock Response instance. + Response() *Response + + // Match matches the given http.Request with the current mock. + Match(*http.Request) (bool, error) + + // AddMatcher adds a new matcher function. + AddMatcher(MatchFunc) + + // SetMatcher uses a new matcher implementation. + SetMatcher(Matcher) +} + +// Mocker implements a Mock capable interface providing +// a default mock configuration used internally to store mocks. +type Mocker struct { + // disabler stores a disabler for thread safety checking current mock is disabled + disabler *disabler + + // mutex stores the mock mutex for thread safety. + mutex sync.Mutex + + // matcher stores a Matcher capable instance to match the given http.Request. + matcher Matcher + + // request stores the mock Request to match. + request *Request + + // response stores the mock Response to use in case of match. + response *Response +} + +type disabler struct { + // disabled stores if the current mock is disabled. + disabled bool + + // mutex stores the disabler mutex for thread safety. + mutex sync.RWMutex +} + +func (d *disabler) isDisabled() bool { + d.mutex.RLock() + defer d.mutex.RUnlock() + return d.disabled +} + +func (d *disabler) Disable() { + d.mutex.Lock() + defer d.mutex.Unlock() + d.disabled = true +} + +// NewMock creates a new HTTP mock based on the given request and response instances. +// It's mostly used internally. +func (g *Gock) NewMock(req *Request, res *Response) *Mocker { + mock := &Mocker{ + disabler: new(disabler), + request: req, + response: res, + matcher: g.DefaultMatcher.Clone(), + } + res.Mock = mock + req.Mock = mock + req.Response = res + return mock +} + +// Disable disables the current mock manually. +func (m *Mocker) Disable() { + m.disabler.Disable() +} + +// Done returns true in case that the current mock +// instance is disabled and therefore must be removed. +func (m *Mocker) Done() bool { + // prevent deadlock with m.mutex + if m.disabler.isDisabled() { + return true + } + + m.mutex.Lock() + defer m.mutex.Unlock() + return !m.request.Persisted && m.request.Counter == 0 +} + +// Request returns the Request instance +// configured for the current HTTP mock. +func (m *Mocker) Request() *Request { + return m.request +} + +// Response returns the Response instance +// configured for the current HTTP mock. +func (m *Mocker) Response() *Response { + return m.response +} + +// Match matches the given http.Request with the current Request +// mock expectation, returning true if matches. +func (m *Mocker) Match(req *http.Request) (bool, error) { + if m.disabler.isDisabled() { + return false, nil + } + + // Filter + for _, filter := range m.request.Filters { + if !filter(req) { + return false, nil + } + } + + // Map + for _, mapper := range m.request.Mappers { + if treq := mapper(req); treq != nil { + req = treq + } + } + + // Match + matches, err := m.matcher.Match(req, m.request) + if matches { + m.decrement() + } + + return matches, err +} + +// SetMatcher sets a new matcher implementation +// for the current mock expectation. +func (m *Mocker) SetMatcher(matcher Matcher) { + m.matcher = matcher +} + +// AddMatcher adds a new matcher function +// for the current mock expectation. +func (m *Mocker) AddMatcher(fn MatchFunc) { + m.matcher.Add(fn) +} + +// decrement decrements the current mock Request counter. +func (m *Mocker) decrement() { + if m.request.Persisted { + return + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + m.request.Counter-- + if m.request.Counter == 0 { + m.disabler.Disable() + } +} diff --git a/threadsafe/mock_test.go b/threadsafe/mock_test.go new file mode 100644 index 0000000..f277842 --- /dev/null +++ b/threadsafe/mock_test.go @@ -0,0 +1,143 @@ +package threadsafe + +import ( + "net/http" + "testing" + + "github.com/nbio/st" +) + +func TestNewMock(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + + st.Expect(t, mock.Request(), req) + st.Expect(t, mock.Request().Mock, mock) + st.Expect(t, mock.Response(), res) + st.Expect(t, mock.Response().Mock, mock) +} + +func TestMockDisable(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, mock.disabler.isDisabled(), false) + mock.Disable() + st.Expect(t, mock.disabler.isDisabled(), true) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, false) +} + +func TestMockDone(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + + mock := g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.Done(), false) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.Disable() + st.Expect(t, mock.Done(), true) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.request.Counter = 0 + st.Expect(t, mock.Done(), true) + + mock = g.NewMock(req, res) + st.Expect(t, mock.disabler.isDisabled(), false) + mock.request.Persisted = true + st.Expect(t, mock.Done(), false) +} + +func TestMockSetMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + matcher := g.NewMatcher() + matcher.Flush() + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + mock.SetMatcher(matcher) + st.Expect(t, len(mock.matcher.Get()), 1) + st.Expect(t, mock.disabler.isDisabled(), false) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, true) +} + +func TestMockAddMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + st.Expect(t, len(mock.matcher.Get()), len(g.DefaultMatcher.Get())) + matcher := g.NewMatcher() + matcher.Flush() + mock.SetMatcher(matcher) + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return true, nil + }) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.matcher, matcher) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, matches, true) +} + +func TestMockMatch(t *testing.T) { + g := NewGock() + defer after(g) + + req := g.NewRequest() + res := g.NewResponse() + mock := g.NewMock(req, res) + + matcher := g.NewMatcher() + matcher.Flush() + mock.SetMatcher(matcher) + calls := 0 + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + calls++ + return true, nil + }) + mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + calls++ + return true, nil + }) + st.Expect(t, mock.disabler.isDisabled(), false) + st.Expect(t, mock.matcher, matcher) + + matches, err := mock.Match(&http.Request{}) + st.Expect(t, err, nil) + st.Expect(t, calls, 2) + st.Expect(t, matches, true) +} diff --git a/threadsafe/options.go b/threadsafe/options.go new file mode 100644 index 0000000..98497f9 --- /dev/null +++ b/threadsafe/options.go @@ -0,0 +1,8 @@ +package threadsafe + +// Options represents customized option for gock +type Options struct { + // DisableRegexpHost stores if the host is only a plain string rather than regular expression, + // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string + DisableRegexpHost bool +} diff --git a/threadsafe/request.go b/threadsafe/request.go new file mode 100644 index 0000000..3508bbb --- /dev/null +++ b/threadsafe/request.go @@ -0,0 +1,330 @@ +package threadsafe + +import ( + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" +) + +// MapRequestFunc represents the required function interface for request mappers. +type MapRequestFunc func(*http.Request) *http.Request + +// FilterRequestFunc represents the required function interface for request filters. +type FilterRequestFunc func(*http.Request) bool + +// Request represents the high-level HTTP request used to store +// request fields used to match intercepted requests. +type Request struct { + g *Gock + + // Mock stores the parent mock reference for the current request mock used for method delegation. + Mock Mock + + // Response stores the current Response instance for the current matches Request. + Response *Response + + // Error stores the latest mock request configuration error. + Error error + + // Counter stores the pending times that the current mock should be active. + Counter int + + // Persisted stores if the current mock should be always active. + Persisted bool + + // Options stores options for current Request. + Options Options + + // URLStruct stores the parsed URL as *url.URL struct. + URLStruct *url.URL + + // Method stores the Request HTTP method to match. + Method string + + // CompressionScheme stores the Request Compression scheme to match and use for decompression. + CompressionScheme string + + // Header stores the HTTP header fields to match. + Header http.Header + + // Cookies stores the Request HTTP cookies values to match. + Cookies []*http.Cookie + + // PathParams stores the path parameters to match. + PathParams map[string]string + + // BodyBuffer stores the body data to match. + BodyBuffer []byte + + // Mappers stores the request functions mappers used for matching. + Mappers []MapRequestFunc + + // Filters stores the request functions filters used for matching. + Filters []FilterRequestFunc +} + +// NewRequest creates a new Request instance. +func (g *Gock) NewRequest() *Request { + return &Request{ + g: g, + Counter: 1, + URLStruct: &url.URL{}, + Header: make(http.Header), + PathParams: make(map[string]string), + } +} + +// URL defines the mock URL to match. +func (r *Request) URL(uri string) *Request { + r.URLStruct, r.Error = url.Parse(uri) + return r +} + +// SetURL defines the url.URL struct to be used for matching. +func (r *Request) SetURL(u *url.URL) *Request { + r.URLStruct = u + return r +} + +// Path defines the mock URL path value to match. +func (r *Request) Path(path string) *Request { + r.URLStruct.Path = path + return r +} + +// Get specifies the GET method and the given URL path to match. +func (r *Request) Get(path string) *Request { + return r.method("GET", path) +} + +// Post specifies the POST method and the given URL path to match. +func (r *Request) Post(path string) *Request { + return r.method("POST", path) +} + +// Put specifies the PUT method and the given URL path to match. +func (r *Request) Put(path string) *Request { + return r.method("PUT", path) +} + +// Delete specifies the DELETE method and the given URL path to match. +func (r *Request) Delete(path string) *Request { + return r.method("DELETE", path) +} + +// Patch specifies the PATCH method and the given URL path to match. +func (r *Request) Patch(path string) *Request { + return r.method("PATCH", path) +} + +// Head specifies the HEAD method and the given URL path to match. +func (r *Request) Head(path string) *Request { + return r.method("HEAD", path) +} + +// method is a DRY shortcut used to declare the expected HTTP method and URL path. +func (r *Request) method(method, path string) *Request { + if path != "/" { + r.URLStruct.Path = path + } + r.Method = strings.ToUpper(method) + return r +} + +// Body defines the body data to match based on a io.Reader interface. +func (r *Request) Body(body io.Reader) *Request { + r.BodyBuffer, r.Error = ioutil.ReadAll(body) + return r +} + +// BodyString defines the body to match based on a given string. +func (r *Request) BodyString(body string) *Request { + r.BodyBuffer = []byte(body) + return r +} + +// File defines the body to match based on the given file path string. +func (r *Request) File(path string) *Request { + r.BodyBuffer, r.Error = ioutil.ReadFile(path) + return r +} + +// Compression defines the request compression scheme, and enables automatic body decompression. +// Supports only the "gzip" scheme so far. +func (r *Request) Compression(scheme string) *Request { + r.Header.Set("Content-Encoding", scheme) + r.CompressionScheme = scheme + return r +} + +// JSON defines the JSON body to match based on a given structure. +func (r *Request) JSON(data interface{}) *Request { + if r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "application/json") + } + r.BodyBuffer, r.Error = readAndDecode(data, "json") + return r +} + +// XML defines the XML body to match based on a given structure. +func (r *Request) XML(data interface{}) *Request { + if r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "application/xml") + } + r.BodyBuffer, r.Error = readAndDecode(data, "xml") + return r +} + +// MatchType defines the request Content-Type MIME header field. +// Supports custom MIME types and type aliases. E.g: json, xml, form, text... +func (r *Request) MatchType(kind string) *Request { + mime := r.g.BodyTypeAliases[kind] + if mime != "" { + kind = mime + } + r.Header.Set("Content-Type", kind) + return r +} + +// BasicAuth defines a username and password for HTTP Basic Authentication +func (r *Request) BasicAuth(username, password string) *Request { + r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) + return r +} + +// MatchHeader defines a new key and value header to match. +func (r *Request) MatchHeader(key, value string) *Request { + r.Header.Set(key, value) + return r +} + +// HeaderPresent defines that a header field must be present in the request. +func (r *Request) HeaderPresent(key string) *Request { + r.Header.Set(key, ".*") + return r +} + +// MatchHeaders defines a map of key-value headers to match. +func (r *Request) MatchHeaders(headers map[string]string) *Request { + for key, value := range headers { + r.Header.Set(key, value) + } + return r +} + +// MatchParam defines a new key and value URL query param to match. +func (r *Request) MatchParam(key, value string) *Request { + query := r.URLStruct.Query() + query.Set(key, value) + r.URLStruct.RawQuery = query.Encode() + return r +} + +// MatchParams defines a map of URL query param key-value to match. +func (r *Request) MatchParams(params map[string]string) *Request { + query := r.URLStruct.Query() + for key, value := range params { + query.Set(key, value) + } + r.URLStruct.RawQuery = query.Encode() + return r +} + +// ParamPresent matches if the given query param key is present in the URL. +func (r *Request) ParamPresent(key string) *Request { + r.MatchParam(key, ".*") + return r +} + +// PathParam matches if a given path parameter key is present in the URL. +// +// The value is representative of the restful resource the key defines, e.g. +// +// // /users/123/name +// r.PathParam("users", "123") +// +// would match. +func (r *Request) PathParam(key, val string) *Request { + r.PathParams[key] = val + + return r +} + +// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it. +func (r *Request) Persist() *Request { + r.Persisted = true + return r +} + +// WithOptions sets the options for the request. +func (r *Request) WithOptions(options Options) *Request { + r.Options = options + return r +} + +// Times defines the number of times that the current HTTP mock should remain active. +func (r *Request) Times(num int) *Request { + r.Counter = num + return r +} + +// AddMatcher adds a new matcher function to match the request. +func (r *Request) AddMatcher(fn MatchFunc) *Request { + r.Mock.AddMatcher(fn) + return r +} + +// SetMatcher sets a new matcher function to match the request. +func (r *Request) SetMatcher(matcher Matcher) *Request { + r.Mock.SetMatcher(matcher) + return r +} + +// Map adds a new request mapper function to map http.Request before the matching process. +func (r *Request) Map(fn MapRequestFunc) *Request { + r.Mappers = append(r.Mappers, fn) + return r +} + +// Filter filters a new request filter function to filter http.Request before the matching process. +func (r *Request) Filter(fn FilterRequestFunc) *Request { + r.Filters = append(r.Filters, fn) + return r +} + +// EnableNetworking enables the use real networking for the current mock. +func (r *Request) EnableNetworking() *Request { + if r.Response != nil { + r.Response.UseNetwork = true + } + return r +} + +// Reply defines the Response status code and returns the mock Response DSL. +func (r *Request) Reply(status int) *Response { + return r.Response.Status(status) +} + +// ReplyError defines the Response simulated error. +func (r *Request) ReplyError(err error) *Response { + return r.Response.SetError(err) +} + +// ReplyFunc allows the developer to define the mock response via a custom function. +func (r *Request) ReplyFunc(replier func(*Response)) *Response { + replier(r.Response) + return r.Response +} + +// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt +// "To receive authorization, the client sends the userid and password, +// separated by a single colon (":") character, within a base64 +// encoded string in the credentials." +// It is not meant to be urlencoded. +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} diff --git a/threadsafe/request_test.go b/threadsafe/request_test.go new file mode 100644 index 0000000..e67614f --- /dev/null +++ b/threadsafe/request_test.go @@ -0,0 +1,318 @@ +package threadsafe + +import ( + "bytes" + "net/http" + "net/url" + "path/filepath" + "testing" + + "github.com/nbio/st" +) + +func TestNewRequest(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + st.Expect(t, req.URLStruct.Host, "foo.com") + st.Expect(t, req.URLStruct.Scheme, "http") + req.MatchHeader("foo", "bar") + st.Expect(t, req.Header.Get("foo"), "bar") +} + +func TestRequestSetURL(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + req.SetURL(&url.URL{Host: "bar.com", Path: "/foo"}) + st.Expect(t, req.URLStruct.Host, "bar.com") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestPath(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.URL("http://foo.com") + req.Path("/foo") + st.Expect(t, req.URLStruct.Scheme, "http") + st.Expect(t, req.URLStruct.Host, "foo.com") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestBody(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Body(bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, string(req.BodyBuffer), "foo bar") +} + +func TestRequestBodyString(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.BodyString("foo bar") + st.Expect(t, string(req.BodyBuffer), "foo bar") +} + +func TestRequestFile(t *testing.T) { + g := NewGock() + req := g.NewRequest() + absPath, err := filepath.Abs("../version.go") + st.Expect(t, err, nil) + req.File(absPath) + st.Expect(t, string(req.BodyBuffer)[:12], "package gock") +} + +func TestRequestJSON(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.JSON(map[string]string{"foo": "bar"}) + st.Expect(t, string(req.BodyBuffer)[:13], `{"foo":"bar"}`) + st.Expect(t, req.Header.Get("Content-Type"), "application/json") +} + +func TestRequestXML(t *testing.T) { + g := NewGock() + req := g.NewRequest() + type xml struct { + Data string `xml:"data"` + } + req.XML(xml{Data: "foo"}) + st.Expect(t, string(req.BodyBuffer), `foo`) + st.Expect(t, req.Header.Get("Content-Type"), "application/xml") +} + +func TestRequestMatchType(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchType("json") + st.Expect(t, req.Header.Get("Content-Type"), "application/json") + + req = g.NewRequest() + req.MatchType("html") + st.Expect(t, req.Header.Get("Content-Type"), "text/html") + + req = g.NewRequest() + req.MatchType("foo/bar") + st.Expect(t, req.Header.Get("Content-Type"), "foo/bar") +} + +func TestRequestBasicAuth(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.BasicAuth("bob", "qwerty") + st.Expect(t, req.Header.Get("Authorization"), "Basic Ym9iOnF3ZXJ0eQ==") +} + +func TestRequestMatchHeader(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchHeader("foo", "bar") + req.MatchHeader("bar", "baz") + req.MatchHeader("UPPERCASE", "bat") + req.MatchHeader("Mixed-CASE", "foo") + + st.Expect(t, req.Header.Get("foo"), "bar") + st.Expect(t, req.Header.Get("bar"), "baz") + st.Expect(t, req.Header.Get("UPPERCASE"), "bat") + st.Expect(t, req.Header.Get("Mixed-CASE"), "foo") +} + +func TestRequestHeaderPresent(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.HeaderPresent("foo") + req.HeaderPresent("bar") + req.HeaderPresent("UPPERCASE") + req.HeaderPresent("Mixed-CASE") + st.Expect(t, req.Header.Get("foo"), ".*") + st.Expect(t, req.Header.Get("bar"), ".*") + st.Expect(t, req.Header.Get("UPPERCASE"), ".*") + st.Expect(t, req.Header.Get("Mixed-CASE"), ".*") +} + +func TestRequestMatchParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchParam("foo", "bar") + req.MatchParam("bar", "baz") + st.Expect(t, req.URLStruct.Query().Get("foo"), "bar") + st.Expect(t, req.URLStruct.Query().Get("bar"), "baz") +} + +func TestRequestMatchParams(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.MatchParams(map[string]string{"foo": "bar", "bar": "baz"}) + st.Expect(t, req.URLStruct.Query().Get("foo"), "bar") + st.Expect(t, req.URLStruct.Query().Get("bar"), "baz") +} + +func TestRequestPresentParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.ParamPresent("key") + st.Expect(t, req.URLStruct.Query().Get("key"), ".*") +} + +func TestRequestPathParam(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.PathParam("key", "value") + st.Expect(t, req.PathParams["key"], "value") +} + +func TestRequestPersist(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, req.Persisted, false) + req.Persist() + st.Expect(t, req.Persisted, true) +} + +func TestRequestTimes(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, req.Counter, 1) + req.Times(3) + st.Expect(t, req.Counter, 3) +} + +func TestRequestMap(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, len(req.Mappers), 0) + req.Map(func(req *http.Request) *http.Request { + return req + }) + st.Expect(t, len(req.Mappers), 1) +} + +func TestRequestFilter(t *testing.T) { + g := NewGock() + req := g.NewRequest() + st.Expect(t, len(req.Filters), 0) + req.Filter(func(req *http.Request) bool { + return true + }) + st.Expect(t, len(req.Filters), 1) +} + +func TestRequestEnableNetworking(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Response = &Response{} + st.Expect(t, req.Response.UseNetwork, false) + req.EnableNetworking() + st.Expect(t, req.Response.UseNetwork, true) +} + +func TestRequestResponse(t *testing.T) { + g := NewGock() + req := g.NewRequest() + res := g.NewResponse() + req.Response = res + chain := req.Reply(200) + st.Expect(t, chain, res) + st.Expect(t, chain.StatusCode, 200) +} + +func TestRequestReplyFunc(t *testing.T) { + g := NewGock() + req := g.NewRequest() + res := g.NewResponse() + req.Response = res + chain := req.ReplyFunc(func(r *Response) { + r.Status(204) + }) + st.Expect(t, chain, res) + st.Expect(t, chain.StatusCode, 204) +} + +func TestRequestMethods(t *testing.T) { + g := NewGock() + req := g.NewRequest() + req.Get("/foo") + st.Expect(t, req.Method, "GET") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Post("/foo") + st.Expect(t, req.Method, "POST") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Put("/foo") + st.Expect(t, req.Method, "PUT") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Delete("/foo") + st.Expect(t, req.Method, "DELETE") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Patch("/foo") + st.Expect(t, req.Method, "PATCH") + st.Expect(t, req.URLStruct.Path, "/foo") + + req = g.NewRequest() + req.Head("/foo") + st.Expect(t, req.Method, "HEAD") + st.Expect(t, req.URLStruct.Path, "/foo") +} + +func TestRequestSetMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + matcher := g.NewEmptyMatcher() + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { + return req.Header.Get("foo") == "bar", nil + }) + ereq := g.NewRequest() + mock := g.NewMock(ereq, &Response{}) + mock.SetMatcher(matcher) + ereq.Mock = mock + + headers := make(http.Header) + headers.Set("foo", "bar") + req := &http.Request{ + URL: &url.URL{Host: "foo.com", Path: "/bar"}, + Header: headers, + } + + match, err := ereq.Mock.Match(req) + st.Expect(t, err, nil) + st.Expect(t, match, true) +} + +func TestRequestAddMatcher(t *testing.T) { + g := NewGock() + defer after(g) + + ereq := g.NewRequest() + mock := g.NewMock(ereq, &Response{}) + mock.matcher = g.NewMatcher() + ereq.Mock = mock + + ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return req.URL.Host == "foo.com", nil + }) + ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { + return req.Header.Get("foo") == "bar", nil + }) + + headers := make(http.Header) + headers.Set("foo", "bar") + req := &http.Request{ + URL: &url.URL{Host: "foo.com", Path: "/bar"}, + Header: headers, + } + + match, err := ereq.Mock.Match(req) + st.Expect(t, err, nil) + st.Expect(t, match, true) +} diff --git a/threadsafe/responder.go b/threadsafe/responder.go new file mode 100644 index 0000000..5dc4a7d --- /dev/null +++ b/threadsafe/responder.go @@ -0,0 +1,111 @@ +package threadsafe + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "strconv" + "time" +) + +// Responder builds a mock http.Response based on the given Response mock. +func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) { + // If error present, reply it + err := mock.Error + if err != nil { + return nil, err + } + + if res == nil { + res = createResponse(req) + } + + // Apply response filter + for _, filter := range mock.Filters { + if !filter(res) { + return res, nil + } + } + + // Define mock status code + if mock.StatusCode != 0 { + res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode) + res.StatusCode = mock.StatusCode + } + + // Define headers by merging fields + res.Header = mergeHeaders(res, mock) + + // Define mock body, if present + if len(mock.BodyBuffer) > 0 { + res.ContentLength = int64(len(mock.BodyBuffer)) + res.Body = createReadCloser(mock.BodyBuffer) + } + + // Set raw mock body, if exist + if mock.BodyGen != nil { + res.ContentLength = -1 + res.Body = mock.BodyGen() + } + + // Apply response mappers + for _, mapper := range mock.Mappers { + if tres := mapper(res); tres != nil { + res = tres + } + } + + // Sleep to simulate delay, if necessary + if mock.ResponseDelay > 0 { + // allow escaping from sleep due to request context expiration or cancellation + t := time.NewTimer(mock.ResponseDelay) + select { + case <-t.C: + case <-req.Context().Done(): + // cleanly stop the timer + if !t.Stop() { + <-t.C + } + } + } + + // check if the request context has ended. we could put this up in the delay code above, but putting it here + // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.) + if err = req.Context().Err(); err != nil { + // cleanly close the response and return the context error + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + return nil, err + } + + return res, err +} + +// createResponse creates a new http.Response with default fields. +func createResponse(req *http.Request) *http.Response { + return &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Proto: "HTTP/1.1", + Request: req, + Header: make(http.Header), + Body: createReadCloser([]byte{}), + } +} + +// mergeHeaders copies the mock headers. +func mergeHeaders(res *http.Response, mres *Response) http.Header { + for key, values := range mres.Header { + for _, value := range values { + res.Header.Add(key, value) + } + } + return res.Header +} + +// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an +// http response body. +func createReadCloser(body []byte) io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader(body)) +} diff --git a/threadsafe/responder_test.go b/threadsafe/responder_test.go new file mode 100644 index 0000000..7d18820 --- /dev/null +++ b/threadsafe/responder_test.go @@ -0,0 +1,191 @@ +package threadsafe + +import ( + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + "github.com/nbio/st" +) + +func TestResponder(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Reply(200).BodyString("foo") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") +} + +func TestResponder_ReadTwice(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Reply(200).BodyString("foo") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderBodyGenerator(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") +} + +func TestResponderBodyGenerator_ReadTwice(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderBodyGenerator_Override(t *testing.T) { + g := NewGock() + defer after(g) + generator := func() io.ReadCloser { + return io.NopCloser(strings.NewReader("foo")) + } + mres := g.New("http://foo.com").Reply(200).BodyGenerator(generator).BodyString("bar") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Status, "200 OK") + st.Expect(t, res.StatusCode, 200) + + body, _ := ioutil.ReadAll(res.Body) + st.Expect(t, string(body), "foo") + + body, err = ioutil.ReadAll(res.Body) + st.Expect(t, err, nil) + st.Expect(t, body, []byte{}) +} + +func TestResponderSupportsMultipleHeadersWithSameKey(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo"). + Reply(200). + AddHeader("Set-Cookie", "a=1"). + AddHeader("Set-Cookie", "b=2") + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err, nil) + st.Expect(t, res.Header, http.Header{"Set-Cookie": []string{"a=1", "b=2"}}) +} + +func TestResponderError(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").ReplyError(errors.New("error")) + req := &http.Request{} + + res, err := Responder(req, mres, nil) + st.Expect(t, err.Error(), "error") + st.Expect(t, res == nil, true) +} + +func TestResponderCancelledContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo") + + // create a context and schedule a call to cancel in 10ms + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.Canceled) + st.Expect(t, res == nil, true) +} + +func TestResponderExpiredContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).Delay(20 * time.Millisecond).BodyString("foo") + + // create a context that is set to expire in 10ms + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.DeadlineExceeded) + st.Expect(t, res == nil, true) +} + +func TestResponderPreExpiredContext(t *testing.T) { + g := NewGock() + defer after(g) + mres := g.New("http://foo.com").Get("").Reply(200).BodyString("foo") + + // create a context and wait to ensure it is expired + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Microsecond) + defer cancel() + time.Sleep(1 * time.Millisecond) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://foo.com", nil) + + res, err := Responder(req, mres, nil) + + // verify that we got a context cancellation error and nil response + st.Expect(t, err, context.DeadlineExceeded) + st.Expect(t, res == nil, true) +} diff --git a/threadsafe/response.go b/threadsafe/response.go new file mode 100644 index 0000000..04de096 --- /dev/null +++ b/threadsafe/response.go @@ -0,0 +1,198 @@ +package threadsafe + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "io" + "io/ioutil" + "net/http" + "time" +) + +// MapResponseFunc represents the required function interface impletemed by response mappers. +type MapResponseFunc func(*http.Response) *http.Response + +// FilterResponseFunc represents the required function interface impletemed by response filters. +type FilterResponseFunc func(*http.Response) bool + +// Response represents high-level HTTP fields to configure +// and define HTTP responses intercepted by gock. +type Response struct { + g *Gock + + // Mock stores the parent mock reference for the current response mock used for method delegation. + Mock Mock + + // Error stores the latest response configuration or injected error. + Error error + + // UseNetwork enables the use of real network for the current mock. + UseNetwork bool + + // StatusCode stores the response status code. + StatusCode int + + // Headers stores the response headers. + Header http.Header + + // Cookies stores the response cookie fields. + Cookies []*http.Cookie + + // BodyGen stores a io.ReadCloser generator to be returned. + BodyGen func() io.ReadCloser + + // BodyBuffer stores the array of bytes to use as body. + BodyBuffer []byte + + // ResponseDelay stores the simulated response delay. + ResponseDelay time.Duration + + // Mappers stores the request functions mappers used for matching. + Mappers []MapResponseFunc + + // Filters stores the request functions filters used for matching. + Filters []FilterResponseFunc +} + +// NewResponse creates a new Response. +func (g *Gock) NewResponse() *Response { + return &Response{g: g, Header: make(http.Header)} +} + +// Status defines the desired HTTP status code to reply in the current response. +func (r *Response) Status(code int) *Response { + r.StatusCode = code + return r +} + +// Type defines the response Content-Type MIME header field. +// Supports type alias. E.g: json, xml, form, text... +func (r *Response) Type(kind string) *Response { + mime := r.g.BodyTypeAliases[kind] + if mime != "" { + kind = mime + } + r.Header.Set("Content-Type", kind) + return r +} + +// SetHeader sets a new header field in the mock response. +func (r *Response) SetHeader(key, value string) *Response { + r.Header.Set(key, value) + return r +} + +// AddHeader adds a new header field in the mock response +// with out removing an existent one. +func (r *Response) AddHeader(key, value string) *Response { + r.Header.Add(key, value) + return r +} + +// SetHeaders sets a map of header fields in the mock response. +func (r *Response) SetHeaders(headers map[string]string) *Response { + for key, value := range headers { + r.Header.Add(key, value) + } + return r +} + +// Body sets the HTTP response body to be used. +func (r *Response) Body(body io.Reader) *Response { + r.BodyBuffer, r.Error = ioutil.ReadAll(body) + return r +} + +// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser +// for every response. This will take priority than other Body methods used. +func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response { + r.BodyGen = generator + return r +} + +// BodyString defines the response body as string. +func (r *Response) BodyString(body string) *Response { + r.BodyBuffer = []byte(body) + return r +} + +// File defines the response body reading the data +// from disk based on the file path string. +func (r *Response) File(path string) *Response { + r.BodyBuffer, r.Error = ioutil.ReadFile(path) + return r +} + +// JSON defines the response body based on a JSON based input. +func (r *Response) JSON(data interface{}) *Response { + r.Header.Set("Content-Type", "application/json") + r.BodyBuffer, r.Error = readAndDecode(data, "json") + return r +} + +// XML defines the response body based on a XML based input. +func (r *Response) XML(data interface{}) *Response { + r.Header.Set("Content-Type", "application/xml") + r.BodyBuffer, r.Error = readAndDecode(data, "xml") + return r +} + +// SetError defines the response simulated error. +func (r *Response) SetError(err error) *Response { + r.Error = err + return r +} + +// Delay defines the response simulated delay. +// This feature is still experimental and will be improved in the future. +func (r *Response) Delay(delay time.Duration) *Response { + r.ResponseDelay = delay + return r +} + +// Map adds a new response mapper function to map http.Response before the matching process. +func (r *Response) Map(fn MapResponseFunc) *Response { + r.Mappers = append(r.Mappers, fn) + return r +} + +// Filter filters a new request filter function to filter http.Request before the matching process. +func (r *Response) Filter(fn FilterResponseFunc) *Response { + r.Filters = append(r.Filters, fn) + return r +} + +// EnableNetworking enables the use real networking for the current mock. +func (r *Response) EnableNetworking() *Response { + r.UseNetwork = true + return r +} + +// Done returns true if the mock was done and disabled. +func (r *Response) Done() bool { + return r.Mock.Done() +} + +func readAndDecode(data interface{}, kind string) ([]byte, error) { + buf := &bytes.Buffer{} + + switch data.(type) { + case string: + buf.WriteString(data.(string)) + case []byte: + buf.Write(data.([]byte)) + default: + var err error + if kind == "xml" { + err = xml.NewEncoder(buf).Encode(data) + } else { + err = json.NewEncoder(buf).Encode(data) + } + if err != nil { + return nil, err + } + } + + return ioutil.ReadAll(buf) +} diff --git a/threadsafe/response_test.go b/threadsafe/response_test.go new file mode 100644 index 0000000..0a44c63 --- /dev/null +++ b/threadsafe/response_test.go @@ -0,0 +1,186 @@ +package threadsafe + +import ( + "bytes" + "errors" + "io" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/nbio/st" +) + +func TestNewResponse(t *testing.T) { + g := NewGock() + res := g.NewResponse() + + res.Status(200) + st.Expect(t, res.StatusCode, 200) + + res.SetHeader("foo", "bar") + st.Expect(t, res.Header.Get("foo"), "bar") + + res.Delay(1000 * time.Millisecond) + st.Expect(t, res.ResponseDelay, 1000*time.Millisecond) + + res.EnableNetworking() + st.Expect(t, res.UseNetwork, true) +} + +func TestResponseStatus(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.StatusCode, 0) + res.Status(200) + st.Expect(t, res.StatusCode, 200) +} + +func TestResponseType(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Type("json") + st.Expect(t, res.Header.Get("Content-Type"), "application/json") + + res = g.NewResponse() + res.Type("xml") + st.Expect(t, res.Header.Get("Content-Type"), "application/xml") + + res = g.NewResponse() + res.Type("foo/bar") + st.Expect(t, res.Header.Get("Content-Type"), "foo/bar") +} + +func TestResponseSetHeader(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.SetHeader("foo", "bar") + res.SetHeader("bar", "baz") + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header.Get("bar"), "baz") +} + +func TestResponseAddHeader(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.AddHeader("foo", "bar") + res.AddHeader("foo", "baz") + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header["Foo"][1], "baz") +} + +func TestResponseSetHeaders(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.SetHeaders(map[string]string{"foo": "bar", "bar": "baz"}) + st.Expect(t, res.Header.Get("foo"), "bar") + st.Expect(t, res.Header.Get("bar"), "baz") +} + +func TestResponseBody(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Body(bytes.NewBuffer([]byte("foo bar"))) + st.Expect(t, string(res.BodyBuffer), "foo bar") +} + +func TestResponseBodyGenerator(t *testing.T) { + g := NewGock() + res := g.NewResponse() + generator := func() io.ReadCloser { + return io.NopCloser(bytes.NewBuffer([]byte("foo bar"))) + } + res.BodyGenerator(generator) + bytes, err := io.ReadAll(res.BodyGen()) + st.Expect(t, err, nil) + st.Expect(t, string(bytes), "foo bar") +} + +func TestResponseBodyString(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.BodyString("foo bar") + st.Expect(t, string(res.BodyBuffer), "foo bar") +} + +func TestResponseFile(t *testing.T) { + g := NewGock() + res := g.NewResponse() + absPath, err := filepath.Abs("../version.go") + st.Expect(t, err, nil) + res.File(absPath) + st.Expect(t, string(res.BodyBuffer)[:12], "package gock") +} + +func TestResponseJSON(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.JSON(map[string]string{"foo": "bar"}) + st.Expect(t, string(res.BodyBuffer)[:13], `{"foo":"bar"}`) + st.Expect(t, res.Header.Get("Content-Type"), "application/json") +} + +func TestResponseXML(t *testing.T) { + g := NewGock() + res := g.NewResponse() + type xml struct { + Data string `xml:"data"` + } + res.XML(xml{Data: "foo"}) + st.Expect(t, string(res.BodyBuffer), `foo`) + st.Expect(t, res.Header.Get("Content-Type"), "application/xml") +} + +func TestResponseMap(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, len(res.Mappers), 0) + res.Map(func(res *http.Response) *http.Response { + return res + }) + st.Expect(t, len(res.Mappers), 1) +} + +func TestResponseFilter(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, len(res.Filters), 0) + res.Filter(func(res *http.Response) bool { + return true + }) + st.Expect(t, len(res.Filters), 1) +} + +func TestResponseSetError(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.Error, nil) + res.SetError(errors.New("foo error")) + st.Expect(t, res.Error.Error(), "foo error") +} + +func TestResponseDelay(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.ResponseDelay, 0*time.Microsecond) + res.Delay(100 * time.Millisecond) + st.Expect(t, res.ResponseDelay, 100*time.Millisecond) +} + +func TestResponseEnableNetworking(t *testing.T) { + g := NewGock() + res := g.NewResponse() + st.Expect(t, res.UseNetwork, false) + res.EnableNetworking() + st.Expect(t, res.UseNetwork, true) +} + +func TestResponseDone(t *testing.T) { + g := NewGock() + res := g.NewResponse() + res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)} + st.Expect(t, res.Done(), false) + res.Mock.Disable() + st.Expect(t, res.Done(), true) +} diff --git a/threadsafe/store.go b/threadsafe/store.go new file mode 100644 index 0000000..d22a02e --- /dev/null +++ b/threadsafe/store.go @@ -0,0 +1,90 @@ +package threadsafe + +// Register registers a new mock in the current mocks stack. +func (g *Gock) Register(mock Mock) { + if g.Exists(mock) { + return + } + + // Make ops thread safe + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + + // Expose mock in request/response for delegation + mock.Request().Mock = mock + mock.Response().Mock = mock + + // Registers the mock in the global store + g.mocks = append(g.mocks, mock) +} + +// GetAll returns the current stack of registered mocks. +func (g *Gock) GetAll() []Mock { + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + return g.mocks +} + +// Exists checks if the given Mock is already registered. +func (g *Gock) Exists(m Mock) bool { + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + for _, mock := range g.mocks { + if mock == m { + return true + } + } + return false +} + +// Remove removes a registered mock by reference. +func (g *Gock) Remove(m Mock) { + for i, mock := range g.mocks { + if mock == m { + g.storeMutex.Lock() + g.mocks = append(g.mocks[:i], g.mocks[i+1:]...) + g.storeMutex.Unlock() + } + } +} + +// Flush flushes the current stack of registered mocks. +func (g *Gock) Flush() { + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + g.mocks = []Mock{} +} + +// Pending returns an slice of pending mocks. +func (g *Gock) Pending() []Mock { + g.Clean() + g.storeMutex.RLock() + defer g.storeMutex.RUnlock() + return g.mocks +} + +// IsDone returns true if all the registered mocks has been triggered successfully. +func (g *Gock) IsDone() bool { + return !g.IsPending() +} + +// IsPending returns true if there are pending mocks. +func (g *Gock) IsPending() bool { + return len(g.Pending()) > 0 +} + +// Clean cleans the mocks store removing disabled or obsolete mocks. +func (g *Gock) Clean() { + g.storeMutex.Lock() + defer g.storeMutex.Unlock() + + buf := []Mock{} + for _, mock := range g.mocks { + if mock.Done() { + continue + } + buf = append(buf, mock) + } + + g.mocks = buf +} diff --git a/threadsafe/store_test.go b/threadsafe/store_test.go new file mode 100644 index 0000000..e4081bb --- /dev/null +++ b/threadsafe/store_test.go @@ -0,0 +1,95 @@ +package threadsafe + +import ( + "testing" + + "github.com/nbio/st" +) + +func TestStoreRegister(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + g.Register(mock) + st.Expect(t, len(g.mocks), 1) + st.Expect(t, mock.Request().Mock, mock) + st.Expect(t, mock.Response().Mock, mock) +} + +func TestStoreGetAll(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + store := g.GetAll() + st.Expect(t, len(g.mocks), 1) + st.Expect(t, len(store), 1) + st.Expect(t, store[0], mock) +} + +func TestStoreExists(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + st.Expect(t, len(g.mocks), 1) + st.Expect(t, g.Exists(mock), true) +} + +func TestStorePending(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.mocks, g.Pending()) +} + +func TestStoreIsPending(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.IsPending(), true) + g.Flush() + st.Expect(t, g.IsPending(), false) +} + +func TestStoreIsDone(t *testing.T) { + g := NewGock() + defer after(g) + g.New("foo") + st.Expect(t, g.IsDone(), false) + g.Flush() + st.Expect(t, g.IsDone(), true) +} + +func TestStoreRemove(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + mock := g.New("foo").Mock + st.Expect(t, len(g.mocks), 1) + st.Expect(t, g.Exists(mock), true) + + g.Remove(mock) + st.Expect(t, g.Exists(mock), false) + + g.Remove(mock) + st.Expect(t, g.Exists(mock), false) +} + +func TestStoreFlush(t *testing.T) { + g := NewGock() + defer after(g) + st.Expect(t, len(g.mocks), 0) + + mock1 := g.New("foo").Mock + mock2 := g.New("foo").Mock + st.Expect(t, len(g.mocks), 2) + st.Expect(t, g.Exists(mock1), true) + st.Expect(t, g.Exists(mock2), true) + + g.Flush() + st.Expect(t, len(g.mocks), 0) + st.Expect(t, g.Exists(mock1), false) + st.Expect(t, g.Exists(mock2), false) +} diff --git a/threadsafe/transport.go b/threadsafe/transport.go new file mode 100644 index 0000000..ee1af5e --- /dev/null +++ b/threadsafe/transport.go @@ -0,0 +1,112 @@ +package threadsafe + +import ( + "errors" + "net/http" + "sync" +) + +var ( + // ErrCannotMatch store the error returned in case of no matches. + ErrCannotMatch = errors.New("gock: cannot match any request") +) + +// Transport implements http.RoundTripper, which fulfills single http requests issued by +// an http.Client. +// +// gock's Transport encapsulates a given or default http.Transport for further +// delegation, if needed. +type Transport struct { + g *Gock + + // mutex is used to make transport thread-safe of concurrent uses across goroutines. + mutex sync.Mutex + + // Transport encapsulates the original http.RoundTripper transport interface for delegation. + Transport http.RoundTripper +} + +// NewTransport creates a new *Transport with no responders. +func (g *Gock) NewTransport(transport http.RoundTripper) *Transport { + return &Transport{g: g, Transport: transport} +} + +// transport is used to always return a non-nil transport. This is the same as `(http.Client).transport`, and is what +// would be invoked if gock's transport were not present. +func (m *Transport) transport() http.RoundTripper { + if m.Transport != nil { + return m.Transport + } + return http.DefaultTransport +} + +// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to +// implement the http.RoundTripper interface. You will not interact with this directly, instead +// the *http.Client you are using will call it for you. +func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + // Just act as a proxy if not intercepting + if !m.g.Intercepting() { + return m.transport().RoundTrip(req) + } + + m.mutex.Lock() + defer m.g.Clean() + + var err error + var res *http.Response + + // Match mock for the incoming http.Request + mock, err := m.g.MatchMock(req) + if err != nil { + m.mutex.Unlock() + return nil, err + } + + // Invoke the observer with the intercepted http.Request and matched mock + if m.g.config.Observer != nil { + m.g.config.Observer(req, mock) + } + + // Verify if should use real networking + networking := shouldUseNetwork(m.g, req, mock) + if !networking && mock == nil { + m.mutex.Unlock() + m.g.trackUnmatchedRequest(req) + return nil, ErrCannotMatch + } + + // Ensure me unlock the mutex before building the response + m.mutex.Unlock() + + // Perform real networking via original transport + if networking { + res, err = m.transport().RoundTrip(req) + // In no mock matched, continue with the response + if err != nil || mock == nil { + return res, err + } + } + + return Responder(req, mock.Response(), res) +} + +// CancelRequest is a no-op function. +func (m *Transport) CancelRequest(req *http.Request) {} + +func shouldUseNetwork(g *Gock, req *http.Request, mock Mock) bool { + if mock != nil && mock.Response().UseNetwork { + return true + } + if !g.config.Networking { + return false + } + if len(g.config.NetworkingFilters) == 0 { + return true + } + for _, filter := range g.config.NetworkingFilters { + if !filter(req) { + return false + } + } + return true +} diff --git a/threadsafe/transport_test.go b/threadsafe/transport_test.go new file mode 100644 index 0000000..5215da6 --- /dev/null +++ b/threadsafe/transport_test.go @@ -0,0 +1,55 @@ +package threadsafe + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/nbio/st" +) + +func TestTransportMatch(t *testing.T) { + g := NewGock() + defer after(g) + const uri = "http://foo.com" + g.New(uri).Reply(204) + u, _ := url.Parse(uri) + req := &http.Request{URL: u} + res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, nil) + st.Expect(t, res.StatusCode, 204) + st.Expect(t, res.Request, req) +} + +func TestTransportCannotMatch(t *testing.T) { + g := NewGock() + defer after(g) + g.New("http://foo.com").Reply(204) + u, _ := url.Parse("http://127.0.0.1:1234") + req := &http.Request{URL: u} + _, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, ErrCannotMatch) +} + +func TestTransportNotIntercepting(t *testing.T) { + g := NewGock() + defer after(g) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, world") + })) + defer ts.Close() + + g.New(ts.URL).Reply(200) + g.Disable() + + u, _ := url.Parse(ts.URL) + req := &http.Request{URL: u, Header: make(http.Header)} + + res, err := g.NewTransport(http.DefaultTransport).RoundTrip(req) + st.Expect(t, err, nil) + st.Expect(t, g.Intercepting(), false) + st.Expect(t, res.StatusCode, 200) +} From 73f34a17092d9c20818bd04654cf53ae7565cc2d Mon Sep 17 00:00:00 2001 From: Dan Wendorf Date: Mon, 8 Jan 2024 13:04:45 -0800 Subject: [PATCH 2/2] Reimplement root package using threadsafe To avoid having and maintaining duplicate code, the root not-threadsafe package is reimplemented using a globl instance of *threadsafe.Gock. I tried to make this as non-breaking of a change as possible, but it could not be done without some breaking changes: * Exported types are just exposing threadsafe types. For example, `type MatchFunc` is changed from `type MatchFunc func(*http.Request, *Request) (bool, error)` to `type MatchFunc = threadsafe.MatchFunc`. The ergonomics of using these types should be unchanged, but it is technically breaking. * Some package-level variables were exposed to allow dynamic configuration, like MatchersHeader. To correctly use the *threadsafe.Gock instance, I had to replace the var with a getter function and add a setter function. For getter use cases, users will just have to append `()` to call the function, but for setter use cases they will need to modify their code a little more (especially if they were doing something like appending to the slice). Other notable things: * I tried to leave as much of the original test suite as possible to prove that this refactor is correct. That means there are some unnecessarily duplicated tests between the root package and `threadsafe`, so there's an opportunity for cleanup. * Some root-level tests relied on unexported symbols which are no longer available to those tests. Some were able to be updated using exported getters, but some were deleted. I believe the deleted tests were not providing additional value because of the above-mentioned duplication. * To correctly maintain the getting and setting of http.DefaultTransport, I added "callback" methods for *threadsafe.Gock: DisableCallback, InterceptCallback, and InterceptingCallback. The root package sets these on the `var g *threadsafe.Gock` variable, and the functions are responsible for reading or writing http.DefaultTransport. Implementing this logic in the original functions (e.g. `gock.Disable`) proved too odd since the some of the functions call others. We would have to retain some duplicate implementation logic to run the logic in the right place, so the callback methods felt like the cleanest workaround. --- gock.go | 128 ++++++------------- gock_test.go | 2 +- matcher.go | 138 ++++++--------------- matcher_test.go | 51 ++++---- matchers.go | 243 +++++------------------------------- matchers_test.go | 9 ++ mock.go | 161 +----------------------- mock_test.go | 65 ---------- options.go | 8 +- request.go | 315 +---------------------------------------------- request_test.go | 1 - responder.go | 105 +--------------- response.go | 186 +--------------------------- response_test.go | 2 +- store.go | 72 ++--------- store_test.go | 24 ++-- transport.go | 87 +------------ 17 files changed, 189 insertions(+), 1408 deletions(-) diff --git a/gock.go b/gock.go index c3a8333..5f1e4c8 100644 --- a/gock.go +++ b/gock.go @@ -1,57 +1,43 @@ package gock import ( - "fmt" "net/http" - "net/http/httputil" - "net/url" - "regexp" "sync" + + "github.com/h2non/gock/threadsafe" ) +var g = threadsafe.NewGock() + +func init() { + g.DisableCallback = disable + g.InterceptCallback = intercept + g.InterceptingCallback = intercepting +} + // mutex is used interally for locking thread-sensitive functions. var mutex = &sync.Mutex{} -// config global singleton store. -var config = struct { - Networking bool - NetworkingFilters []FilterRequestFunc - Observer ObserverFunc -}{} - // ObserverFunc is implemented by users to inspect the outgoing intercepted HTTP traffic -type ObserverFunc func(*http.Request, Mock) +type ObserverFunc = threadsafe.ObserverFunc // DumpRequest is a default implementation of ObserverFunc that dumps // the HTTP/1.x wire representation of the http request -var DumpRequest ObserverFunc = func(request *http.Request, mock Mock) { - bytes, _ := httputil.DumpRequestOut(request, true) - fmt.Println(string(bytes)) - fmt.Printf("\nMatches: %v\n---\n", mock != nil) -} - -// track unmatched requests so they can be tested for -var unmatchedRequests = []*http.Request{} +var DumpRequest = g.DumpRequest // New creates and registers a new HTTP mock with // default settings and returns the Request DSL for HTTP mock // definition and set up. func New(uri string) *Request { - Intercept() - - res := NewResponse() - req := NewRequest() - req.URLStruct, res.Error = url.Parse(normalizeURI(uri)) - - // Create the new mock expectation - exp := NewMock(req, res) - Register(exp) - - return req + return g.New(uri) } // Intercepting returns true if gock is currently able to intercept. func Intercepting() bool { + return g.Intercepting() +} + +func intercepting() bool { mutex.Lock() defer mutex.Unlock() return http.DefaultTransport == DefaultTransport @@ -60,37 +46,33 @@ func Intercepting() bool { // Intercept enables HTTP traffic interception via http.DefaultTransport. // If you are using a custom HTTP transport, you have to use `gock.Transport()` func Intercept() { - if !Intercepting() { - mutex.Lock() - http.DefaultTransport = DefaultTransport - mutex.Unlock() - } + g.Intercept() +} + +func intercept() { + mutex.Lock() + http.DefaultTransport = DefaultTransport + mutex.Unlock() } // InterceptClient allows the developer to intercept HTTP traffic using // a custom http.Client who uses a non default http.Transport/http.RoundTripper implementation. func InterceptClient(cli *http.Client) { - _, ok := cli.Transport.(*Transport) - if ok { - return // if transport already intercepted, just ignore it - } - trans := NewTransport() - trans.Transport = cli.Transport - cli.Transport = trans + g.InterceptClient(cli) } // RestoreClient allows the developer to disable and restore the // original transport in the given http.Client. func RestoreClient(cli *http.Client) { - trans, ok := cli.Transport.(*Transport) - if !ok { - return - } - cli.Transport = trans.Transport + g.RestoreClient(cli) } // Disable disables HTTP traffic interception by gock. func Disable() { + g.Disable() +} + +func disable() { mutex.Lock() defer mutex.Unlock() http.DefaultTransport = NativeTransport @@ -99,80 +81,50 @@ func Disable() { // Off disables the default HTTP interceptors and removes // all the registered mocks, even if they has not been intercepted yet. func Off() { - Flush() - Disable() + g.Off() } // OffAll is like `Off()`, but it also removes the unmatched requests registry. func OffAll() { - Flush() - Disable() - CleanUnmatchedRequest() + g.OffAll() } // Observe provides a hook to support inspection of the request and matched mock func Observe(fn ObserverFunc) { - mutex.Lock() - defer mutex.Unlock() - config.Observer = fn + g.Observe(fn) } // EnableNetworking enables real HTTP networking func EnableNetworking() { - mutex.Lock() - defer mutex.Unlock() - config.Networking = true + g.EnableNetworking() } // DisableNetworking disables real HTTP networking func DisableNetworking() { - mutex.Lock() - defer mutex.Unlock() - config.Networking = false + g.DisableNetworking() } // NetworkingFilter determines if an http.Request should be triggered or not. func NetworkingFilter(fn FilterRequestFunc) { - mutex.Lock() - defer mutex.Unlock() - config.NetworkingFilters = append(config.NetworkingFilters, fn) + g.NetworkingFilter(fn) } // DisableNetworkingFilters disables registered networking filters. func DisableNetworkingFilters() { - mutex.Lock() - defer mutex.Unlock() - config.NetworkingFilters = []FilterRequestFunc{} + g.DisableNetworkingFilters() } // GetUnmatchedRequests returns all requests that have been received but haven't matched any mock func GetUnmatchedRequests() []*http.Request { - mutex.Lock() - defer mutex.Unlock() - return unmatchedRequests + return g.GetUnmatchedRequests() } // HasUnmatchedRequest returns true if gock has received any requests that didn't match a mock func HasUnmatchedRequest() bool { - return len(GetUnmatchedRequests()) > 0 + return g.HasUnmatchedRequest() } // CleanUnmatchedRequest cleans the unmatched requests internal registry. func CleanUnmatchedRequest() { - mutex.Lock() - defer mutex.Unlock() - unmatchedRequests = []*http.Request{} -} - -func trackUnmatchedRequest(req *http.Request) { - mutex.Lock() - defer mutex.Unlock() - unmatchedRequests = append(unmatchedRequests, req) -} - -func normalizeURI(uri string) string { - if ok, _ := regexp.MatchString("^http[s]?", uri); !ok { - return "http://" + uri - } - return uri + g.CleanUnmatchedRequest() } diff --git a/gock_test.go b/gock_test.go index 7df68fe..8ad128b 100644 --- a/gock_test.go +++ b/gock_test.go @@ -304,7 +304,7 @@ func TestUnmatched(t *testing.T) { defer after() // clear out any unmatchedRequests from other tests - unmatchedRequests = []*http.Request{} + CleanUnmatchedRequest() Intercept() diff --git a/matcher.go b/matcher.go index 11a1d7e..633734f 100644 --- a/matcher.go +++ b/matcher.go @@ -1,137 +1,75 @@ package gock -import "net/http" +import ( + "net/http" + + "github.com/h2non/gock/threadsafe" +) // MatchersHeader exposes an slice of HTTP header specific mock matchers. -var MatchersHeader = []MatchFunc{ - MatchMethod, - MatchScheme, - MatchHost, - MatchPath, - MatchHeaders, - MatchQueryParams, - MatchPathParams, +func MatchersHeader() []MatchFunc { + return g.MatchersHeader +} + +func SetMatchersHeader(matchers []MatchFunc) { + g.MatchersHeader = matchers } // MatchersBody exposes an slice of HTTP body specific built-in mock matchers. -var MatchersBody = []MatchFunc{ - MatchBody, +func MatchersBody() []MatchFunc { + return g.MatchersBody +} + +func SetMatchersBody(matchers []MatchFunc) { + g.MatchersBody = matchers } // Matchers stores all the built-in mock matchers. -var Matchers = append(MatchersHeader, MatchersBody...) +func Matchers() []MatchFunc { + return g.Matchers +} + +func SetMatchers(matchers []MatchFunc) { + g.Matchers = matchers +} // DefaultMatcher stores the default Matcher instance used to match mocks. -var DefaultMatcher = NewMatcher() +func DefaultMatcher() *MockMatcher { + return g.DefaultMatcher +} + +func SetDefaultMatcher(matcher *MockMatcher) { + g.DefaultMatcher = matcher +} // MatchFunc represents the required function // interface implemented by matchers. -type MatchFunc func(*http.Request, *Request) (bool, error) +type MatchFunc = threadsafe.MatchFunc // Matcher represents the required interface implemented by mock matchers. -type Matcher interface { - // Get returns a slice of registered function matchers. - Get() []MatchFunc - - // Add adds a new matcher function. - Add(MatchFunc) - - // Set sets the matchers functions stack. - Set([]MatchFunc) - - // Flush flushes the current matchers function stack. - Flush() - - // Match matches the given http.Request with a mock Request. - Match(*http.Request, *Request) (bool, error) -} +type Matcher = threadsafe.Matcher // MockMatcher implements a mock matcher -type MockMatcher struct { - Matchers []MatchFunc -} +type MockMatcher = threadsafe.MockMatcher // NewMatcher creates a new mock matcher // using the default matcher functions. func NewMatcher() *MockMatcher { - m := NewEmptyMatcher() - for _, matchFn := range Matchers { - m.Add(matchFn) - } - return m + return g.NewMatcher() } // NewBasicMatcher creates a new matcher with header only mock matchers. func NewBasicMatcher() *MockMatcher { - m := NewEmptyMatcher() - for _, matchFn := range MatchersHeader { - m.Add(matchFn) - } - return m + return g.NewBasicMatcher() } // NewEmptyMatcher creates a new empty matcher without default matchers. func NewEmptyMatcher() *MockMatcher { - return &MockMatcher{Matchers: []MatchFunc{}} -} - -// Get returns a slice of registered function matchers. -func (m *MockMatcher) Get() []MatchFunc { - mutex.Lock() - defer mutex.Unlock() - return m.Matchers -} - -// Add adds a new function matcher. -func (m *MockMatcher) Add(fn MatchFunc) { - m.Matchers = append(m.Matchers, fn) -} - -// Set sets a new stack of matchers functions. -func (m *MockMatcher) Set(stack []MatchFunc) { - m.Matchers = stack -} - -// Flush flushes the current matcher -func (m *MockMatcher) Flush() { - m.Matchers = []MatchFunc{} -} - -// Clone returns a separate MockMatcher instance that has a copy of the same MatcherFuncs -func (m *MockMatcher) Clone() *MockMatcher { - m2 := NewEmptyMatcher() - for _, mFn := range m.Get() { - m2.Add(mFn) - } - return m2 -} - -// Match matches the given http.Request with a mock request -// returning true in case that the request matches, otherwise false. -func (m *MockMatcher) Match(req *http.Request, ereq *Request) (bool, error) { - for _, matcher := range m.Matchers { - matches, err := matcher(req, ereq) - if err != nil { - return false, err - } - if !matches { - return false, nil - } - } - return true, nil + return g.NewEmptyMatcher() } // MatchMock is a helper function that matches the given http.Request // in the list of registered mocks, returning it if matches or error if it fails. func MatchMock(req *http.Request) (Mock, error) { - for _, mock := range GetAll() { - matches, err := mock.Match(req) - if err != nil { - return nil, err - } - if matches { - return mock, nil - } - } - return nil, nil + return g.MatchMock(req) } diff --git a/matcher_test.go b/matcher_test.go index d96c00c..c7e842c 100644 --- a/matcher_test.go +++ b/matcher_test.go @@ -9,24 +9,24 @@ import ( ) func TestRegisteredMatchers(t *testing.T) { - st.Expect(t, len(MatchersHeader), 7) - st.Expect(t, len(MatchersBody), 1) + st.Expect(t, len(MatchersHeader()), 7) + st.Expect(t, len(MatchersBody()), 1) } func TestNewMatcher(t *testing.T) { matcher := NewMatcher() // Funcs are not comparable, checking slice length as it's better than nothing // See https://golang.org/pkg/reflect/#DeepEqual - st.Expect(t, len(matcher.Matchers), len(Matchers)) - st.Expect(t, len(matcher.Get()), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) + st.Expect(t, len(matcher.Get()), len(Matchers())) } func TestNewBasicMatcher(t *testing.T) { matcher := NewBasicMatcher() // Funcs are not comparable, checking slice length as it's better than nothing // See https://golang.org/pkg/reflect/#DeepEqual - st.Expect(t, len(matcher.Matchers), len(MatchersHeader)) - st.Expect(t, len(matcher.Get()), len(MatchersHeader)) + st.Expect(t, len(matcher.Matchers), len(MatchersHeader())) + st.Expect(t, len(matcher.Get()), len(MatchersHeader())) } func TestNewEmptyMatcher(t *testing.T) { @@ -37,17 +37,17 @@ func TestNewEmptyMatcher(t *testing.T) { func TestMatcherAdd(t *testing.T) { matcher := NewMatcher() - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, len(matcher.Get()), len(Matchers)+1) + st.Expect(t, len(matcher.Get()), len(Matchers())+1) } func TestMatcherSet(t *testing.T) { matcher := NewMatcher() matchers := []MatchFunc{} - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Set(matchers) st.Expect(t, matcher.Matchers, matchers) st.Expect(t, len(matcher.Get()), 0) @@ -62,18 +62,18 @@ func TestMatcherGet(t *testing.T) { func TestMatcherFlush(t *testing.T) { matcher := NewMatcher() - st.Expect(t, len(matcher.Matchers), len(Matchers)) + st.Expect(t, len(matcher.Matchers), len(Matchers())) matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, len(matcher.Get()), len(Matchers)+1) + st.Expect(t, len(matcher.Get()), len(Matchers())+1) matcher.Flush() st.Expect(t, len(matcher.Get()), 0) } func TestMatcherClone(t *testing.T) { - matcher := DefaultMatcher.Clone() - st.Expect(t, len(matcher.Get()), len(DefaultMatcher.Get())) + matcher := DefaultMatcher().Clone() + st.Expect(t, len(matcher.Get()), len(DefaultMatcher().Get())) } func TestMatcher(t *testing.T) { @@ -115,19 +115,20 @@ func TestMatcher(t *testing.T) { func TestMatchMock(t *testing.T) { cases := []struct { - method string - url string - matches bool + method string + methodFn func(r *Request, path string) *Request + url string + matches bool }{ - {"GET", "http://foo.com/bar", true}, - {"GET", "http://foo.com/baz", true}, - {"GET", "http://foo.com/foo", false}, - {"POST", "http://foo.com/bar", false}, - {"POST", "http://bar.com/bar", false}, - {"GET", "http://foo.com", false}, + {"GET", (*Request).Get, "http://foo.com/bar", true}, + {"GET", (*Request).Get, "http://foo.com/baz", true}, + {"GET", (*Request).Get, "http://foo.com/foo", false}, + {"POST", (*Request).Post, "http://foo.com/bar", false}, + {"POST", (*Request).Post, "http://bar.com/bar", false}, + {"GET", (*Request).Get, "http://foo.com", false}, } - matcher := DefaultMatcher + matcher := DefaultMatcher() matcher.Flush() st.Expect(t, len(matcher.Matchers), 0) @@ -143,7 +144,7 @@ func TestMatchMock(t *testing.T) { for _, test := range cases { Flush() - mock := New(test.url).method(test.method, "").Mock + mock := test.methodFn(New(test.url), "").Mock u, _ := url.Parse(test.url) req := &http.Request{Method: test.method, URL: u} @@ -157,5 +158,5 @@ func TestMatchMock(t *testing.T) { } } - DefaultMatcher.Matchers = Matchers + DefaultMatcher().Matchers = Matchers() } diff --git a/matchers.go b/matchers.go index 658c9a6..8c3ac77 100644 --- a/matchers.go +++ b/matchers.go @@ -1,266 +1,79 @@ package gock import ( - "compress/gzip" - "encoding/json" - "io" - "io/ioutil" "net/http" - "reflect" - "regexp" - "strings" - "github.com/h2non/parth" + "github.com/h2non/gock/threadsafe" ) // EOL represents the end of line character. -const EOL = 0xa +const EOL = threadsafe.EOL // BodyTypes stores the supported MIME body types for matching. // Currently only text-based types. -var BodyTypes = []string{ - "text/html", - "text/plain", - "application/json", - "application/xml", - "multipart/form-data", - "application/x-www-form-urlencoded", +func BodyTypes() []string { + return g.BodyTypes +} + +func SetBodyTypes(types []string) { + g.BodyTypes = types } // BodyTypeAliases stores a generic MIME type by alias. -var BodyTypeAliases = map[string]string{ - "html": "text/html", - "text": "text/plain", - "json": "application/json", - "xml": "application/xml", - "form": "multipart/form-data", - "url": "application/x-www-form-urlencoded", +func BodyTypeAliases() map[string]string { + return g.BodyTypeAliases +} + +func SetBodyTypeAliases(aliases map[string]string) { + g.BodyTypeAliases = aliases } // CompressionSchemes stores the supported Content-Encoding types for decompression. -var CompressionSchemes = []string{ - "gzip", +func CompressionSchemes() []string { + return g.CompressionSchemes +} + +func SetCompressionSchemes(schemes []string) { + g.CompressionSchemes = schemes } // MatchMethod matches the HTTP method of the given request. func MatchMethod(req *http.Request, ereq *Request) (bool, error) { - return ereq.Method == "" || req.Method == ereq.Method, nil + return g.MatchMethod(req, ereq) } // MatchScheme matches the request URL protocol scheme. func MatchScheme(req *http.Request, ereq *Request) (bool, error) { - return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil + return g.MatchScheme(req, ereq) } // MatchHost matches the HTTP host header field of the given request. func MatchHost(req *http.Request, ereq *Request) (bool, error) { - url := ereq.URLStruct - if strings.EqualFold(url.Host, req.URL.Host) { - return true, nil - } - if !ereq.Options.DisableRegexpHost { - return regexp.MatchString(url.Host, req.URL.Host) - } - return false, nil + return g.MatchHost(req, ereq) } // MatchPath matches the HTTP URL path of the given request. func MatchPath(req *http.Request, ereq *Request) (bool, error) { - if req.URL.Path == ereq.URLStruct.Path { - return true, nil - } - return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path) + return g.MatchPath(req, ereq) } // MatchHeaders matches the headers fields of the given request. func MatchHeaders(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.Header { - var err error - var match bool - var matchEscaped bool - - for _, field := range req.Header[key] { - match, err = regexp.MatchString(value[0], field) - // Some values may contain reserved regex params e.g. "()", try matching with these escaped. - matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field) - - if err != nil { - return false, err - } - if match || matchEscaped { - break - } - - } - - if !match && !matchEscaped { - return false, nil - } - } - return true, nil + return g.MatchHeaders(req, ereq) } // MatchQueryParams matches the URL query params fields of the given request. func MatchQueryParams(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.URLStruct.Query() { - var err error - var match bool - - for _, field := range req.URL.Query()[key] { - match, err = regexp.MatchString(value[0], field) - if err != nil { - return false, err - } - if match { - break - } - } - - if !match { - return false, nil - } - } - return true, nil + return g.MatchQueryParams(req, ereq) } // MatchPathParams matches the URL path parameters of the given request. func MatchPathParams(req *http.Request, ereq *Request) (bool, error) { - for key, value := range ereq.PathParams { - var s string - - if err := parth.Sequent(req.URL.Path, key, &s); err != nil { - return false, nil - } - - if s != value { - return false, nil - } - } - return true, nil + return g.MatchPathParams(req, ereq) } // MatchBody tries to match the request body. // TODO: not too smart now, needs several improvements. func MatchBody(req *http.Request, ereq *Request) (bool, error) { - // If match body is empty, just continue - if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 { - return true, nil - } - - // Only can match certain MIME body types - if !supportedType(req, ereq) { - return false, nil - } - - // Can only match certain compression schemes - if !supportedCompressionScheme(req) { - return false, nil - } - - // Create a reader for the body depending on compression type - bodyReader := req.Body - if ereq.CompressionScheme != "" { - if ereq.CompressionScheme != req.Header.Get("Content-Encoding") { - return false, nil - } - compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme) - if err != nil { - return false, err - } - bodyReader = compressedBodyReader - } - - // Read the whole request body - body, err := ioutil.ReadAll(bodyReader) - if err != nil { - return false, err - } - - // Restore body reader stream - req.Body = createReadCloser(body) - - // If empty, ignore the match - if len(body) == 0 && len(ereq.BodyBuffer) != 0 { - return false, nil - } - - // Match body by atomic string comparison - bodyStr := castToString(body) - matchStr := castToString(ereq.BodyBuffer) - if bodyStr == matchStr { - return true, nil - } - - // Match request body by regexp - match, _ := regexp.MatchString(matchStr, bodyStr) - if match == true { - return true, nil - } - - // todo - add conditional do only perform the conversion of body bytes - // representation of JSON to a map and then compare them for equality. - - // Check if the key + value pairs match - var bodyMap map[string]interface{} - var matchMap map[string]interface{} - - // Ensure that both byte bodies that that should be JSON can be converted to maps. - umErr := json.Unmarshal(body, &bodyMap) - umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap) - if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) { - return true, nil - } - - return false, nil -} - -func supportedType(req *http.Request, ereq *Request) bool { - mime := req.Header.Get("Content-Type") - if mime == "" { - return true - } - - mimeToMatch := ereq.Header.Get("Content-Type") - if mimeToMatch != "" { - return mime == mimeToMatch - } - - for _, kind := range BodyTypes { - if match, _ := regexp.MatchString(kind, mime); match { - return true - } - } - return false -} - -func supportedCompressionScheme(req *http.Request) bool { - encoding := req.Header.Get("Content-Encoding") - if encoding == "" { - return true - } - - for _, kind := range CompressionSchemes { - if match, _ := regexp.MatchString(kind, encoding); match { - return true - } - } - return false -} - -func castToString(buf []byte) string { - str := string(buf) - tail := len(str) - 1 - if str[tail] == EOL { - str = str[:tail] - } - return str -} - -func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) { - switch scheme { - case "gzip": - return gzip.NewReader(r) - default: - return r, nil - } + return g.MatchBody(req, ereq) } diff --git a/matchers_test.go b/matchers_test.go index 56aaa01..cbe30d6 100644 --- a/matchers_test.go +++ b/matchers_test.go @@ -1,6 +1,9 @@ package gock import ( + "bytes" + "io" + "io/ioutil" "net/http" "net/url" "testing" @@ -249,3 +252,9 @@ func TestMatchBody_MatchType(t *testing.T) { st.Expect(t, matches, test.matches) } } + +// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an +// http response body. +func createReadCloser(body []byte) io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader(body)) +} diff --git a/mock.go b/mock.go index d28875b..aa388ff 100644 --- a/mock.go +++ b/mock.go @@ -1,172 +1,19 @@ package gock import ( - "net/http" - "sync" + "github.com/h2non/gock/threadsafe" ) // Mock represents the required interface that must // be implemented by HTTP mock instances. -type Mock interface { - // Disable disables the current mock manually. - Disable() - - // Done returns true if the current mock is disabled. - Done() bool - - // Request returns the mock Request instance. - Request() *Request - - // Response returns the mock Response instance. - Response() *Response - - // Match matches the given http.Request with the current mock. - Match(*http.Request) (bool, error) - - // AddMatcher adds a new matcher function. - AddMatcher(MatchFunc) - - // SetMatcher uses a new matcher implementation. - SetMatcher(Matcher) -} +type Mock = threadsafe.Mock // Mocker implements a Mock capable interface providing // a default mock configuration used internally to store mocks. -type Mocker struct { - // disabler stores a disabler for thread safety checking current mock is disabled - disabler *disabler - - // mutex stores the mock mutex for thread safety. - mutex sync.Mutex - - // matcher stores a Matcher capable instance to match the given http.Request. - matcher Matcher - - // request stores the mock Request to match. - request *Request - - // response stores the mock Response to use in case of match. - response *Response -} - -type disabler struct { - // disabled stores if the current mock is disabled. - disabled bool - - // mutex stores the disabler mutex for thread safety. - mutex sync.RWMutex -} - -func (d *disabler) isDisabled() bool { - d.mutex.RLock() - defer d.mutex.RUnlock() - return d.disabled -} - -func (d *disabler) Disable() { - d.mutex.Lock() - defer d.mutex.Unlock() - d.disabled = true -} +type Mocker = threadsafe.Mocker // NewMock creates a new HTTP mock based on the given request and response instances. // It's mostly used internally. func NewMock(req *Request, res *Response) *Mocker { - mock := &Mocker{ - disabler: new(disabler), - request: req, - response: res, - matcher: DefaultMatcher.Clone(), - } - res.Mock = mock - req.Mock = mock - req.Response = res - return mock -} - -// Disable disables the current mock manually. -func (m *Mocker) Disable() { - m.disabler.Disable() -} - -// Done returns true in case that the current mock -// instance is disabled and therefore must be removed. -func (m *Mocker) Done() bool { - // prevent deadlock with m.mutex - if m.disabler.isDisabled() { - return true - } - - m.mutex.Lock() - defer m.mutex.Unlock() - return !m.request.Persisted && m.request.Counter == 0 -} - -// Request returns the Request instance -// configured for the current HTTP mock. -func (m *Mocker) Request() *Request { - return m.request -} - -// Response returns the Response instance -// configured for the current HTTP mock. -func (m *Mocker) Response() *Response { - return m.response -} - -// Match matches the given http.Request with the current Request -// mock expectation, returning true if matches. -func (m *Mocker) Match(req *http.Request) (bool, error) { - if m.disabler.isDisabled() { - return false, nil - } - - // Filter - for _, filter := range m.request.Filters { - if !filter(req) { - return false, nil - } - } - - // Map - for _, mapper := range m.request.Mappers { - if treq := mapper(req); treq != nil { - req = treq - } - } - - // Match - matches, err := m.matcher.Match(req, m.request) - if matches { - m.decrement() - } - - return matches, err -} - -// SetMatcher sets a new matcher implementation -// for the current mock expectation. -func (m *Mocker) SetMatcher(matcher Matcher) { - m.matcher = matcher -} - -// AddMatcher adds a new matcher function -// for the current mock expectation. -func (m *Mocker) AddMatcher(fn MatchFunc) { - m.matcher.Add(fn) -} - -// decrement decrements the current mock Request counter. -func (m *Mocker) decrement() { - if m.request.Persisted { - return - } - - m.mutex.Lock() - defer m.mutex.Unlock() - - m.request.Counter-- - if m.request.Counter == 0 { - m.disabler.Disable() - } + return g.NewMock(req, res) } diff --git a/mock_test.go b/mock_test.go index 01e6fca..70b0765 100644 --- a/mock_test.go +++ b/mock_test.go @@ -7,63 +7,6 @@ import ( "github.com/nbio/st" ) -func TestNewMock(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - mock := NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) - - st.Expect(t, mock.Request(), req) - st.Expect(t, mock.Request().Mock, mock) - st.Expect(t, mock.Response(), res) - st.Expect(t, mock.Response().Mock, mock) -} - -func TestMockDisable(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - mock := NewMock(req, res) - - st.Expect(t, mock.disabler.isDisabled(), false) - mock.Disable() - st.Expect(t, mock.disabler.isDisabled(), true) - - matches, err := mock.Match(&http.Request{}) - st.Expect(t, err, nil) - st.Expect(t, matches, false) -} - -func TestMockDone(t *testing.T) { - defer after() - - req := NewRequest() - res := NewResponse() - - mock := NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.Done(), false) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.Disable() - st.Expect(t, mock.Done(), true) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.request.Counter = 0 - st.Expect(t, mock.Done(), true) - - mock = NewMock(req, res) - st.Expect(t, mock.disabler.isDisabled(), false) - mock.request.Persisted = true - st.Expect(t, mock.Done(), false) -} - func TestMockSetMatcher(t *testing.T) { defer after() @@ -71,15 +14,12 @@ func TestMockSetMatcher(t *testing.T) { res := NewResponse() mock := NewMock(req, res) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) matcher := NewMatcher() matcher.Flush() matcher.Add(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) mock.SetMatcher(matcher) - st.Expect(t, len(mock.matcher.Get()), 1) - st.Expect(t, mock.disabler.isDisabled(), false) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) @@ -93,15 +33,12 @@ func TestMockAddMatcher(t *testing.T) { res := NewResponse() mock := NewMock(req, res) - st.Expect(t, len(mock.matcher.Get()), len(DefaultMatcher.Get())) matcher := NewMatcher() matcher.Flush() mock.SetMatcher(matcher) mock.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { return true, nil }) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.matcher, matcher) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) @@ -127,8 +64,6 @@ func TestMockMatch(t *testing.T) { calls++ return true, nil }) - st.Expect(t, mock.disabler.isDisabled(), false) - st.Expect(t, mock.matcher, matcher) matches, err := mock.Match(&http.Request{}) st.Expect(t, err, nil) diff --git a/options.go b/options.go index 188aa58..f754563 100644 --- a/options.go +++ b/options.go @@ -1,8 +1,6 @@ package gock +import "github.com/h2non/gock/threadsafe" + // Options represents customized option for gock -type Options struct { - // DisableRegexpHost stores if the host is only a plain string rather than regular expression, - // if DisableRegexpHost is true, host sets in gock.New(...) will be treated as plain string - DisableRegexpHost bool -} +type Options = threadsafe.Options diff --git a/request.go b/request.go index 5702417..1563c37 100644 --- a/request.go +++ b/request.go @@ -1,325 +1,20 @@ package gock import ( - "encoding/base64" - "io" - "io/ioutil" - "net/http" - "net/url" - "strings" + "github.com/h2non/gock/threadsafe" ) // MapRequestFunc represents the required function interface for request mappers. -type MapRequestFunc func(*http.Request) *http.Request +type MapRequestFunc = threadsafe.MapRequestFunc // FilterRequestFunc represents the required function interface for request filters. -type FilterRequestFunc func(*http.Request) bool +type FilterRequestFunc = threadsafe.FilterRequestFunc // Request represents the high-level HTTP request used to store // request fields used to match intercepted requests. -type Request struct { - // Mock stores the parent mock reference for the current request mock used for method delegation. - Mock Mock - - // Response stores the current Response instance for the current matches Request. - Response *Response - - // Error stores the latest mock request configuration error. - Error error - - // Counter stores the pending times that the current mock should be active. - Counter int - - // Persisted stores if the current mock should be always active. - Persisted bool - - // Options stores options for current Request. - Options Options - - // URLStruct stores the parsed URL as *url.URL struct. - URLStruct *url.URL - - // Method stores the Request HTTP method to match. - Method string - - // CompressionScheme stores the Request Compression scheme to match and use for decompression. - CompressionScheme string - - // Header stores the HTTP header fields to match. - Header http.Header - - // Cookies stores the Request HTTP cookies values to match. - Cookies []*http.Cookie - - // PathParams stores the path parameters to match. - PathParams map[string]string - - // BodyBuffer stores the body data to match. - BodyBuffer []byte - - // Mappers stores the request functions mappers used for matching. - Mappers []MapRequestFunc - - // Filters stores the request functions filters used for matching. - Filters []FilterRequestFunc -} +type Request = threadsafe.Request // NewRequest creates a new Request instance. func NewRequest() *Request { - return &Request{ - Counter: 1, - URLStruct: &url.URL{}, - Header: make(http.Header), - PathParams: make(map[string]string), - } -} - -// URL defines the mock URL to match. -func (r *Request) URL(uri string) *Request { - r.URLStruct, r.Error = url.Parse(uri) - return r -} - -// SetURL defines the url.URL struct to be used for matching. -func (r *Request) SetURL(u *url.URL) *Request { - r.URLStruct = u - return r -} - -// Path defines the mock URL path value to match. -func (r *Request) Path(path string) *Request { - r.URLStruct.Path = path - return r -} - -// Get specifies the GET method and the given URL path to match. -func (r *Request) Get(path string) *Request { - return r.method("GET", path) -} - -// Post specifies the POST method and the given URL path to match. -func (r *Request) Post(path string) *Request { - return r.method("POST", path) -} - -// Put specifies the PUT method and the given URL path to match. -func (r *Request) Put(path string) *Request { - return r.method("PUT", path) -} - -// Delete specifies the DELETE method and the given URL path to match. -func (r *Request) Delete(path string) *Request { - return r.method("DELETE", path) -} - -// Patch specifies the PATCH method and the given URL path to match. -func (r *Request) Patch(path string) *Request { - return r.method("PATCH", path) -} - -// Head specifies the HEAD method and the given URL path to match. -func (r *Request) Head(path string) *Request { - return r.method("HEAD", path) -} - -// method is a DRY shortcut used to declare the expected HTTP method and URL path. -func (r *Request) method(method, path string) *Request { - if path != "/" { - r.URLStruct.Path = path - } - r.Method = strings.ToUpper(method) - return r -} - -// Body defines the body data to match based on a io.Reader interface. -func (r *Request) Body(body io.Reader) *Request { - r.BodyBuffer, r.Error = ioutil.ReadAll(body) - return r -} - -// BodyString defines the body to match based on a given string. -func (r *Request) BodyString(body string) *Request { - r.BodyBuffer = []byte(body) - return r -} - -// File defines the body to match based on the given file path string. -func (r *Request) File(path string) *Request { - r.BodyBuffer, r.Error = ioutil.ReadFile(path) - return r -} - -// Compression defines the request compression scheme, and enables automatic body decompression. -// Supports only the "gzip" scheme so far. -func (r *Request) Compression(scheme string) *Request { - r.Header.Set("Content-Encoding", scheme) - r.CompressionScheme = scheme - return r -} - -// JSON defines the JSON body to match based on a given structure. -func (r *Request) JSON(data interface{}) *Request { - if r.Header.Get("Content-Type") == "" { - r.Header.Set("Content-Type", "application/json") - } - r.BodyBuffer, r.Error = readAndDecode(data, "json") - return r -} - -// XML defines the XML body to match based on a given structure. -func (r *Request) XML(data interface{}) *Request { - if r.Header.Get("Content-Type") == "" { - r.Header.Set("Content-Type", "application/xml") - } - r.BodyBuffer, r.Error = readAndDecode(data, "xml") - return r -} - -// MatchType defines the request Content-Type MIME header field. -// Supports custom MIME types and type aliases. E.g: json, xml, form, text... -func (r *Request) MatchType(kind string) *Request { - mime := BodyTypeAliases[kind] - if mime != "" { - kind = mime - } - r.Header.Set("Content-Type", kind) - return r -} - -// BasicAuth defines a username and password for HTTP Basic Authentication -func (r *Request) BasicAuth(username, password string) *Request { - r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) - return r -} - -// MatchHeader defines a new key and value header to match. -func (r *Request) MatchHeader(key, value string) *Request { - r.Header.Set(key, value) - return r -} - -// HeaderPresent defines that a header field must be present in the request. -func (r *Request) HeaderPresent(key string) *Request { - r.Header.Set(key, ".*") - return r -} - -// MatchHeaders defines a map of key-value headers to match. -func (r *Request) MatchHeaders(headers map[string]string) *Request { - for key, value := range headers { - r.Header.Set(key, value) - } - return r -} - -// MatchParam defines a new key and value URL query param to match. -func (r *Request) MatchParam(key, value string) *Request { - query := r.URLStruct.Query() - query.Set(key, value) - r.URLStruct.RawQuery = query.Encode() - return r -} - -// MatchParams defines a map of URL query param key-value to match. -func (r *Request) MatchParams(params map[string]string) *Request { - query := r.URLStruct.Query() - for key, value := range params { - query.Set(key, value) - } - r.URLStruct.RawQuery = query.Encode() - return r -} - -// ParamPresent matches if the given query param key is present in the URL. -func (r *Request) ParamPresent(key string) *Request { - r.MatchParam(key, ".*") - return r -} - -// PathParam matches if a given path parameter key is present in the URL. -// -// The value is representative of the restful resource the key defines, e.g. -// // /users/123/name -// r.PathParam("users", "123") -// would match. -func (r *Request) PathParam(key, val string) *Request { - r.PathParams[key] = val - - return r -} - -// Persist defines the current HTTP mock as persistent and won't be removed after intercepting it. -func (r *Request) Persist() *Request { - r.Persisted = true - return r -} - -// WithOptions sets the options for the request. -func (r *Request) WithOptions(options Options) *Request { - r.Options = options - return r -} - -// Times defines the number of times that the current HTTP mock should remain active. -func (r *Request) Times(num int) *Request { - r.Counter = num - return r -} - -// AddMatcher adds a new matcher function to match the request. -func (r *Request) AddMatcher(fn MatchFunc) *Request { - r.Mock.AddMatcher(fn) - return r -} - -// SetMatcher sets a new matcher function to match the request. -func (r *Request) SetMatcher(matcher Matcher) *Request { - r.Mock.SetMatcher(matcher) - return r -} - -// Map adds a new request mapper function to map http.Request before the matching process. -func (r *Request) Map(fn MapRequestFunc) *Request { - r.Mappers = append(r.Mappers, fn) - return r -} - -// Filter filters a new request filter function to filter http.Request before the matching process. -func (r *Request) Filter(fn FilterRequestFunc) *Request { - r.Filters = append(r.Filters, fn) - return r -} - -// EnableNetworking enables the use real networking for the current mock. -func (r *Request) EnableNetworking() *Request { - if r.Response != nil { - r.Response.UseNetwork = true - } - return r -} - -// Reply defines the Response status code and returns the mock Response DSL. -func (r *Request) Reply(status int) *Response { - return r.Response.Status(status) -} - -// ReplyError defines the Response simulated error. -func (r *Request) ReplyError(err error) *Response { - return r.Response.SetError(err) -} - -// ReplyFunc allows the developer to define the mock response via a custom function. -func (r *Request) ReplyFunc(replier func(*Response)) *Response { - replier(r.Response) - return r.Response -} - -// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt -// "To receive authorization, the client sends the userid and password, -// separated by a single colon (":") character, within a base64 -// encoded string in the credentials." -// It is not meant to be urlencoded. -func basicAuth(username, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) + return g.NewRequest() } diff --git a/request_test.go b/request_test.go index 463e784..011a0f9 100644 --- a/request_test.go +++ b/request_test.go @@ -266,7 +266,6 @@ func TestRequestAddMatcher(t *testing.T) { ereq := NewRequest() mock := NewMock(ereq, &Response{}) - mock.matcher = NewMatcher() ereq.Mock = mock ereq.AddMatcher(func(req *http.Request, ereq *Request) (bool, error) { diff --git a/responder.go b/responder.go index f0f16bb..0ec2de5 100644 --- a/responder.go +++ b/responder.go @@ -1,111 +1,12 @@ package gock import ( - "bytes" - "io" - "io/ioutil" "net/http" - "strconv" - "time" + + "github.com/h2non/gock/threadsafe" ) // Responder builds a mock http.Response based on the given Response mock. func Responder(req *http.Request, mock *Response, res *http.Response) (*http.Response, error) { - // If error present, reply it - err := mock.Error - if err != nil { - return nil, err - } - - if res == nil { - res = createResponse(req) - } - - // Apply response filter - for _, filter := range mock.Filters { - if !filter(res) { - return res, nil - } - } - - // Define mock status code - if mock.StatusCode != 0 { - res.Status = strconv.Itoa(mock.StatusCode) + " " + http.StatusText(mock.StatusCode) - res.StatusCode = mock.StatusCode - } - - // Define headers by merging fields - res.Header = mergeHeaders(res, mock) - - // Define mock body, if present - if len(mock.BodyBuffer) > 0 { - res.ContentLength = int64(len(mock.BodyBuffer)) - res.Body = createReadCloser(mock.BodyBuffer) - } - - // Set raw mock body, if exist - if mock.BodyGen != nil { - res.ContentLength = -1 - res.Body = mock.BodyGen() - } - - // Apply response mappers - for _, mapper := range mock.Mappers { - if tres := mapper(res); tres != nil { - res = tres - } - } - - // Sleep to simulate delay, if necessary - if mock.ResponseDelay > 0 { - // allow escaping from sleep due to request context expiration or cancellation - t := time.NewTimer(mock.ResponseDelay) - select { - case <-t.C: - case <-req.Context().Done(): - // cleanly stop the timer - if !t.Stop() { - <-t.C - } - } - } - - // check if the request context has ended. we could put this up in the delay code above, but putting it here - // has the added benefit of working even when there is no delay (very small timeouts, already-done contexts, etc.) - if err = req.Context().Err(); err != nil { - // cleanly close the response and return the context error - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - return nil, err - } - - return res, err -} - -// createResponse creates a new http.Response with default fields. -func createResponse(req *http.Request) *http.Response { - return &http.Response{ - ProtoMajor: 1, - ProtoMinor: 1, - Proto: "HTTP/1.1", - Request: req, - Header: make(http.Header), - Body: createReadCloser([]byte{}), - } -} - -// mergeHeaders copies the mock headers. -func mergeHeaders(res *http.Response, mres *Response) http.Header { - for key, values := range mres.Header { - for _, value := range values { - res.Header.Add(key, value) - } - } - return res.Header -} - -// createReadCloser creates an io.ReadCloser from a byte slice that is suitable for use as an -// http response body. -func createReadCloser(body []byte) io.ReadCloser { - return ioutil.NopCloser(bytes.NewReader(body)) + return threadsafe.Responder(req, mock, res) } diff --git a/response.go b/response.go index 3e62b9e..0eeb314 100644 --- a/response.go +++ b/response.go @@ -1,196 +1,20 @@ package gock import ( - "bytes" - "encoding/json" - "encoding/xml" - "io" - "io/ioutil" - "net/http" - "time" + "github.com/h2non/gock/threadsafe" ) // MapResponseFunc represents the required function interface impletemed by response mappers. -type MapResponseFunc func(*http.Response) *http.Response +type MapResponseFunc = threadsafe.MapResponseFunc // FilterResponseFunc represents the required function interface impletemed by response filters. -type FilterResponseFunc func(*http.Response) bool +type FilterResponseFunc = threadsafe.FilterResponseFunc // Response represents high-level HTTP fields to configure // and define HTTP responses intercepted by gock. -type Response struct { - // Mock stores the parent mock reference for the current response mock used for method delegation. - Mock Mock - - // Error stores the latest response configuration or injected error. - Error error - - // UseNetwork enables the use of real network for the current mock. - UseNetwork bool - - // StatusCode stores the response status code. - StatusCode int - - // Headers stores the response headers. - Header http.Header - - // Cookies stores the response cookie fields. - Cookies []*http.Cookie - - // BodyGen stores a io.ReadCloser generator to be returned. - BodyGen func() io.ReadCloser - - // BodyBuffer stores the array of bytes to use as body. - BodyBuffer []byte - - // ResponseDelay stores the simulated response delay. - ResponseDelay time.Duration - - // Mappers stores the request functions mappers used for matching. - Mappers []MapResponseFunc - - // Filters stores the request functions filters used for matching. - Filters []FilterResponseFunc -} +type Response = threadsafe.Response // NewResponse creates a new Response. func NewResponse() *Response { - return &Response{Header: make(http.Header)} -} - -// Status defines the desired HTTP status code to reply in the current response. -func (r *Response) Status(code int) *Response { - r.StatusCode = code - return r -} - -// Type defines the response Content-Type MIME header field. -// Supports type alias. E.g: json, xml, form, text... -func (r *Response) Type(kind string) *Response { - mime := BodyTypeAliases[kind] - if mime != "" { - kind = mime - } - r.Header.Set("Content-Type", kind) - return r -} - -// SetHeader sets a new header field in the mock response. -func (r *Response) SetHeader(key, value string) *Response { - r.Header.Set(key, value) - return r -} - -// AddHeader adds a new header field in the mock response -// with out removing an existent one. -func (r *Response) AddHeader(key, value string) *Response { - r.Header.Add(key, value) - return r -} - -// SetHeaders sets a map of header fields in the mock response. -func (r *Response) SetHeaders(headers map[string]string) *Response { - for key, value := range headers { - r.Header.Add(key, value) - } - return r -} - -// Body sets the HTTP response body to be used. -func (r *Response) Body(body io.Reader) *Response { - r.BodyBuffer, r.Error = ioutil.ReadAll(body) - return r -} - -// BodyGenerator accepts a io.ReadCloser generator, returning custom io.ReadCloser -// for every response. This will take priority than other Body methods used. -func (r *Response) BodyGenerator(generator func() io.ReadCloser) *Response { - r.BodyGen = generator - return r -} - -// BodyString defines the response body as string. -func (r *Response) BodyString(body string) *Response { - r.BodyBuffer = []byte(body) - return r -} - -// File defines the response body reading the data -// from disk based on the file path string. -func (r *Response) File(path string) *Response { - r.BodyBuffer, r.Error = ioutil.ReadFile(path) - return r -} - -// JSON defines the response body based on a JSON based input. -func (r *Response) JSON(data interface{}) *Response { - r.Header.Set("Content-Type", "application/json") - r.BodyBuffer, r.Error = readAndDecode(data, "json") - return r -} - -// XML defines the response body based on a XML based input. -func (r *Response) XML(data interface{}) *Response { - r.Header.Set("Content-Type", "application/xml") - r.BodyBuffer, r.Error = readAndDecode(data, "xml") - return r -} - -// SetError defines the response simulated error. -func (r *Response) SetError(err error) *Response { - r.Error = err - return r -} - -// Delay defines the response simulated delay. -// This feature is still experimental and will be improved in the future. -func (r *Response) Delay(delay time.Duration) *Response { - r.ResponseDelay = delay - return r -} - -// Map adds a new response mapper function to map http.Response before the matching process. -func (r *Response) Map(fn MapResponseFunc) *Response { - r.Mappers = append(r.Mappers, fn) - return r -} - -// Filter filters a new request filter function to filter http.Request before the matching process. -func (r *Response) Filter(fn FilterResponseFunc) *Response { - r.Filters = append(r.Filters, fn) - return r -} - -// EnableNetworking enables the use real networking for the current mock. -func (r *Response) EnableNetworking() *Response { - r.UseNetwork = true - return r -} - -// Done returns true if the mock was done and disabled. -func (r *Response) Done() bool { - return r.Mock.Done() -} - -func readAndDecode(data interface{}, kind string) ([]byte, error) { - buf := &bytes.Buffer{} - - switch data.(type) { - case string: - buf.WriteString(data.(string)) - case []byte: - buf.Write(data.([]byte)) - default: - var err error - if kind == "xml" { - err = xml.NewEncoder(buf).Encode(data) - } else { - err = json.NewEncoder(buf).Encode(data) - } - if err != nil { - return nil, err - } - } - - return ioutil.ReadAll(buf) + return g.NewResponse() } diff --git a/response_test.go b/response_test.go index 412ca53..c27781a 100644 --- a/response_test.go +++ b/response_test.go @@ -158,7 +158,7 @@ func TestResponseEnableNetworking(t *testing.T) { func TestResponseDone(t *testing.T) { res := NewResponse() - res.Mock = &Mocker{request: &Request{Counter: 1}, disabler: new(disabler)} + res.Mock = NewMock(&Request{Counter: 1}, res) st.Expect(t, res.Done(), false) res.Mock.Disable() st.Expect(t, res.Done(), true) diff --git a/store.go b/store.go index 7ed1316..adc5021 100644 --- a/store.go +++ b/store.go @@ -1,100 +1,46 @@ package gock -import ( - "sync" -) - -// storeMutex is used interally for store synchronization. -var storeMutex = sync.RWMutex{} - -// mocks is internally used to store registered mocks. -var mocks = []Mock{} - // Register registers a new mock in the current mocks stack. func Register(mock Mock) { - if Exists(mock) { - return - } - - // Make ops thread safe - storeMutex.Lock() - defer storeMutex.Unlock() - - // Expose mock in request/response for delegation - mock.Request().Mock = mock - mock.Response().Mock = mock - - // Registers the mock in the global store - mocks = append(mocks, mock) + g.Register(mock) } // GetAll returns the current stack of registered mocks. func GetAll() []Mock { - storeMutex.RLock() - defer storeMutex.RUnlock() - return mocks + return g.GetAll() } // Exists checks if the given Mock is already registered. func Exists(m Mock) bool { - storeMutex.RLock() - defer storeMutex.RUnlock() - for _, mock := range mocks { - if mock == m { - return true - } - } - return false + return g.Exists(m) } // Remove removes a registered mock by reference. func Remove(m Mock) { - for i, mock := range mocks { - if mock == m { - storeMutex.Lock() - mocks = append(mocks[:i], mocks[i+1:]...) - storeMutex.Unlock() - } - } + g.Remove(m) } // Flush flushes the current stack of registered mocks. func Flush() { - storeMutex.Lock() - defer storeMutex.Unlock() - mocks = []Mock{} + g.Flush() } // Pending returns an slice of pending mocks. func Pending() []Mock { - Clean() - storeMutex.RLock() - defer storeMutex.RUnlock() - return mocks + return g.Pending() } // IsDone returns true if all the registered mocks has been triggered successfully. func IsDone() bool { - return !IsPending() + return g.IsDone() } // IsPending returns true if there are pending mocks. func IsPending() bool { - return len(Pending()) > 0 + return g.IsPending() } // Clean cleans the mocks store removing disabled or obsolete mocks. func Clean() { - storeMutex.Lock() - defer storeMutex.Unlock() - - buf := []Mock{} - for _, mock := range mocks { - if mock.Done() { - continue - } - buf = append(buf, mock) - } - - mocks = buf + g.Clean() } diff --git a/store_test.go b/store_test.go index 4ab4c83..b40a078 100644 --- a/store_test.go +++ b/store_test.go @@ -8,36 +8,36 @@ import ( func TestStoreRegister(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock Register(mock) - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, mock.Request().Mock, mock) st.Expect(t, mock.Response().Mock, mock) } func TestStoreGetAll(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock store := GetAll() - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, len(store), 1) st.Expect(t, store[0], mock) } func TestStoreExists(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, Exists(mock), true) } func TestStorePending(t *testing.T) { defer after() New("foo") - st.Expect(t, mocks, Pending()) + st.Expect(t, GetAll(), Pending()) } func TestStoreIsPending(t *testing.T) { @@ -58,9 +58,9 @@ func TestStoreIsDone(t *testing.T) { func TestStoreRemove(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock := New("foo").Mock - st.Expect(t, len(mocks), 1) + st.Expect(t, len(GetAll()), 1) st.Expect(t, Exists(mock), true) Remove(mock) @@ -72,16 +72,16 @@ func TestStoreRemove(t *testing.T) { func TestStoreFlush(t *testing.T) { defer after() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) mock1 := New("foo").Mock mock2 := New("foo").Mock - st.Expect(t, len(mocks), 2) + st.Expect(t, len(GetAll()), 2) st.Expect(t, Exists(mock1), true) st.Expect(t, Exists(mock2), true) Flush() - st.Expect(t, len(mocks), 0) + st.Expect(t, len(GetAll()), 0) st.Expect(t, Exists(mock1), false) st.Expect(t, Exists(mock2), false) } diff --git a/transport.go b/transport.go index 5b2bba2..985e3ad 100644 --- a/transport.go +++ b/transport.go @@ -1,9 +1,9 @@ package gock import ( - "errors" "net/http" - "sync" + + "github.com/h2non/gock/threadsafe" ) // var mutex *sync.Mutex = &sync.Mutex{} @@ -19,7 +19,7 @@ var ( var ( // ErrCannotMatch store the error returned in case of no matches. - ErrCannotMatch = errors.New("gock: cannot match any request") + ErrCannotMatch = threadsafe.ErrCannotMatch ) // Transport implements http.RoundTripper, which fulfills single http requests issued by @@ -27,86 +27,9 @@ var ( // // gock's Transport encapsulates a given or default http.Transport for further // delegation, if needed. -type Transport struct { - // mutex is used to make transport thread-safe of concurrent uses across goroutines. - mutex sync.Mutex - - // Transport encapsulates the original http.RoundTripper transport interface for delegation. - Transport http.RoundTripper -} +type Transport = threadsafe.Transport // NewTransport creates a new *Transport with no responders. func NewTransport() *Transport { - return &Transport{Transport: NativeTransport} -} - -// RoundTrip receives HTTP requests and routes them to the appropriate responder. It is required to -// implement the http.RoundTripper interface. You will not interact with this directly, instead -// the *http.Client you are using will call it for you. -func (m *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - // Just act as a proxy if not intercepting - if !Intercepting() { - return m.Transport.RoundTrip(req) - } - - m.mutex.Lock() - defer Clean() - - var err error - var res *http.Response - - // Match mock for the incoming http.Request - mock, err := MatchMock(req) - if err != nil { - m.mutex.Unlock() - return nil, err - } - - // Invoke the observer with the intercepted http.Request and matched mock - if config.Observer != nil { - config.Observer(req, mock) - } - - // Verify if should use real networking - networking := shouldUseNetwork(req, mock) - if !networking && mock == nil { - m.mutex.Unlock() - trackUnmatchedRequest(req) - return nil, ErrCannotMatch - } - - // Ensure me unlock the mutex before building the response - m.mutex.Unlock() - - // Perform real networking via original transport - if networking { - res, err = m.Transport.RoundTrip(req) - // In no mock matched, continue with the response - if err != nil || mock == nil { - return res, err - } - } - - return Responder(req, mock.Response(), res) -} - -// CancelRequest is a no-op function. -func (m *Transport) CancelRequest(req *http.Request) {} - -func shouldUseNetwork(req *http.Request, mock Mock) bool { - if mock != nil && mock.Response().UseNetwork { - return true - } - if !config.Networking { - return false - } - if len(config.NetworkingFilters) == 0 { - return true - } - for _, filter := range config.NetworkingFilters { - if !filter(req) { - return false - } - } - return true + return g.NewTransport(NativeTransport) }