diff --git a/pkg/healthsupport/health_support.go b/pkg/healthsupport/health_support.go index b53e096e..a99554c7 100644 --- a/pkg/healthsupport/health_support.go +++ b/pkg/healthsupport/health_support.go @@ -29,10 +29,19 @@ type response struct { Pass string `json:"pass"` } -func HealthHandlerFunction(w http.ResponseWriter, r *http.Request) { - checks := make([]HealthCheck, 0) - checks = append(checks, &NoopCheck{}) - HealthHandlerFunctionWithChecks(w, r, checks) +type HealthInfo struct { + Status string `json:"status"` +} + +type httpClient interface { + Get(url string) (*http.Response, error) +} + +func HealthHandlerFunction(w http.ResponseWriter, _ *http.Request) { + data, _ := json.Marshal(&HealthInfo{"pass"}) + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } func HealthHandlerFunctionWithChecks(w http.ResponseWriter, _ *http.Request, checks []HealthCheck) { // todo - move to func? @@ -50,10 +59,10 @@ func HealthHandlerFunctionWithChecks(w http.ResponseWriter, _ *http.Request, che } func WaitForHealthy(server *http.Server) { - WaitForHealthyWithClient(server, &http.Client{}, fmt.Sprintf("http://%s/health", server.Addr)) + WaitForHealthyWithClient(server, http.DefaultClient, fmt.Sprintf("http://%s/health", server.Addr)) } -func WaitForHealthyWithClient(server *http.Server, client *http.Client, url string) { +func WaitForHealthyWithClient(server *http.Server, client httpClient, url string) { var isLive bool for !isLive { resp, err := client.Get(url) diff --git a/pkg/healthsupport/health_support_test.go b/pkg/healthsupport/health_support_test.go index 1c450e23..fa2d799e 100644 --- a/pkg/healthsupport/health_support_test.go +++ b/pkg/healthsupport/health_support_test.go @@ -28,6 +28,37 @@ func TestHealth(t *testing.T) { resp, _ := http.Get(fmt.Sprintf("http://%s/health", server.Addr)) body, _ := io.ReadAll(resp.Body) - assert.Equal(t, "[{\"name\":\"noop\",\"pass\":\"true\"}]", string(body)) + assert.Equal(t, "{\"status\":\"pass\"}", string(body)) _ = server.Shutdown(context.Background()) } + +func TestWaitForHealthyWithClient(t *testing.T) { + listener, _ := net.Listen("tcp", "localhost:0") + router := mux.NewRouter() + router.HandleFunc("/health", healthsupport.HealthHandlerFunction).Methods("GET") + server := &http.Server{ + Addr: listener.Addr().String(), + Handler: router, + } + go func() { + _ = server.Serve(listener) + }() + defer server.Shutdown(context.Background()) + + client := &mockClient{} + healthsupport.WaitForHealthyWithClient(server, client, fmt.Sprintf("http://%s/health", server.Addr)) + + resp, _ := http.Get(fmt.Sprintf("http://%s/health", server.Addr)) + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "{\"status\":\"pass\"}", string(body)) + assert.True(t, client.getCalled) +} + +type mockClient struct { + getCalled bool +} + +func (m *mockClient) Get(_ string) (*http.Response, error) { + m.getCalled = true + return &http.Response{StatusCode: http.StatusOK}, nil +}