Skip to content

Commit

Permalink
fix(plugin test): close connection in plugin test environment
Browse files Browse the repository at this point in the history
- Addresses a connection leak in the internal plugin test environment's network connection instantiation
  • Loading branch information
zryanl committed Aug 26, 2024
1 parent f494b64 commit d1e6683
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
28 changes: 15 additions & 13 deletions bridge/bridgetest/bridgetest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func readPbFrame(conn net.Conn) (data []byte, err error) {
var len uint32
err = binary.Read(conn, binary.LittleEndian, &len)
if err != nil {
return
return nil, err
}

data = make([]byte, len)
Expand All @@ -30,21 +30,21 @@ func readPbFrame(conn net.Conn) (data []byte, err error) {
return nil, err
}

return
return data, nil
}

func writePbFrame(conn net.Conn, data []byte) (err error) {
var len uint32 = uint32(len(data))
err = binary.Write(conn, binary.LittleEndian, len)
if err != nil {
return
return err
}

if len > 0 {
_, err = conn.Write(data)
}

return
return err
}

func Mock(t *testing.T, s []MockStep) net.Conn {
Expand Down Expand Up @@ -119,37 +119,39 @@ func MockFunc(e mockEnvironment) net.Conn {
statusCh := make(chan string, 1)
e.SubscribeStatusChange(statusCh)

var err error
go func() {
for {
d, err := readPbFrame(conB)
if err != nil {
e.Errorf("Can't read method name")
break
}
method := string(d)

d, err = readPbFrame(conB)
if err != nil {
e.Errorf("Can't read method \"%v\" arguments", method)
break
}

d = e.Handle(method, d)

err = writePbFrame(conB, d)
if err != nil {
e.Errorf("Can't write back return values")
break
}
}

select {
case msg := <-statusCh:
if msg == "finished" {
return
}
default: // do nothing
select {
case msg := <-statusCh:
if msg == "finished" {
return
}
default:
if err != nil {
e.Errorf(err.Error())
}
}
conB.Close()
}()
return conA
}
40 changes: 25 additions & 15 deletions test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (req Request) clone() Request {
}
}

func mergeHeaders(h http.Header, in http.Header) http.Header {
func mergeHeaders(h, in http.Header) http.Header {
for k, l := range in {
h.Del(k)
for _, v := range l {
Expand Down Expand Up @@ -122,19 +122,23 @@ func (req *Request) Validate() error {
}
return nil
}
return fmt.Errorf("Unsupported method \"%v\"", req.Method)
return fmt.Errorf("unsupported method \"%v\"", req.Method)
}

func getPort(u *url.URL) int32 {
func getPort(u *url.URL) (int32, error) {
p := u.Port()
if p == "" {
if u.Scheme == "https" {
return 443
return 443, nil
}
return 80
return 80, nil
}
portnum, _ := strconv.Atoi(p)
return int32(portnum)
portnum, err := strconv.ParseInt(p, 10, 32)
if err != nil {
return 0, err
}

return int32(portnum), nil
}

// ToResponse creates a new Response object from a Request,
Expand Down Expand Up @@ -177,7 +181,7 @@ func (res *Response) merge(other Response) {
}

type Ctx struct {
Store map[string]interface{}
Store map[string]interface{}
}

type envState int
Expand All @@ -196,14 +200,14 @@ type TestEnv struct {
ServiceReq Request
ServiceRes Response
ClientRes Response
Ctx Ctx
Ctx Ctx
}

// New creates a new test environment.
func New(t *testing.T, req Request) (env *TestEnv, err error) {
err = req.Validate()
if err != nil {
return
return nil, err
}

env = &TestEnv{
Expand All @@ -213,7 +217,7 @@ func New(t *testing.T, req Request) (env *TestEnv, err error) {
ServiceReq: req.clone(),
ServiceRes: Response{Headers: make(http.Header)},
ClientRes: Response{Headers: make(http.Header)},
Ctx: Ctx{Store: make(map[string]interface{})},
Ctx: Ctx{Store: make(map[string]interface{})},
}

b := bridge.New(bridgetest.MockFunc(env)) // check
Expand All @@ -231,7 +235,7 @@ func New(t *testing.T, req Request) (env *TestEnv, err error) {
ServiceRequest: service_request.Request{PdkBridge: b},
ServiceResponse: service_response.Response{PdkBridge: b},
}
return
return env, nil
}

func (e *TestEnv) noErr(err error) {
Expand All @@ -254,6 +258,8 @@ func (e *TestEnv) Finish() {
if e.stateChange != nil {
e.stateChange <- "finished"
}

e.pdk.Ctx.Close()
}

func LowercaseHeaders(h http.Header) http.Header {
Expand Down Expand Up @@ -349,7 +355,9 @@ func (e *TestEnv) Handle(method string, args_d []byte) []byte {
case "kong.request.get_port":
u, err := url.Parse(e.ClientReq.Url)
e.noErr(err)
out = &kong_plugin_protocol.Int{V: getPort(u)}
p, err := getPort(u)
e.noErr(err)
out = &kong_plugin_protocol.Int{V: p}

case "kong.request.get_forwarded_scheme":
scheme := e.ClientReq.Headers.Get("X-Forwarded-Proto")
Expand All @@ -372,13 +380,15 @@ func (e *TestEnv) Handle(method string, args_d []byte) []byte {
case "kong.request.get_forwarded_port":
port := e.ClientReq.Headers.Get("X-Forwarded-Port")
if port != "" {
p, err := strconv.Atoi(port)
p, err := strconv.ParseInt(port, 10, 32)
e.noErr(err)
out = &kong_plugin_protocol.Int{V: int32(p)}
} else {
u, err := url.Parse(e.ClientReq.Url)
e.noErr(err)
out = &kong_plugin_protocol.Int{V: getPort(u)}
p, err := getPort(u)
e.noErr(err)
out = &kong_plugin_protocol.Int{V: p}
}

case "kong.request.get_http_version":
Expand Down

0 comments on commit d1e6683

Please sign in to comment.