Skip to content

Commit

Permalink
[+] handle timeout for proxy connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Noooste committed Jan 4, 2024
1 parent 561805b commit 2fb64bd
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 16 deletions.
69 changes: 55 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ type Conn struct {

ClientHelloSpec func() *tls.ClientHelloSpec

mu *sync.RWMutex
ctx context.Context
mu *sync.RWMutex

ctx context.Context
cancel context.CancelFunc
}

/*
Expand Down Expand Up @@ -79,6 +81,8 @@ func (cp *ConnPool) Close() {
for _, c := range cp.hosts {
c.Close()
}
cp.hosts = nil
cp.mu = nil
}

func getHost(u *url.URL) string {
Expand Down Expand Up @@ -201,9 +205,54 @@ func (c *Conn) Close() {
_ = c.HTTP2.Close()
c.HTTP2 = nil
}

if c.cancel != nil {
c.cancel()
c.cancel = nil
}
c.PinManager = nil
}

func (s *Session) getProxyConn(conn *Conn, host string) (err error) {
ctx, cancel := context.WithCancel(s.ctx)
s.ProxyDialer.ForceHTTP2 = s.H2Proxy
s.ProxyDialer.tr = s.HTTP2Transport
s.ProxyDialer.Dialer.Timeout = conn.TimeOut

timer := time.NewTimer(conn.TimeOut)
defer timer.Stop()

connChan := make(chan net.Conn, 1)
errChan := make(chan error, 1)

defer close(connChan)
defer close(errChan)

go func() {
proxyConn, dialErr := s.ProxyDialer.DialContext(ctx, "tcp", host)
if dialErr != nil {
errChan <- dialErr
}
connChan <- proxyConn
}()

select {
case <-timer.C:
cancel()
return errors.New("proxy connection timeout")

case c := <-connChan:
conn.Conn = c
conn.cancel = cancel

case err = <-errChan:
cancel()
return err
}

return nil
}

func (s *Session) initConn(req *Request) (conn *Conn, err error) {
// get connection from pool
conn = s.Connections.Get(req.parsedUrl)
Expand All @@ -228,20 +277,12 @@ func (s *Session) initConn(req *Request) (conn *Conn, err error) {
defer conn.mu.Unlock()

if conn.Conn == nil {
var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)

if s.ProxyDialer != nil {
s.ProxyDialer.ForceHTTP2 = s.H2Proxy
s.ProxyDialer.tr = s.HTTP2Transport
dialContext = s.ProxyDialer.DialContext
s.ProxyDialer.Dialer.Timeout = conn.TimeOut
if err = s.getProxyConn(conn, host); err != nil {
return nil, err
}
} else {
dialContext = (&net.Dialer{Timeout: conn.TimeOut}).DialContext
}

conn.Conn, err = dialContext(s.ctx, "tcp", host)
if err != nil {
return
conn.Conn, err = (&net.Dialer{Timeout: conn.TimeOut}).DialContext(s.ctx, "tcp", host)
}
}

Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (s *Session) do(req *Request, args ...any) (resp *Response, err error) {
req.ctx = s.ctx
}

var reqs = make([]*Request, 0, 10)
var reqs = make([]*Request, 0, req.MaxRedirects+1)

var (
redirectMethod string
Expand Down
2 changes: 2 additions & 0 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type Session struct {
ctx context.Context // Context for cancellable and timeout operations.

UserAgent string // Headers for User-Agent and Sec-Ch-Ua, respectively.

closed bool
}

// Request represents the details and configuration for an individual HTTP(S)
Expand Down
18 changes: 18 additions & 0 deletions tests/connection_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/Noooste/azuretls-client"
"os"
"testing"
"time"
)

var skipProxy bool
Expand Down Expand Up @@ -192,3 +193,20 @@ func TestProxy4(t *testing.T) {
t.Fatal("TestProxy failed, IP is not changed")
}
}

func TestBadProxy(t *testing.T) {
session := azuretls.NewSession()
session.SetTimeout(1 * time.Second)
defer session.Close()

if err := session.SetProxy("https://test.com"); err != nil {
t.Fatal(err)
}

_, err := session.Get("https://ipinfo.io/ip")

if err == nil || err.Error() != "proxy connection timeout" {
t.Fatal("TestBadProxy failed, expected error, got", err)
}

}
2 changes: 1 addition & 1 deletion tests/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ func TestSession_Connect(t *testing.T) {
func TestSession_TooManyRedirects(t *testing.T) {
session := azuretls.NewSession()

resp, err := session.Get("https://httpbin.org/redirect/10")
resp, err := session.Get("https://httpbin.org/redirect/11")

if err == nil || !strings.Contains(err.Error(), "too many Redirects") {
t.Fatal("TestSession_TooManyRedirects failed, expected: too many Redirects, got: ", err)
Expand Down

0 comments on commit 2fb64bd

Please sign in to comment.