Skip to content

Commit

Permalink
Merge pull request #24 from 9seconds/fix-23
Browse files Browse the repository at this point in the history
Correctly restore HTTP protocol on doing a request
  • Loading branch information
9seconds authored Jun 15, 2021
2 parents f01fdd3 + fc89ea6 commit 83abecf
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 25 deletions.
4 changes: 2 additions & 2 deletions headers/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (suite *HeadersTestSuite) TearDownTest() {
}

func (suite *HeadersTestSuite) TestCheckHeaders() {
suite.Len(suite.hdrs.Headers, 2)
suite.Len(suite.hdrs.Headers, 3)

headerNames := map[string]bool{
"Accept-Encoding": true,
Expand Down Expand Up @@ -147,7 +147,7 @@ func (suite *HeadersTestSuite) TestSetNoCleanup() {

func (suite *HeadersTestSuite) TestSetUnknown() {
suite.hdrs.Set("hello", "NewValue", false)
suite.Len(suite.hdrs.Headers, 3)
suite.Len(suite.hdrs.Headers, 4)

header := suite.hdrs.GetFirst("hello")

Expand Down
12 changes: 11 additions & 1 deletion headers/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func (r requestHeaderWrapper) Read(rd io.Reader) error {
method := append([]byte(nil), r.ref.Method()...)
requestURI := append([]byte(nil), r.ref.RequestURI()...)
host := append([]byte(nil), r.ref.Host()...)
protocol := append([]byte(nil), r.ref.Protocol()...)

r.ref.Reset()
r.ref.DisableNormalizing()
Expand All @@ -26,6 +27,7 @@ func (r requestHeaderWrapper) Read(rd io.Reader) error {
return errors.Annotate(err, "cannot read request headers", "headers_sync", 0)
}

r.ref.SetProtocolBytes(protocol)
r.ref.SetHostBytes(host)
r.ref.SetMethodBytes(method)
r.ref.SetRequestURIBytes(requestURI)
Expand All @@ -42,7 +44,15 @@ func (r requestHeaderWrapper) ResetConnectionClose() {
}

func (r requestHeaderWrapper) Headers() []byte {
return r.ref.RawHeaders()
buf := append([]byte(nil), r.ref.Method()...)
buf = append(buf, ' ')
buf = append(buf, r.ref.RequestURI()...)
buf = append(buf, ' ')
buf = append(buf, r.ref.Protocol()...)
buf = append(buf, '\r', '\n')
buf = append(buf, r.ref.RawHeaders()...)

return buf
}

type responseHeaderWrapper struct {
Expand Down
7 changes: 4 additions & 3 deletions headers/wrappers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ func (suite *RequestHeaderWrapperTestSuite) TestRawHeaders() {
suite.hdr.SetConnectionClose()

request := []string{
"GET http://example.com HTTP/1.1",
"Host: example.com",
"accept: deflate",
"connection: close",
}
fullRequest := strings.Join(append([]string{"GET / HTTP/1.1"}, request...), "\r\n") + "\r\n\r\n"
fullRequest := strings.Join(request, "\r\n") + "\r\n\r\n"

suite.NoError(suite.wrp.Read(strings.NewReader(fullRequest)))
suite.Equal([]byte(strings.Join(request, "\r\n")+"\r\n\r\n"), suite.wrp.Headers())
Expand Down Expand Up @@ -116,7 +117,7 @@ func (suite *ResponseWrapperTestSuite) TestCorrectRestore() {
}, "\r\n") + "\r\n\r\n"

suite.NoError(suite.wrp.Read(strings.NewReader(request)))
suite.Equal(fasthttp.StatusCreated, suite.hdr.StatusCode())
suite.Equal(fasthttp.StatusCreated, suite.hdr.StatusCode())
}

func (suite *ResponseWrapperTestSuite) TestDisableNormalizing() {
Expand All @@ -137,5 +138,5 @@ func TestRequestHeaderWrapper(t *testing.T) {
}

func TestResponseHeaderWrapper(t *testing.T) {
suite.Run(t, &ResponseWrapperTestSuite{})
suite.Run(t, &ResponseWrapperTestSuite{})
}
19 changes: 0 additions & 19 deletions layers/ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,6 @@ func (suite *ContextTestSuite) TestRespond() {
suite.Equal("text/plain", string(resp.Header.ContentType()))
}

func (suite *ContextTestSuite) TestErrorRequest() {
ctx := layers.AcquireContext()
defer layers.ReleaseContext(ctx)

fhttpCtx := &fasthttp.RequestCtx{}
remoteAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 65342,
}

fhttpCtx.Init(&fasthttp.Request{}, remoteAddr, nil)

suite.Error(ctx.Init(fhttpCtx,
"127.0.0.1:8000",
suite.eventsChannel,
"user",
events.RequestTypeTLS))
}

func (suite *ContextTestSuite) TestErrorGeneril() {
suite.ctx.Error(io.EOF)

Expand Down
7 changes: 7 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ func (suite *ServerTestSuite) TestHTTPSAuthRequired() {
suite.Error(err)
}

func (suite *ServerTestSuite) TestGolangOrg() {
resp, err := suite.http.Get("https://golang.org")

suite.NoError(err)
suite.Equal(http.StatusOK, resp.StatusCode)
}

func TestServer(t *testing.T) {
suite.Run(t, &ServerTestSuite{})
}

0 comments on commit 83abecf

Please sign in to comment.