Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http_proxy: support multi-addresses listening #654

Closed
wants to merge 8 commits into from
6 changes: 6 additions & 0 deletions bind/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix
"The server address to listen on. "+
"If the host is empty, the server will listen on all available interfaces. ")

fs.StringSliceVarP(&cfg.OptionalAddrs,
namePrefix+"optional-addresses", "", cfg.OptionalAddrs, "<host:port,...>"+
"Optional server addresses to listen on. "+
"The server will continue to run if the bind fails. "+
"Can be specified multiple times.")

if schemes == nil {
schemes = []forwarder.Scheme{
forwarder.HTTPScheme,
Expand Down
5 changes: 5 additions & 0 deletions e2e/forwarder/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ func GRPCTestService() *Service {
return s.WithProtocol("h2")
}

func (s *Service) WithOptionalAddresses(addresses ...string) *Service {
s.Environment["FORWARDER_OPTIONAL_ADDRESSES"] = strings.Join(addresses, ",")
return s
}

func (s *Service) WithProtocol(protocol string) *Service {
s.Environment["FORWARDER_PROTOCOL"] = protocol

Expand Down
16 changes: 16 additions & 0 deletions e2e/setups.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func AllSetups() []setup.Setup {
SetupFlagDenyDomains(l)
SetupFlagDirectDomains(l)
SetupFlagRateLimit(l)
SetupFlagOptionalAddresses(l)
SetupSC2450(l)

return l.Build()
Expand Down Expand Up @@ -453,6 +454,21 @@ func SetupFlagRateLimit(l *setupList) {
)
}

func SetupFlagOptionalAddresses(l *setupList) {
l.Add(
setup.Setup{
Name: "flag-address-multiple-ports",
Compose: compose.NewBuilder().
AddService(
forwarder.HttpbinService()).
AddService(
forwarder.ProxyService().
WithOptionalAddresses(":4567,:5678")).
MustBuild(),
Run: "^TestFlagOptionalAddresses$",
})
}

func SetupSC2450(l *setupList) {
l.Add(setup.Setup{
Name: "sc-2450",
Expand Down
14 changes: 14 additions & 0 deletions e2e/tests/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/saucelabs/forwarder/e2e/forwarder"
"github.com/saucelabs/forwarder/utils/httpexpect"
)

func TestFlagProxyLocalhost(t *testing.T) {
Expand Down Expand Up @@ -224,3 +225,16 @@ func testRateLimitHelper(t *testing.T, workers int, expectedTime, epsilon time.D
t.Fatalf("Expected request to take approximately %s, took %s", expectedTime, elapsed)
}
}

func TestFlagOptionalAddresses(t *testing.T) {
for _, port := range []string{"3128", "4567", "5678"} {
newClient(t, httpbin, func(tr *http.Transport) {
proxy := serviceScheme("FORWARDER_PROTOCOL") + "://proxy:" + port
proxyURL, err := httpexpect.NewURLWithBasicAuth(proxy, basicAuth)
if err != nil {
t.Fatal(err)
}
tr.Proxy = http.ProxyURL(proxyURL)
}).GET("/status/200").ExpectStatus(http.StatusOK)
}
}
7 changes: 5 additions & 2 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher,
}
hp.listener = l

hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol)
for _, addr := range l.Addrs() {
hp.log.Infof("PROXY server listen address=%s protocol=%s", addr, hp.config.Protocol)
}

return hp, nil
}
Expand Down Expand Up @@ -602,7 +604,7 @@ func (hp *HTTPProxy) Run(ctx context.Context) error {
return nil
}

func (hp *HTTPProxy) listen() (net.Listener, error) {
func (hp *HTTPProxy) listen() (*Listener, error) {
switch hp.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
default:
Expand All @@ -611,6 +613,7 @@ func (hp *HTTPProxy) listen() (net.Listener, error) {

l := Listener{
Address: hp.config.Addr,
OptionalAddresses: hp.config.OptionalAddrs,
Log: hp.log,
TLSConfig: hp.tlsConfig,
TLSHandshakeTimeout: hp.config.TLSServerConfig.HandshakeTimeout,
Expand Down
36 changes: 20 additions & 16 deletions http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ func h2TLSConfigTemplate() *tls.Config {
}

type HTTPServerConfig struct {
Protocol Scheme
Addr string
Protocol Scheme
Addr string
OptionalAddrs []string
TLSServerConfig
IdleTimeout time.Duration
ReadTimeout time.Duration
Expand Down Expand Up @@ -149,7 +150,9 @@ func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTP
}
hs.listener = l

hs.log.Infof("HTTP server listen address=%s protocol=%s", l.Addr(), hs.config.Protocol)
for _, addr := range l.Addrs() {
hs.log.Infof("HTTP server listen address=%s protocol=%s", addr, hs.config.Protocol)
}

return hs, nil
}
Expand Down Expand Up @@ -212,10 +215,8 @@ func (hs *HTTPServer) Run(ctx context.Context) error {

var srvErr error
switch hs.config.Protocol {
case HTTPScheme:
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
srvErr = hs.srv.Serve(hs.listener)
case HTTP2Scheme, HTTPSScheme:
srvErr = hs.srv.ServeTLS(hs.listener, "", "")
default:
return fmt.Errorf("invalid protocol %q", hs.config.Protocol)
}
Expand All @@ -231,17 +232,20 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
return nil
}

func (hs *HTTPServer) listen() (net.Listener, error) {
switch hs.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
listener, err := Listen("tcp", hs.srv.Addr)
if err != nil {
return nil, fmt.Errorf("failed to open listener on address %s: %w", hs.srv.Addr, err)
}
return listener, nil
default:
return nil, fmt.Errorf("invalid protocol %q", hs.config.Protocol)
func (hs *HTTPServer) listen() (*Listener, error) {
l := Listener{
Address: hs.config.Addr,
OptionalAddresses: hs.config.OptionalAddrs,
Log: hs.log,
TLSConfig: hs.srv.TLSConfig,
TLSHandshakeTimeout: hs.config.HandshakeTimeout,
}

if err := l.Listen(); err != nil {
return nil, err
}

return &l, nil
}

// Addr returns the address the server is listening on.
Expand Down
109 changes: 101 additions & 8 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
"context"
"crypto/tls"
"net"
"sync"
"syscall"
"time"

"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/ratelimit"
"go.uber.org/multierr"
)

type DialConfig struct {
Expand Down Expand Up @@ -87,25 +89,64 @@ type ListenerCallbacks interface {
// OnAccept is called when a new connection is successfully accepted.
OnAccept(net.Conn)

// OnBindError is called when a listener fails to bind to an address.
OnBindError(address string, err error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should go before TLS things.


// OnTLSHandshakeError is called after a TLS handshake errors out.
OnTLSHandshakeError(*tls.Conn, error)
}

// Listener is a multi-address listener with TLS support, rate limiting and callbacks.
// The Listener must successfully bind on Address, but it may fail to bind on OptionalAddresses.
type Listener struct {
Address string
OptionalAddresses []string
Log log.Logger
TLSConfig *tls.Config
TLSHandshakeTimeout time.Duration
ReadLimit int64
WriteLimit int64
Callbacks ListenerCallbacks

listener net.Listener
listeners []net.Listener
acceptCh chan acceptResult
wg sync.WaitGroup
closeCh chan struct{}
closeOnce sync.Once
}

type acceptResult struct {
c net.Conn
err error
}

// Listen starts listening on the provided addresses.
// The method should be called only once.
func (l *Listener) Listen() error {
ll, err := Listen("tcp", l.Address)
l.acceptCh = make(chan acceptResult)
l.closeCh = make(chan struct{})

if err := l.listen(l.Address); err != nil {
return err
}

// OptionalAddresses may fail to bind.
for _, addr := range l.OptionalAddresses {
if err := l.listen(addr); err != nil {
l.Log.Errorf("failed to listen on %s: %v", addr, err)
continue
}
}

return nil
}

func (l *Listener) listen(addr string) error {
ll, err := Listen("tcp", addr)
if err != nil {
if l.Callbacks != nil {
l.Callbacks.OnBindError(addr, err)
}
return err
}

Expand All @@ -116,13 +157,42 @@ func (l *Listener) Listen() error {
ll = ratelimit.NewListener(ll, wl, rl)
}

l.listener = ll
l.listeners = append(l.listeners, ll)
l.wg.Add(1)
go l.acceptLoop(ll)

return nil
}

func (l *Listener) acceptLoop(ll net.Listener) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use errgroup with anonymous function instead of creating this and closeCh.

defer l.wg.Done()
for {
c, err := ll.Accept()
select {
case l.acceptCh <- acceptResult{c, err}:
case <-l.closeCh:
if c != nil {
if cerr := c.Close(); cerr != nil {
l.Log.Errorf("failed to close connection: %v", cerr)
}
}
return
}
}
}

func (l *Listener) Accept() (net.Conn, error) {
for {
c, err := l.listener.Accept()
var (
c net.Conn
err error
)
select {
case <-l.closeCh:
return nil, net.ErrClosed
case res := <-l.acceptCh:
c, err = res.c, res.err
}
if err != nil {
return nil, err
}
Expand All @@ -137,9 +207,9 @@ func (l *Listener) Accept() (net.Conn, error) {

tc, err := l.withTLS(c)
if err != nil {
l.Log.Errorf("Failed to perform TLS handshake: %v", err)
l.Log.Errorf("failed to perform TLS handshake: %v", err)
if cerr := tc.Close(); cerr != nil {
l.Log.Errorf("Failed to close TLS connection: %v", cerr)
l.Log.Errorf("failed to close connection: %v", cerr)
}
continue
}
Expand Down Expand Up @@ -169,9 +239,32 @@ func (l *Listener) withTLS(conn net.Conn) (*tls.Conn, error) {
}

func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
if len(l.listeners) == 0 {
return &net.IPAddr{}
}

return l.listeners[0].Addr()
}

func (l *Listener) Addrs() []net.Addr {
addrs := make([]net.Addr, 0, len(l.listeners))
for _, ll := range l.listeners {
addrs = append(addrs, ll.Addr())
}
return addrs
}

func (l *Listener) Close() error {
return l.listener.Close()
l.closeOnce.Do(func() { close(l.closeCh) })

var merr error
for _, ll := range l.listeners {
if err := ll.Close(); err != nil {
merr = multierr.Append(merr, err)
}
}

l.wg.Wait()

return merr
}
Loading