From 789d87c6aa016936835e19b6f21496ea1924248d Mon Sep 17 00:00:00 2001 From: Carl Amko Date: Tue, 26 Sep 2023 07:49:14 -0600 Subject: [PATCH] Update forwarding API --- CHANGELOG.md | 5 + VERSION | 2 +- config/common.go | 12 ++ config/config_test.go | 58 ++------- config/forwards_to.go | 21 ++-- config/http.go | 10 +- config/http_handler.go | 28 +++-- config/http_handler_test.go | 50 -------- config/labeled.go | 11 +- config/tcp.go | 11 +- config/tls.go | 12 +- config/tunnel_config.go | 1 + examples/go.mod | 2 +- examples/go.sum | 4 +- examples/ngrok-forward-lite/main.go | 75 ++++++++++++ examples/ngrok-http-lite/main.go | 67 ++++++++++ examples/ngrok-lite/main.go | 76 ------------ forward.go | 181 ++++++++++++++++++++++++++++ go.mod | 16 ++- go.sum | 47 +++----- go.work.sum | 22 +++- online_test.go | 17 ++- session.go | 114 ++++++++++++++---- tunnel.go | 86 ++++++++++++- tunnel_config.go | 1 + 25 files changed, 637 insertions(+), 292 deletions(-) delete mode 100644 config/http_handler_test.go create mode 100644 examples/ngrok-forward-lite/main.go create mode 100644 examples/ngrok-http-lite/main.go delete mode 100644 examples/ngrok-lite/main.go create mode 100644 forward.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 64b7d63..620f74d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 1.5.0 + +- Added new forwarding API. See `[Session].ListenAndForward` and `[Session].ListenAndServeHTTP`. +- Deprecates `WithHTTPServer` and `WithHTTPHandler`. Use `[Session].ListenAndServeHTTP` instead. + ## 1.4.0 - Switch to `connect.ngrok-agent.com:443` as the default server address diff --git a/VERSION b/VERSION index 347f583..bc80560 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.1 +1.5.0 diff --git a/config/common.go b/config/common.go index c289a81..30e05d9 100644 --- a/config/common.go +++ b/config/common.go @@ -14,9 +14,21 @@ type commonOpts struct { ForwardsTo string } +type CommonOptionsFunc func(cfg *commonOpts) + +type CommonOption interface { + ApplyCommon(cfg *commonOpts) +} + +func (of CommonOptionsFunc) ApplyCommon(cfg *commonOpts) { + of(cfg) +} + func (cfg *commonOpts) getForwardsTo() string { if cfg.ForwardsTo == "" { return defaultForwardsTo() } return cfg.ForwardsTo } + +func (cfg commonOpts) tunnelOptions() {} diff --git a/config/config_test.go b/config/config_test.go index bcc28b2..b163b4d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,6 @@ package config import ( - "net/http" "reflect" "testing" @@ -20,14 +19,6 @@ func assertSlice[T any](opts []any) []T { return out } -func handlerPtr(h http.Handler) *http.Handler { - return &h -} - -func serverPtr(srv *http.Server) **http.Server { - return &srv -} - func labelPtr(labels map[string]*string) *map[string]*string { return &labels } @@ -55,16 +46,14 @@ func (m matchBindExtra) RequireMatches(t *testing.T, actual proto.BindExtra) { } type testCase[T tunnelConfigPrivate, O any] struct { - name string - opts Tunnel - expectForwardsTo *string - expectProto *string - expectExtra *matchBindExtra - expectLabels *map[string]*string - expectHTTPServer **http.Server - expectHTTPHandler *http.Handler - expectOpts func(t *testing.T, opts *O) - expectNilOpts bool + name string + opts Tunnel + expectForwardsTo *string + expectProto *string + expectExtra *matchBindExtra + expectLabels *map[string]*string + expectOpts func(t *testing.T, opts *O) + expectNilOpts bool } type testCases[T tunnelConfigPrivate, O any] []testCase[T, O] @@ -101,37 +90,6 @@ func (tc testCase[T, O]) Run(t *testing.T) { require.Truef(t, ok, "Opts has the type %v", reflect.TypeOf((*O)(nil))) tc.expectOpts(t, opts) } - - if tc.expectHTTPServer != nil { - withHTTPServer, ok := tc.opts.(interface { - HTTPServer() *http.Server - }) - if *tc.expectHTTPServer != nil { - require.True(t, ok, "opts should have the HTTPServer method") - actual := withHTTPServer.HTTPServer() - require.Equal(t, *tc.expectHTTPServer, actual) - } else if ok { - require.Nil(t, withHTTPServer.HTTPServer()) - } - } - - if tc.expectHTTPHandler != nil { - withHTTPServer, ok := tc.opts.(interface { - HTTPServer() *http.Server - }) - if *tc.expectHTTPHandler != nil { - require.True(t, ok, "opts should have the HTTPServer method") - actualServer := withHTTPServer.HTTPServer() - require.NotNil(t, actualServer) - actual := actualServer.Handler - require.Equal(t, *tc.expectHTTPHandler, actual) - } else if ok { - actualServer := withHTTPServer.HTTPServer() - if actualServer != nil { - require.Nil(t, actualServer.Handler) - } - } - } }) } diff --git a/config/forwards_to.go b/config/forwards_to.go index 3efd167..ba998f1 100644 --- a/config/forwards_to.go +++ b/config/forwards_to.go @@ -6,33 +6,32 @@ import ( "path/filepath" ) +type forwardsToOption string + // WithForwardsTo sets the ForwardsTo string for this tunnel. // This can be veiwed via the API or dashboard. -func WithForwardsTo(meta string) interface { - HTTPEndpointOption - LabeledTunnelOption - TCPEndpointOption - TLSEndpointOption -} { +func WithForwardsTo(meta string) Options { return forwardsToOption(meta) } -type forwardsToOption string +func (fwd forwardsToOption) ApplyCommon(cfg *commonOpts) { + cfg.ForwardsTo = string(fwd) +} func (fwd forwardsToOption) ApplyHTTP(cfg *httpOptions) { - cfg.commonOpts.ForwardsTo = string(fwd) + fwd.ApplyCommon(&cfg.commonOpts) } func (fwd forwardsToOption) ApplyTCP(cfg *tcpOptions) { - cfg.commonOpts.ForwardsTo = string(fwd) + fwd.ApplyCommon(&cfg.commonOpts) } func (fwd forwardsToOption) ApplyTLS(cfg *tlsOptions) { - cfg.commonOpts.ForwardsTo = string(fwd) + fwd.ApplyCommon(&cfg.commonOpts) } func (fwd forwardsToOption) ApplyLabeled(cfg *labeledOptions) { - cfg.commonOpts.ForwardsTo = string(fwd) + fwd.ApplyCommon(&cfg.commonOpts) } func defaultForwardsTo() string { diff --git a/config/http.go b/config/http.go index d00d59e..00de81f 100644 --- a/config/http.go +++ b/config/http.go @@ -45,6 +45,7 @@ type httpOptions struct { // If non-nil, start a goroutine which runs this http server // accepting connections from the http tunnel + // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. httpServer *http.Server // Certificates to use for client authentication at the ngrok edge. @@ -121,22 +122,27 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint { return opts } -func (cfg httpOptions) tunnelOptions() {} - func (cfg httpOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } + +func (cfg httpOptions) WithForwardsTo(hostname string) { + cfg.commonOpts.ForwardsTo = hostname +} + func (cfg httpOptions) Extra() proto.BindExtra { return proto.BindExtra{ Metadata: cfg.Metadata, } } + func (cfg httpOptions) Proto() string { if cfg.Scheme == "" { return string(SchemeHTTPS) } return string(cfg.Scheme) } + func (cfg httpOptions) Opts() any { return cfg.toProtoConfig() } diff --git a/config/http_handler.go b/config/http_handler.go index 47cc240..ffbb41c 100644 --- a/config/http_handler.go +++ b/config/http_handler.go @@ -8,6 +8,18 @@ type httpServerOption struct { Server *http.Server } +type Options interface { + HTTPEndpointOption + TLSEndpointOption + TCPEndpointOption + LabeledTunnelOption + CommonOption +} + +func (opt *httpServerOption) ApplyCommon(cfg *commonOpts) { + +} + func (opt *httpServerOption) ApplyHTTP(cfg *httpOptions) { cfg.httpServer = opt.Server } @@ -26,22 +38,14 @@ func (opt *httpServerOption) ApplyLabeled(cfg *labeledOptions) { // WithHTTPHandler adds the provided credentials to the list of basic // authentication credentials. -func WithHTTPHandler(h http.Handler) interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption - LabeledTunnelOption -} { +// Deprecated: Use session.ListenAndServeHTTP instead. +func WithHTTPHandler(h http.Handler) Options { return WithHTTPServer(&http.Server{Handler: h}) } // WithHTTPServer adds the provided credentials to the list of basic // authentication credentials. -func WithHTTPServer(srv *http.Server) interface { - HTTPEndpointOption - TLSEndpointOption - TCPEndpointOption - LabeledTunnelOption -} { +// Deprecated: Use session.ListenAndServeHTTP instead. +func WithHTTPServer(srv *http.Server) Options { return &httpServerOption{Server: srv} } diff --git a/config/http_handler_test.go b/config/http_handler_test.go deleted file mode 100644 index 456ffcf..0000000 --- a/config/http_handler_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package config - -import ( - "net/http" - "testing" - - _ "embed" -) - -type nopHandler struct{} - -func (f nopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} - -func testHTTPServer[T tunnelConfigPrivate, OT any](t *testing.T, - makeOpts func(...OT) Tunnel, -) { - optsFunc := func(opts ...any) Tunnel { - return makeOpts(assertSlice[OT](opts)...) - } - - handler := nopHandler{} - srv := &http.Server{} - - cases := testCases[T, any]{ - { - name: "absent", - opts: optsFunc(), - expectHTTPServer: serverPtr(nil), - }, - { - name: "with server", - opts: optsFunc(WithHTTPServer(srv)), - expectHTTPServer: serverPtr(srv), - }, - { - name: "with handler", - opts: optsFunc(WithHTTPHandler(handler)), - expectHTTPHandler: handlerPtr(handler), - }, - } - - cases.runAll(t) -} - -func TestHTTPServer(t *testing.T) { - testHTTPServer[httpOptions](t, HTTPEndpoint) - testHTTPServer[tlsOptions](t, TLSEndpoint) - testHTTPServer[tcpOptions](t, TCPEndpoint) - testHTTPServer[labeledOptions](t, LabeledTunnel) -} diff --git a/config/labeled.go b/config/labeled.go index 18d59b8..9211228 100644 --- a/config/labeled.go +++ b/config/labeled.go @@ -34,6 +34,7 @@ type labeledOptions struct { labels map[string]string // An HTTP Server to run traffic on + // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. httpServer *http.Server } @@ -48,22 +49,28 @@ func WithLabel(label, value string) LabeledTunnelOption { }) } -func (cfg labeledOptions) tunnelOptions() {} - func (cfg labeledOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } + +func (cfg labeledOptions) WithForwardsTo(hostname string) { + cfg.commonOpts.ForwardsTo = hostname +} + func (cfg labeledOptions) Extra() proto.BindExtra { return proto.BindExtra{ Metadata: cfg.Metadata, } } + func (cfg labeledOptions) Proto() string { return "" } + func (cfg labeledOptions) Opts() any { return nil } + func (cfg labeledOptions) Labels() map[string]string { return cfg.labels } diff --git a/config/tcp.go b/config/tcp.go index d136693..dec3213 100644 --- a/config/tcp.go +++ b/config/tcp.go @@ -32,6 +32,7 @@ type tcpOptions struct { // The TCP address to request for this edge. RemoteAddr string // An HTTP Server to run traffic on + // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. httpServer *http.Server } @@ -50,22 +51,28 @@ func (cfg *tcpOptions) toProtoConfig() *proto.TCPEndpoint { } } -func (cfg tcpOptions) tunnelOptions() {} - func (cfg tcpOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } + +func (cfg tcpOptions) WithForwardsTo(hostname string) { + cfg.commonOpts.ForwardsTo = hostname +} + func (cfg tcpOptions) Extra() proto.BindExtra { return proto.BindExtra{ Metadata: cfg.Metadata, } } + func (cfg tcpOptions) Proto() string { return "tcp" } + func (cfg tcpOptions) Opts() any { return cfg.toProtoConfig() } + func (cfg tcpOptions) Labels() map[string]string { return nil } diff --git a/config/tls.go b/config/tls.go index fc9e9a3..c2c4a60 100644 --- a/config/tls.go +++ b/config/tls.go @@ -52,6 +52,7 @@ type tlsOptions struct { CertPEM []byte // An HTTP Server to run traffic on + // Deprecated: Pass HTTP server refs via session.ListenAndServeHTTP instead. httpServer *http.Server } @@ -80,25 +81,32 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint { return opts } -func (cfg tlsOptions) tunnelOptions() {} - func (cfg tlsOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } + +func (cfg tlsOptions) WithForwardsTo(hostname string) { + cfg.commonOpts.ForwardsTo = hostname +} + func (cfg tlsOptions) Extra() proto.BindExtra { return proto.BindExtra{ Metadata: cfg.Metadata, } } + func (cfg tlsOptions) Proto() string { return "tls" } + func (cfg tlsOptions) Opts() any { return cfg.toProtoConfig() } + func (cfg tlsOptions) Labels() map[string]string { return nil } + func (cfg tlsOptions) HTTPServer() *http.Server { return cfg.httpServer } diff --git a/config/tunnel_config.go b/config/tunnel_config.go index ee79624..beadc8d 100644 --- a/config/tunnel_config.go +++ b/config/tunnel_config.go @@ -14,6 +14,7 @@ type Tunnel interface { // the public interface with internal details. type tunnelConfigPrivate interface { ForwardsTo() string + WithForwardsTo(string) Extra() proto.BindExtra Proto() string Opts() any diff --git a/examples/go.mod b/examples/go.mod index b7b76e5..0562468 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -6,7 +6,6 @@ require ( golang.ngrok.com/ngrok v0.0.0 golang.ngrok.com/ngrok/log/slog v0.0.0-00010101000000-000000000000 golang.org/x/exp v0.0.0-20230307190834-24139beb5833 - golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 ) require ( @@ -19,6 +18,7 @@ require ( go.uber.org/multierr v1.10.0 // indirect golang.ngrok.com/muxado/v2 v2.0.0 // indirect golang.org/x/net v0.10.0 // indirect + golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/term v0.8.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/examples/go.sum b/examples/go.sum index 0042222..4804d65 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -26,8 +26,8 @@ golang.org/x/exp v0.0.0-20230307190834-24139beb5833 h1:SChBja7BCQewoTAU7IgvucQKM golang.org/x/exp v0.0.0-20230307190834-24139beb5833/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 h1:ZrnxWX62AgTKOSagEqxvb3ffipvEDX2pl7E1TdqLqIc= -golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/examples/ngrok-forward-lite/main.go b/examples/ngrok-forward-lite/main.go new file mode 100644 index 0000000..4f592a9 --- /dev/null +++ b/examples/ngrok-forward-lite/main.go @@ -0,0 +1,75 @@ +// Naïve ngrok agent implementation. +// Sets up a single tunnel and forwards it to another service. + +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "strings" + + "golang.ngrok.com/ngrok" + "golang.ngrok.com/ngrok/config" + ngrok_log "golang.ngrok.com/ngrok/log" +) + +func usage(bin string) { + log.Fatalf("Usage: %s ", bin) +} + +// Simple logger that forwards to the Go standard logger. +type logger struct { + lvl ngrok_log.LogLevel +} + +func (l *logger) Log(ctx context.Context, lvl ngrok_log.LogLevel, msg string, data map[string]interface{}) { + if lvl > l.lvl { + return + } + lvlName, _ := ngrok_log.StringFromLogLevel(lvl) + log.Printf("[%s] %s %v", lvlName, msg, data) +} + +var l *logger = &logger{ + lvl: ngrok_log.LogLevelDebug, +} + +func main() { + if len(os.Args) != 2 { + usage(os.Args[0]) + } + backend := os.Args[1] + if !strings.Contains(backend, "://") { + backend = fmt.Sprintf("tcp://%s", backend) + } + + backendUrl, err := url.Parse(backend) + if err != nil { + usage(os.Args[0]) + } + + if err := run(context.Background(), backendUrl); err != nil { + log.Fatal(err) + } +} + +func run(ctx context.Context, backend *url.URL) error { + fwd, err := ngrok.ListenAndForward(ctx, + backend, + config.HTTPEndpoint(), + ngrok.WithAuthtokenFromEnv(), + ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), + ) + if err != nil { + return err + } + + l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{ + "url": fwd.URL(), + }) + + return fwd.Wait() +} diff --git a/examples/ngrok-http-lite/main.go b/examples/ngrok-http-lite/main.go new file mode 100644 index 0000000..451117b --- /dev/null +++ b/examples/ngrok-http-lite/main.go @@ -0,0 +1,67 @@ +// Naïve ngrok agent implementation. +// Sets up a single tunnel and connects to an arbitrary HTTP server. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "golang.ngrok.com/ngrok" + "golang.ngrok.com/ngrok/config" + ngrok_log "golang.ngrok.com/ngrok/log" +) + +// Simple logger that forwards to the Go standard logger. +type logger struct { + lvl ngrok_log.LogLevel +} + +func (l *logger) Log(ctx context.Context, lvl ngrok_log.LogLevel, msg string, data map[string]interface{}) { + if lvl > l.lvl { + return + } + lvlName, _ := ngrok_log.StringFromLogLevel(lvl) + log.Printf("[%s] %s %v", lvlName, msg, data) +} + +var l *logger = &logger{ + lvl: ngrok_log.LogLevelDebug, +} + +func main() { + // Spin up a simple HTTP server + server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello from ngrok-go!") + })} + + // Serve with tunnel backend + if err := run(context.Background(), server); err != nil { + log.Fatal(err) + } + + // Sleep main thread + for { + time.Sleep(5 * time.Second) + } +} + +func run(ctx context.Context, server *http.Server) error { + tunnel, err := ngrok.ListenAndServeHTTP(ctx, + server, + config.HTTPEndpoint(), + ngrok.WithAuthtokenFromEnv(), + ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), + ) + + if err == nil { + l.Log(ctx, ngrok_log.LogLevelInfo, "tunnel created", map[string]any{ + "url": tunnel.URL(), + }) + } + + return err +} diff --git a/examples/ngrok-lite/main.go b/examples/ngrok-lite/main.go deleted file mode 100644 index 7e52ba7..0000000 --- a/examples/ngrok-lite/main.go +++ /dev/null @@ -1,76 +0,0 @@ -// Naïve ngrok agent implementation. -// Sets up a single tunnel and forwards it to another service. - -package main - -import ( - "context" - "io" - "log" - "net" - "os" - - "golang.org/x/sync/errgroup" - - "golang.ngrok.com/ngrok" - "golang.ngrok.com/ngrok/config" -) - -func usage(bin string) { - log.Fatalf("Usage: %s ", bin) -} - -func main() { - if len(os.Args) != 2 { - usage(os.Args[0]) - } - if err := run(context.Background(), os.Args[1]); err != nil { - log.Fatal(err) - } -} - -func run(ctx context.Context, dest string) error { - tun, err := ngrok.Listen(ctx, - config.HTTPEndpoint(), - ngrok.WithAuthtokenFromEnv(), - ) - if err != nil { - return err - } - - log.Println("tunnel created:", tun.URL()) - - for { - conn, err := tun.Accept() - if err != nil { - return err - } - - log.Println("accepted connection from", conn.RemoteAddr()) - - go func() { - err := handleConn(ctx, dest, conn) - log.Println("connection closed:", err) - }() - } -} - -func handleConn(ctx context.Context, dest string, conn net.Conn) error { - next, err := net.Dial("tcp", dest) - if err != nil { - return err - } - - g, _ := errgroup.WithContext(ctx) - - g.Go(func() error { - _, err := io.Copy(next, conn) - return err - }) - g.Go(func() error { - _, err := io.Copy(conn, next) - return err - }) - - return g.Wait() -} diff --git a/forward.go b/forward.go new file mode 100644 index 0000000..922f6c2 --- /dev/null +++ b/forward.go @@ -0,0 +1,181 @@ +package ngrok + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/inconshreveable/log15/v3" + "golang.org/x/sync/errgroup" +) + +// Forwarder is a tunnel that has every connection forwarded to some URL. +type Forwarder interface { + // Information about the tunnel being forwarded + TunnelInfo + + // Close is a convenience method for calling Tunnel.CloseWithContext + // with a context that has a timeout of 5 seconds. This also allows the + // Tunnel to satisfy the io.Closer interface. + Close() error + + // CloseWithContext closes the Tunnel. Closing a tunnel is an operation + // that involves sending a "close" message over the parent session. + // Since this is a network operation, it is most correct to provide a + // context with a timeout. + CloseWithContext(context.Context) error + + // Session returns the tunnel's parent Session object that it + // was started on. + Session() Session + + // Wait blocks until the forwarding task exits (usually due to tunnel + // close), or the `context.Context` that it was started with is canceled. + Wait() error +} + +type forwarder struct { + Tunnel + mainGroup *errgroup.Group +} + +func (fwd *forwarder) Wait() error { + return fwd.mainGroup.Wait() +} + +// compile-time check that we're implementing the proper interface +var _ Forwarder = (*forwarder)(nil) + +func join(ctx context.Context, left, right io.ReadWriter) { + g := &sync.WaitGroup{} + g.Add(2) + go func() { + _, _ = io.Copy(left, right) + g.Done() + }() + go func() { + _, _ = io.Copy(right, left) + g.Done() + }() + g.Wait() +} + +func forwardTunnel(ctx context.Context, tun Tunnel, url *url.URL) Forwarder { + mainGroup, ctx := errgroup.WithContext(ctx) + fwdTasks := &sync.WaitGroup{} + + sess := tun.Session() + sessImpl := sess.(*sessionImpl) + logger := sessImpl.inner().Logger.New("task", "forward", "toUrl", url, "tunnelUrl", tun.URL()) + + mainGroup.Go(func() error { + for { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + + conn, err := tun.Accept() + if err != nil { + return err + } + fwdTasks.Add(1) + + go func() { + ngrokConn := conn.(Conn) + defer ngrokConn.Close() + + backend, err := openBackend(ctx, logger, tun, ngrokConn, url) + if err != nil { + logger.Warn("failed to connect to backend url", "error", err) + fwdTasks.Done() + return + } + + defer backend.Close() + join(ctx, ngrokConn, backend) + fwdTasks.Done() + }() + } + }) + + return &forwarder{ + Tunnel: tun, + mainGroup: mainGroup, + } +} + +// TODO: use an actual reverse proxy for http/s tunnels so that the host header gets set? +func openBackend(ctx context.Context, logger log15.Logger, tun Tunnel, tunnelConn Conn, url *url.URL) (net.Conn, error) { + host := url.Hostname() + port := url.Port() + if port == "" { + switch { + case usesTLS(url.Scheme): + port = "443" + case isHTTP(url.Scheme): + port = "80" + default: + return nil, fmt.Errorf("no default tcp port available for %s", url.Scheme) + } + logger.Debug("set default port", "port", port) + } + + // Create TLS config if necessary + var tlsConfig *tls.Config + if usesTLS(url.Scheme) { + tlsConfig = &tls.Config{ServerName: url.Hostname()} + } + + dialer := &net.Dialer{} + address := fmt.Sprintf("%s:%s", host, port) + logger.Debug("dial backend tcp", "address", address) + + conn, err := dialer.DialContext(ctx, "tcp", address) + if err != nil { + defer tunnelConn.Close() + + if isHTTP(tunnelConn.Proto()) { + _ = writeHTTPError(tunnelConn, err) + } + return nil, err + } + + if usesTLS(url.Scheme) && !tunnelConn.PassthroughTLS() { + logger.Debug("establishing TLS connection with backend") + return tls.Client(conn, tlsConfig), nil + } + + return conn, nil +} + +func writeHTTPError(w io.Writer, err error) error { + resp := &http.Response{} + resp.StatusCode = http.StatusBadGateway + resp.Body = io.NopCloser(bytes.NewBufferString(fmt.Sprintf("failed to connect to backend: %s", err.Error()))) + return resp.Write(w) +} + +func usesTLS(scheme string) bool { + switch strings.ToLower(scheme) { + case "https", "tls": + return true + default: + return false + } +} + +func isHTTP(scheme string) bool { + switch strings.ToLower(scheme) { + case "https", "http": + return true + default: + return false + } +} diff --git a/go.mod b/go.mod index a07ca44..91a8996 100644 --- a/go.mod +++ b/go.mod @@ -5,24 +5,22 @@ go 1.20 require ( github.com/inconshreveable/log15/v3 v3.0.0-testing.5 github.com/jpillora/backoff v1.0.0 - github.com/stretchr/testify v1.8.0 - go.uber.org/multierr v1.10.0 + github.com/stretchr/testify v1.8.4 + go.uber.org/multierr v1.11.0 golang.ngrok.com/muxado/v2 v2.0.0 - golang.org/x/net v0.10.0 - google.golang.org/protobuf v1.28.1 + golang.org/x/net v0.15.0 + golang.org/x/sync v0.3.0 + google.golang.org/protobuf v1.31.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-stack/stack v1.8.1 // indirect - github.com/google/go-cmp v0.5.8 // indirect github.com/inconshreveable/log15 v3.0.0-testing.3+incompatible // indirect - github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/term v0.8.0 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/term v0.12.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3e00d4f..6f156dc 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,10 @@ -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/inconshreveable/log15 v3.0.0-testing.3+incompatible h1:zaX5fYT98jX5j4UhO/WbfY8T1HkgVrydiDMC9PWqGCo= github.com/inconshreveable/log15 v3.0.0-testing.3+incompatible/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o= @@ -15,42 +12,34 @@ github.com/inconshreveable/log15/v3 v3.0.0-testing.5 h1:h4e0f3kjgg+RJBlKOabrohjH github.com/inconshreveable/log15/v3 v3.0.0-testing.5/go.mod h1:3GQg1SVrLoWGfRv/kAZMsdyU5cp8eFc1P3cw+Wwku94= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= -go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.ngrok.com/muxado/v2 v2.0.0 h1:bu9eIDhRdYNtIXNnqat/HyMeHYOAbUH55ebD7gTvW6c= golang.ngrok.com/muxado/v2 v2.0.0/go.mod h1:wzxJYX4xiAtmwumzL+QsukVwFRXmPNv86vB8RPpOxyM= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.work.sum b/go.work.sum index 1220897..b4277b1 100644 --- a/go.work.sum +++ b/go.work.sum @@ -4,15 +4,18 @@ github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c= github.com/creack/pty v1.1.7 h1:6pwm8kMQKCmgUg0ZHTm5+/YvRK0s3THD/28+T6/kk4A= +github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -22,25 +25,34 @@ github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ= go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.ngrok.com/muxado/v2 v2.0.0 h1:bu9eIDhRdYNtIXNnqat/HyMeHYOAbUH55ebD7gTvW6c= -golang.ngrok.com/muxado/v2 v2.0.0/go.mod h1:wzxJYX4xiAtmwumzL+QsukVwFRXmPNv86vB8RPpOxyM= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee h1:WG0RUwxtNT4qqaXX3DPA8zHFNm/D9xaBpxzHt1WcA/E= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/online_test.go b/online_test.go index 5f997bf..ba8a469 100644 --- a/online_test.go +++ b/online_test.go @@ -122,13 +122,10 @@ func TestTunnelConnMetadata(t *testing.T) { func TestWithHTTPHandler(t *testing.T) { ctx := context.Background() - sess := setupSession(ctx, t) - - tun := startTunnel(ctx, t, sess, config.HTTPEndpoint( + tun, _ := serveHTTP(ctx, t, nil, config.HTTPEndpoint( config.WithMetadata("Hello, world!"), config.WithForwardsTo("some application"), - config.WithHTTPHandler(helloHandler), - )) + ), helloHandler) resp, err := http.Get(tun.URL()) require.NoError(t, err, "GET tunnel url") @@ -595,13 +592,13 @@ func TestConnectionCallbacks(t *testing.T) { disconnectNils := 0 sess := setupSession(ctx, t, WithConnectHandler(func(ctx context.Context, sess Session) { - connects += 1 + connects++ }), WithDisconnectHandler(func(ctx context.Context, sess Session, err error) { if err == nil { - disconnectNils += 1 + disconnectNils++ } else { - disconnectErrs += 1 + disconnectErrs++ } }), WithDialer(&sketchyDialer{1 * time.Second})) @@ -642,7 +639,7 @@ func TestHeartbeatCallback(t *testing.T) { heartbeats := 0 sess := setupSession(ctx, t, WithHeartbeatHandler(func(ctx context.Context, sess Session, latency time.Duration) { - heartbeats += 1 + heartbeats++ }), WithHeartbeatInterval(10*time.Second)) @@ -744,4 +741,4 @@ func TestWebsockets(t *testing.T) { tun.Close() require.Error(t, <-errCh) -} +} \ No newline at end of file diff --git a/session.go b/session.go index eccd095..912212c 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" - _ "embed" + _ "embed" // nolint "errors" "fmt" "net" @@ -21,6 +21,7 @@ import ( "github.com/inconshreveable/log15/v3" "go.uber.org/multierr" "golang.org/x/net/proxy" + "golang.org/x/sync/errgroup" "golang.ngrok.com/ngrok/config" @@ -35,6 +36,7 @@ import ( //go:embed VERSION var libraryAgentVersion string +// AgentVersionDeprecated is a type wrapper for [proto.AgentVersionDeprecated] type AgentVersionDeprecated proto.AgentVersionDeprecated func (avd *AgentVersionDeprecated) Error() string { @@ -51,6 +53,19 @@ type Session interface { // Warnings returns a list of warnings generated for the session on connect/auth Warnings() []error + // ListenAndForward creates a new Tunnel which will listen for new inbound + // connections. Connections on this tunnel are automatically forwarded to + // the provided URL. + ListenAndForward(ctx context.Context, backend *url.URL, cfg config.Tunnel) (Forwarder, error) + + // ListenAndServeHTTP creates a new Tunnel to serve as a backend for an HTTP server. Connections will be + // forwarded to the provided HTTP server. + ListenAndServeHTTP(ctx context.Context, cfg config.Tunnel, server *http.Server) (Forwarder, error) + + // ListenAndHandleHTTP creates a new Tunnel to serve as a backend for an HTTP handler. Connections will be + // forwarded to a new HTTP server and handled by the provided HTTP handler. + ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error) + // Close ends the ngrok session. All Tunnel objects created by Listen // on this session will be closed. Close() error @@ -78,14 +93,13 @@ type SessionConnectHandler func(ctx context.Context, sess Session) // SessionDisconnectHandler is the callback type for [WithDisconnectHandler] type SessionDisconnectHandler func(ctx context.Context, sess Session, err error) -// SessionHearbeatHandler is the callback type for [WithHearbeatHandler] +// SessionHeartbeatHandler is the callback type for [WithHearbeatHandler] type SessionHeartbeatHandler func(ctx context.Context, sess Session, latency time.Duration) // ServerCommandHandler is the callback type for [WithStopHandler] type ServerCommandHandler func(ctx context.Context, sess Session) error -// ConnectOptions are passed to [Connect] to customize session connection and -// establishment. +// ConnectOption is passed to [Connect] to customize session connection and establishment. type ConnectOption func(*connectConfig) type clientInfo struct { @@ -176,7 +190,7 @@ type connectConfig struct { Logger log.Logger } -// WithMetdata configures the opaque, machine-readable metadata string for this +// WithMetadata configures the opaque, machine-readable metadata string for this // session. Metadata is made available to you in the ngrok dashboard and the // Agents API resource. It is a useful way to allow you to uniquely identify // sessions. We suggest encoding the value in a structured format like JSON. @@ -635,6 +649,7 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) { Banner: resp.Extra.Banner, SessionDuration: resp.Extra.SessionDuration, DeprecationWarning: resp.Extra.DeprecationWarning, + Logger: logger, }) if cfg.HeartbeatHandler != nil { @@ -740,6 +755,8 @@ type sessionInner struct { Banner string SessionDuration int64 DeprecationWarning *proto.AgentVersionDeprecated + + Logger log15.Logger } func (s *sessionImpl) inner() *sessionInner { @@ -758,53 +775,100 @@ func (s *sessionImpl) Close() error { return s.inner().Close() } +func (s *sessionImpl) Warnings() []error { + deprecated := s.inner().DeprecationWarning + if deprecated != nil { + return []error{(*AgentVersionDeprecated)(deprecated)} + } + return nil + +} + func (s *sessionImpl) Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, error) { var ( tunnel tunnel_client.Tunnel err error ) - tunnelCfg, ok := cfg.(tunnelConfigPrivate) if !ok { return nil, errors.New("invalid tunnel config") } extra := tunnelCfg.Extra() - if tunnelCfg.Proto() != "" { tunnel, err = s.inner().Listen(tunnelCfg.Proto(), tunnelCfg.Opts(), extra, tunnelCfg.ForwardsTo()) } else { tunnel, err = s.inner().ListenLabel(tunnelCfg.Labels(), extra.Metadata, tunnelCfg.ForwardsTo()) } - if err != nil { - return nil, errListen{err} - } - - t := &tunnelImpl{ + impl := &tunnelImpl{ Sess: s, Tunnel: tunnel, } - if httpServerCfg, ok := cfg.(interface { - HTTPServer() *http.Server - }); ok { - if srv := httpServerCfg.HTTPServer(); srv != nil { - go func() { - _ = srv.Serve(t) - }() + // Legacy support for passing HTTP server via config options. + // TODO: Remove this after we feel HTTP options via config have been deprecated. + if serverCfg, ok := cfg.(interface{ HTTPServer() *http.Server }); ok { + server := serverCfg.HTTPServer() + if server != nil { + go func() { _ = server.Serve(impl) }() + impl.server = server } } - return t, nil + if err == nil { + return impl, nil + } + return nil, errListen{err} } -func (s *sessionImpl) Warnings() []error { - deprecated := s.inner().DeprecationWarning - if deprecated != nil { - return []error{(*AgentVersionDeprecated)(deprecated)} +func (s *sessionImpl) ListenAndForward(ctx context.Context, url *url.URL, cfg config.Tunnel) (Forwarder, error) { + tunnelCfg, ok := cfg.(tunnelConfigPrivate) + if !ok { + return nil, errors.New("invalid tunnel config") } - return nil + + // Set 'Forwards To' + tunnelCfg.WithForwardsTo(url.Host) + + tun, err := s.Listen(ctx, cfg) + if err != nil { + return nil, err + } + + return forwardTunnel(ctx, tun, url), nil +} + +func (s *sessionImpl) ListenAndServeHTTP(ctx context.Context, cfg config.Tunnel, server *http.Server) (Forwarder, error) { + tun, err := s.Listen(ctx, cfg) + if err != nil { + return nil, err + } + + mainGroup, _ := errgroup.WithContext(ctx) + if server != nil { + // Store server ref to close when tunnel closes + impl, _ := tun.(*tunnelImpl) + + // Check if tunnel is already serving an HTTP server + // TODO: Remove this once we feel HTTP options via config have been deprecated. + if impl.server == nil { + mainGroup.Go(func() error { return server.Serve(tun) }) + impl.server = server + } else { + // Inform end user that they're using a deprecated option. + fmt.Println("Tunnel is serving an HTTP server via HTTP options. This has been deprecated. Please use Session.ListenAndServeHTTP instead.") + } + } + + return &forwarder{ + Tunnel: tun, + mainGroup: mainGroup, + }, nil +} + +func (s *sessionImpl) ListenAndHandleHTTP(ctx context.Context, cfg config.Tunnel, handler *http.Handler) (Forwarder, error) { + return s.ListenAndServeHTTP(ctx, cfg, &http.Server{Handler: *handler}) } // The rest of the `sessionImpl` methods are non-public, but can be diff --git a/tunnel.go b/tunnel.go index a558195..be1149f 100644 --- a/tunnel.go +++ b/tunnel.go @@ -3,6 +3,8 @@ package ngrok import ( "context" "net" + "net/http" + "net/url" "time" "golang.ngrok.com/ngrok/config" @@ -18,15 +20,27 @@ type Tunnel interface { // code that expects a net.Listener seamlessly without any changes. net.Listener + // Information associated with the tunnel + TunnelInfo + // Close is a convenience method for calling Tunnel.CloseWithContext // with a context that has a timeout of 5 seconds. This also allows the // Tunnel to satisfy the io.Closer interface. Close() error + // CloseWithContext closes the Tunnel. Closing a tunnel is an operation // that involves sending a "close" message over the parent session. // Since this is a network operation, it is most correct to provide a // context with a timeout. CloseWithContext(context.Context) error + + // Session returns the tunnel's parent Session object that it + // was started on. + Session() Session +} + +// TunnelInfo implementations contain metadata about a [Tunnel]. +type TunnelInfo interface { // ForwardsTo returns a human-readable string presented in the ngrok // dashboard and the Tunnels API. Use config.WithForwardsTo when // calling Session.Listen to set this value explicitly. @@ -41,9 +55,6 @@ type Tunnel interface { // Proto returns the protocol of the tunnel's endpoint. // Labeled tunnels will return the empty string. Proto() string - // Session returns the tunnel's parent Session object that it - // was started on. - Session() Session // URL returns the tunnel endpoint's URL. // Labeled tunnels will return the empty string. URL() string @@ -70,9 +81,69 @@ func Listen(ctx context.Context, tunnelConfig config.Tunnel, connectOpts ...Conn return tunnel, nil } +// ListenAndForward creates a new [Forwarder] after connecting a new [Session], and +// then forwards all connections to the provided URL. +// This is a shortcut for calling [Connect] then [Session].ListenAndForward. +// +// Access to the underlying [Session] that was started automatically can be +// accessed via [Forwarder].Session. +// +// If an error is encountered during [Session].ListenAndForward, the [Session] +// object that was created will be closed automatically. +func ListenAndForward(ctx context.Context, backend *url.URL, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { + sess, err := Connect(ctx, connectOpts...) + if err != nil { + return nil, err + } + fwd, err := sess.ListenAndForward(ctx, backend, tunnelConfig) + if err != nil { + _ = sess.Close() + return nil, err + } + + return fwd, nil +} + +// ListenAndServeHTTP creates a new [Forwarder] after connecting a new [Session], and +// then forwards all connections to the provided HTTP server. +// This is a shortcut for calling [Connect] then [Session].ListenAndForward. +// +// Access to the underlying [Session] that was started automatically can be +// accessed via [Tunnel].Session. +// +// If an error is encountered during [Session].ListenAndServeHTTP, the [Session] +// object that was created will be closed automatically. +func ListenAndServeHTTP(ctx context.Context, server *http.Server, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { + sess, err := Connect(ctx, connectOpts...) + if err != nil { + return nil, err + } + + forwarder, err := sess.ListenAndServeHTTP(ctx, tunnelConfig, server) + if err != nil { + _ = sess.Close() + return nil, err + } + + return forwarder, nil +} + +// ListenAndHandleHTTP creates a new [Forwarder] after connecting a new [Session], and +// then forwards all connections to a new HTTP server and handles them with the provided HTTP handler. +// +// Access to the underlying [Session] that was started automatically can be +// accessed via [Tunnel].Session. +// +// If an error is encountered during [Session].ListenAndHandleHTTP, the [Session] +// object that was created will be closed automatically. +func ListenAndHandleHTTP(ctx context.Context, handler *http.Handler, tunnelConfig config.Tunnel, connectOpts ...ConnectOption) (Forwarder, error) { + return ListenAndServeHTTP(ctx, &http.Server{Handler: *handler}, tunnelConfig, connectOpts...) +} + type tunnelImpl struct { Sess Session Tunnel tunnel_client.Tunnel + server *http.Server } func (t *tunnelImpl) Accept() (net.Conn, error) { @@ -93,6 +164,12 @@ func (t *tunnelImpl) Close() error { } func (t *tunnelImpl) CloseWithContext(_ context.Context) error { + if t.server != nil { + err := t.server.Close() + if err != nil { + return err + } + } return t.Tunnel.Close() } @@ -150,8 +227,10 @@ type Conn interface { PassthroughTLS() bool } +// EdgeType is the type of the edge (https, tls, or tcp) for this tunnel. type EdgeType proto.EdgeType +// All possible edge types. Currently only https, tls, and tcp are supported. const ( EdgeTypeUndefined EdgeType = 0 EdgeTypeTCP EdgeType = 1 @@ -164,6 +243,7 @@ type connImpl struct { Proxy *tunnel_client.ProxyConn } +// compile-time check that we're implementing the proper interface var _ Conn = &connImpl{} func (c *connImpl) ProxyConn() *tunnel_client.ProxyConn { diff --git a/tunnel_config.go b/tunnel_config.go index 0e7cf97..697f795 100644 --- a/tunnel_config.go +++ b/tunnel_config.go @@ -13,4 +13,5 @@ type tunnelConfigPrivate interface { Proto() string Opts() any Labels() map[string]string + WithForwardsTo(string) }