diff --git a/bridge/bridgetest/bridgetest.go b/bridge/bridgetest/bridgetest.go index ca95e0f..5b050b3 100644 --- a/bridge/bridgetest/bridgetest.go +++ b/bridge/bridgetest/bridgetest.go @@ -20,7 +20,7 @@ func readPbFrame(conn net.Conn) (data []byte, err error) { var len uint32 err = binary.Read(conn, binary.LittleEndian, &len) if err != nil { - return + return nil, err } data = make([]byte, len) @@ -30,21 +30,21 @@ func readPbFrame(conn net.Conn) (data []byte, err error) { return nil, err } - return + return data, nil } func writePbFrame(conn net.Conn, data []byte) (err error) { var len uint32 = uint32(len(data)) err = binary.Write(conn, binary.LittleEndian, len) if err != nil { - return + return err } if len > 0 { _, err = conn.Write(data) } - return + return err } func Mock(t *testing.T, s []MockStep) net.Conn { @@ -119,18 +119,17 @@ func MockFunc(e mockEnvironment) net.Conn { statusCh := make(chan string, 1) e.SubscribeStatusChange(statusCh) + var err error go func() { for { d, err := readPbFrame(conB) if err != nil { - e.Errorf("Can't read method name") break } method := string(d) d, err = readPbFrame(conB) if err != nil { - e.Errorf("Can't read method \"%v\" arguments", method) break } @@ -138,18 +137,21 @@ func MockFunc(e mockEnvironment) net.Conn { err = writePbFrame(conB, d) if err != nil { - e.Errorf("Can't write back return values") break } + } - select { - case msg := <-statusCh: - if msg == "finished" { - return - } - default: // do nothing + select { + case msg := <-statusCh: + if msg == "finished" { + return + } + default: + if err != nil { + e.Errorf(err.Error()) } } + conB.Close() }() return conA } diff --git a/test/test.go b/test/test.go index 6b23a44..81eb3dc 100644 --- a/test/test.go +++ b/test/test.go @@ -90,7 +90,7 @@ func (req Request) clone() Request { } } -func mergeHeaders(h http.Header, in http.Header) http.Header { +func mergeHeaders(h, in http.Header) http.Header { for k, l := range in { h.Del(k) for _, v := range l { @@ -122,19 +122,23 @@ func (req *Request) Validate() error { } return nil } - return fmt.Errorf("Unsupported method \"%v\"", req.Method) + return fmt.Errorf("unsupported method \"%v\"", req.Method) } -func getPort(u *url.URL) int32 { +func getPort(u *url.URL) (int32, error) { p := u.Port() if p == "" { if u.Scheme == "https" { - return 443 + return 443, nil } - return 80 + return 80, nil } - portnum, _ := strconv.Atoi(p) - return int32(portnum) + portnum, err := strconv.ParseInt(p, 10, 32) + if err != nil { + return 0, err + } + + return int32(portnum), nil } // ToResponse creates a new Response object from a Request, @@ -177,7 +181,7 @@ func (res *Response) merge(other Response) { } type Ctx struct { - Store map[string]interface{} + Store map[string]interface{} } type envState int @@ -196,14 +200,14 @@ type TestEnv struct { ServiceReq Request ServiceRes Response ClientRes Response - Ctx Ctx + Ctx Ctx } // New creates a new test environment. func New(t *testing.T, req Request) (env *TestEnv, err error) { err = req.Validate() if err != nil { - return + return nil, err } env = &TestEnv{ @@ -213,7 +217,7 @@ func New(t *testing.T, req Request) (env *TestEnv, err error) { ServiceReq: req.clone(), ServiceRes: Response{Headers: make(http.Header)}, ClientRes: Response{Headers: make(http.Header)}, - Ctx: Ctx{Store: make(map[string]interface{})}, + Ctx: Ctx{Store: make(map[string]interface{})}, } b := bridge.New(bridgetest.MockFunc(env)) // check @@ -231,7 +235,7 @@ func New(t *testing.T, req Request) (env *TestEnv, err error) { ServiceRequest: service_request.Request{PdkBridge: b}, ServiceResponse: service_response.Response{PdkBridge: b}, } - return + return env, nil } func (e *TestEnv) noErr(err error) { @@ -254,6 +258,8 @@ func (e *TestEnv) Finish() { if e.stateChange != nil { e.stateChange <- "finished" } + + e.pdk.Ctx.Close() } func LowercaseHeaders(h http.Header) http.Header { @@ -349,7 +355,9 @@ func (e *TestEnv) Handle(method string, args_d []byte) []byte { case "kong.request.get_port": u, err := url.Parse(e.ClientReq.Url) e.noErr(err) - out = &kong_plugin_protocol.Int{V: getPort(u)} + p, err := getPort(u) + e.noErr(err) + out = &kong_plugin_protocol.Int{V: p} case "kong.request.get_forwarded_scheme": scheme := e.ClientReq.Headers.Get("X-Forwarded-Proto") @@ -372,13 +380,15 @@ func (e *TestEnv) Handle(method string, args_d []byte) []byte { case "kong.request.get_forwarded_port": port := e.ClientReq.Headers.Get("X-Forwarded-Port") if port != "" { - p, err := strconv.Atoi(port) + p, err := strconv.ParseInt(port, 10, 32) e.noErr(err) out = &kong_plugin_protocol.Int{V: int32(p)} } else { u, err := url.Parse(e.ClientReq.Url) e.noErr(err) - out = &kong_plugin_protocol.Int{V: getPort(u)} + p, err := getPort(u) + e.noErr(err) + out = &kong_plugin_protocol.Int{V: p} } case "kong.request.get_http_version":