diff --git a/config/basic_auth_test.go b/config/basic_auth_test.go index deecf7e..b436508 100644 --- a/config/basic_auth_test.go +++ b/config/basic_auth_test.go @@ -10,7 +10,7 @@ import ( ) func TestBasicAuth(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "single", opts: HTTPEndpoint(WithBasicAuth("foo", "bar")), diff --git a/config/cidr_restrictions_test.go b/config/cidr_restrictions_test.go index 2377757..fb45aa7 100644 --- a/config/cidr_restrictions_test.go +++ b/config/cidr_restrictions_test.go @@ -122,15 +122,15 @@ func testCIDRRestrictions[T tunnelConfigPrivate, O any, OT any](t *testing.T, } func TestCIDRRestrictions(t *testing.T) { - testCIDRRestrictions[httpOptions](t, HTTPEndpoint, + testCIDRRestrictions[*httpOptions](t, HTTPEndpoint, func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_IPRestriction { return h.IPRestriction }) - testCIDRRestrictions[tcpOptions](t, TCPEndpoint, + testCIDRRestrictions[*tcpOptions](t, TCPEndpoint, func(h *proto.TCPEndpoint) *pb.MiddlewareConfiguration_IPRestriction { return h.IPRestriction }) - testCIDRRestrictions[tlsOptions](t, TLSEndpoint, + testCIDRRestrictions[*tlsOptions](t, TLSEndpoint, func(h *proto.TLSEndpoint) *pb.MiddlewareConfiguration_IPRestriction { return h.IPRestriction }) diff --git a/config/circuit_breaker_test.go b/config/circuit_breaker_test.go index 180aef9..c487752 100644 --- a/config/circuit_breaker_test.go +++ b/config/circuit_breaker_test.go @@ -9,7 +9,7 @@ import ( ) func TestCircuitBreaker(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/common.go b/config/common.go index 30e05d9..8fe0fdd 100644 --- a/config/common.go +++ b/config/common.go @@ -31,4 +31,4 @@ func (cfg *commonOpts) getForwardsTo() string { return cfg.ForwardsTo } -func (cfg commonOpts) tunnelOptions() {} +func (cfg *commonOpts) tunnelOptions() {} diff --git a/config/compression_test.go b/config/compression_test.go index 59e6427..6a4b288 100644 --- a/config/compression_test.go +++ b/config/compression_test.go @@ -9,7 +9,7 @@ import ( ) func TestCompression(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/domain_test.go b/config/domain_test.go index db132cd..3927747 100644 --- a/config/domain_test.go +++ b/config/domain_test.go @@ -40,10 +40,10 @@ func testDomain[T tunnelConfigPrivate, O any, OT any](t *testing.T, } func TestDomain(t *testing.T) { - testDomain[httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) string { + testDomain[*httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) string { return opts.Domain }) - testDomain[tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) string { + testDomain[*tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) string { return opts.Domain }) } diff --git a/config/forwards_to_test.go b/config/forwards_to_test.go index 637447a..429b2ef 100644 --- a/config/forwards_to_test.go +++ b/config/forwards_to_test.go @@ -28,8 +28,8 @@ func testForwardsTo[T tunnelConfigPrivate, OT any](t *testing.T, } func TestForwardsTo(t *testing.T) { - testForwardsTo[httpOptions](t, HTTPEndpoint) - testForwardsTo[tlsOptions](t, TLSEndpoint) - testForwardsTo[tcpOptions](t, TCPEndpoint) - testForwardsTo[labeledOptions](t, LabeledTunnel) + testForwardsTo[*httpOptions](t, HTTPEndpoint) + testForwardsTo[*tlsOptions](t, TLSEndpoint) + testForwardsTo[*tcpOptions](t, TCPEndpoint) + testForwardsTo[*labeledOptions](t, LabeledTunnel) } diff --git a/config/http.go b/config/http.go index 995b634..e77e4b2 100644 --- a/config/http.go +++ b/config/http.go @@ -3,6 +3,7 @@ package config import ( "crypto/x509" "net/http" + "net/url" "golang.ngrok.com/ngrok/internal/pb" "golang.ngrok.com/ngrok/internal/tunnel/proto" @@ -24,7 +25,7 @@ func HTTPEndpoint(opts ...HTTPEndpointOption) Tunnel { for _, opt := range opts { opt.ApplyHTTP(&cfg) } - return cfg + return &cfg } type httpOptions struct { @@ -63,6 +64,9 @@ type httpOptions struct { // Headers to be added to or removed from all responses at the ngrok edge. ResponseHeaders *headers + // Auto-rewrite host header on ListenAndForward? + RewriteHostHeader bool + // Credentials for basic authentication. // If empty, basic authentication is disabled. BasicAuth []basicAuth @@ -126,8 +130,11 @@ func (cfg httpOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } -func (cfg httpOptions) WithForwardsTo(hostname string) { - cfg.commonOpts.ForwardsTo = hostname +func (cfg *httpOptions) WithForwardsTo(url *url.URL) { + cfg.commonOpts.ForwardsTo = url.Host + if cfg.RewriteHostHeader { + WithRequestHeader("host", url.Host).ApplyHTTP(cfg) + } } func (cfg httpOptions) Extra() proto.BindExtra { diff --git a/config/http_headers.go b/config/http_headers.go index f54f025..3f8d207 100644 --- a/config/http_headers.go +++ b/config/http_headers.go @@ -62,6 +62,18 @@ func (h responseHeaders) ApplyHTTP(cfg *httpOptions) { cfg.ResponseHeaders = cfg.ResponseHeaders.merge(headers(h)) } +// WithHostHeaderRewrite will automatically set the `Host` header to the one in +// the URL passed to `ListenAndForward`. Does nothing if using `Listen`. +// Defaults to `false`. +// +// If you need to set the host header to a specific value, use +// `WithRequestHeader("host", "some.host.com")` instead. +func WithHostHeaderRewrite(rewrite bool) HTTPEndpointOption { + return httpOptionFunc(func(cfg *httpOptions) { + cfg.RewriteHostHeader = rewrite + }) +} + // WithRequestHeader adds a header to all requests to this edge. func WithRequestHeader(name, value string) HTTPEndpointOption { return requestHeaders(headers{ diff --git a/config/http_headers_test.go b/config/http_headers_test.go index 5db9eeb..ba8599e 100644 --- a/config/http_headers_test.go +++ b/config/http_headers_test.go @@ -9,7 +9,7 @@ import ( ) func TestHTTPHeaders(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/http_test.go b/config/http_test.go index 77f6a00..5e2cf3f 100644 --- a/config/http_test.go +++ b/config/http_test.go @@ -9,7 +9,7 @@ import ( ) func TestHTTP(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "empty", opts: HTTPEndpoint(), diff --git a/config/labeled.go b/config/labeled.go index 65287dc..9cb8a16 100644 --- a/config/labeled.go +++ b/config/labeled.go @@ -2,6 +2,7 @@ package config import ( "net/http" + "net/url" "golang.ngrok.com/ngrok/internal/tunnel/proto" ) @@ -22,7 +23,7 @@ func LabeledTunnel(opts ...LabeledTunnelOption) Tunnel { for _, opt := range opts { opt.ApplyLabeled(&cfg) } - return cfg + return &cfg } // Options for labeled tunnels. @@ -53,8 +54,8 @@ func (cfg labeledOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } -func (cfg labeledOptions) WithForwardsTo(hostname string) { - cfg.commonOpts.ForwardsTo = hostname +func (cfg *labeledOptions) WithForwardsTo(url *url.URL) { + cfg.commonOpts.ForwardsTo = url.Host } func (cfg labeledOptions) Extra() proto.BindExtra { diff --git a/config/labeled_test.go b/config/labeled_test.go index cbb37a0..8fe8ff5 100644 --- a/config/labeled_test.go +++ b/config/labeled_test.go @@ -7,7 +7,7 @@ import ( ) func TestLabeled(t *testing.T) { - cases := testCases[labeledOptions, proto.LabelOptions]{ + cases := testCases[*labeledOptions, proto.LabelOptions]{ { name: "simple", opts: LabeledTunnel(WithLabel("foo", "bar")), diff --git a/config/metadata_test.go b/config/metadata_test.go index 95b538c..f1ceb9a 100644 --- a/config/metadata_test.go +++ b/config/metadata_test.go @@ -32,8 +32,8 @@ func testMetadata[T tunnelConfigPrivate, OT any](t *testing.T, } func TestMetadata(t *testing.T) { - testMetadata[httpOptions](t, HTTPEndpoint) - testMetadata[tlsOptions](t, TLSEndpoint) - testMetadata[tcpOptions](t, TCPEndpoint) - testMetadata[labeledOptions](t, LabeledTunnel) + testMetadata[*httpOptions](t, HTTPEndpoint) + testMetadata[*tlsOptions](t, TLSEndpoint) + testMetadata[*tcpOptions](t, TCPEndpoint) + testMetadata[*labeledOptions](t, LabeledTunnel) } diff --git a/config/mutual_tls_test.go b/config/mutual_tls_test.go index 11b788c..b85e595 100644 --- a/config/mutual_tls_test.go +++ b/config/mutual_tls_test.go @@ -54,10 +54,10 @@ func testMutualTLS[T tunnelConfigPrivate, O any, OT any](t *testing.T, } func TestMutualTLS(t *testing.T) { - testMutualTLS[httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_MutualTLS { + testMutualTLS[*httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_MutualTLS { return opts.MutualTLSCA }) - testMutualTLS[tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) *pb.MiddlewareConfiguration_MutualTLS { + testMutualTLS[*tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) *pb.MiddlewareConfiguration_MutualTLS { return opts.MutualTLSAtEdge }) } diff --git a/config/oauth_test.go b/config/oauth_test.go index 89fe43b..8f181b9 100644 --- a/config/oauth_test.go +++ b/config/oauth_test.go @@ -9,7 +9,7 @@ import ( ) func TestOAuth(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/oidc_test.go b/config/oidc_test.go index 9c5a0ce..c796ca2 100644 --- a/config/oidc_test.go +++ b/config/oidc_test.go @@ -9,7 +9,7 @@ import ( ) func TestOIDC(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/proxy_proto_test.go b/config/proxy_proto_test.go index 0228bef..9eb55b1 100644 --- a/config/proxy_proto_test.go +++ b/config/proxy_proto_test.go @@ -41,13 +41,13 @@ func testProxyProto[T tunnelConfigPrivate, O any, OT any](t *testing.T, } func TestProxyProto(t *testing.T) { - testProxyProto[httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) proto.ProxyProto { + testProxyProto[*httpOptions](t, HTTPEndpoint, func(opts *proto.HTTPEndpoint) proto.ProxyProto { return opts.ProxyProto }) - testProxyProto[tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) proto.ProxyProto { + testProxyProto[*tlsOptions](t, TLSEndpoint, func(opts *proto.TLSEndpoint) proto.ProxyProto { return opts.ProxyProto }) - testProxyProto[tcpOptions](t, TCPEndpoint, func(opts *proto.TCPEndpoint) proto.ProxyProto { + testProxyProto[*tcpOptions](t, TCPEndpoint, func(opts *proto.TCPEndpoint) proto.ProxyProto { return opts.ProxyProto }) } diff --git a/config/scheme_test.go b/config/scheme_test.go index 713e012..c12b610 100644 --- a/config/scheme_test.go +++ b/config/scheme_test.go @@ -7,7 +7,7 @@ import ( ) func TestScheme(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "default", opts: HTTPEndpoint(), diff --git a/config/tcp.go b/config/tcp.go index 1b1f9f0..52ec67a 100644 --- a/config/tcp.go +++ b/config/tcp.go @@ -2,6 +2,7 @@ package config import ( "net/http" + "net/url" "golang.ngrok.com/ngrok/internal/tunnel/proto" ) @@ -22,7 +23,7 @@ func TCPEndpoint(opts ...TCPEndpointOption) Tunnel { for _, opt := range opts { opt.ApplyTCP(&cfg) } - return cfg + return &cfg } // The options for a TCP edge. @@ -55,8 +56,8 @@ func (cfg tcpOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } -func (cfg tcpOptions) WithForwardsTo(hostname string) { - cfg.commonOpts.ForwardsTo = hostname +func (cfg *tcpOptions) WithForwardsTo(url *url.URL) { + cfg.commonOpts.ForwardsTo = url.Host } func (cfg tcpOptions) Extra() proto.BindExtra { diff --git a/config/tcp_test.go b/config/tcp_test.go index ab93346..934856b 100644 --- a/config/tcp_test.go +++ b/config/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCP(t *testing.T) { - cases := testCases[tcpOptions, proto.TCPEndpoint]{ + cases := testCases[*tcpOptions, proto.TCPEndpoint]{ { name: "empty", opts: TCPEndpoint(), diff --git a/config/tls.go b/config/tls.go index 01f5d3f..2175658 100644 --- a/config/tls.go +++ b/config/tls.go @@ -3,6 +3,7 @@ package config import ( "crypto/x509" "net/http" + "net/url" "golang.ngrok.com/ngrok/internal/pb" "golang.ngrok.com/ngrok/internal/tunnel/proto" @@ -24,7 +25,7 @@ func TLSEndpoint(opts ...TLSEndpointOption) Tunnel { for _, opt := range opts { opt.ApplyTLS(&cfg) } - return cfg + return &cfg } // The options for TLS edges. @@ -85,8 +86,8 @@ func (cfg tlsOptions) ForwardsTo() string { return cfg.commonOpts.getForwardsTo() } -func (cfg tlsOptions) WithForwardsTo(hostname string) { - cfg.commonOpts.ForwardsTo = hostname +func (cfg *tlsOptions) WithForwardsTo(url *url.URL) { + cfg.commonOpts.ForwardsTo = url.Host } func (cfg tlsOptions) Extra() proto.BindExtra { diff --git a/config/tls_termination_test.go b/config/tls_termination_test.go index e3cd1a6..6cbb510 100644 --- a/config/tls_termination_test.go +++ b/config/tls_termination_test.go @@ -9,7 +9,7 @@ import ( ) func TestTLSTermination(t *testing.T) { - cases := testCases[tlsOptions, proto.TLSEndpoint]{ + cases := testCases[*tlsOptions, proto.TLSEndpoint]{ { name: "absent", opts: TLSEndpoint(), diff --git a/config/tls_test.go b/config/tls_test.go index b3907f1..0c88fd9 100644 --- a/config/tls_test.go +++ b/config/tls_test.go @@ -7,7 +7,7 @@ import ( ) func TestTLS(t *testing.T) { - cases := testCases[tlsOptions, proto.TLSEndpoint]{ + cases := testCases[*tlsOptions, proto.TLSEndpoint]{ { name: "basic", opts: TLSEndpoint(), diff --git a/config/tunnel_config.go b/config/tunnel_config.go index beadc8d..3b20039 100644 --- a/config/tunnel_config.go +++ b/config/tunnel_config.go @@ -1,6 +1,10 @@ package config -import "golang.ngrok.com/ngrok/internal/tunnel/proto" +import ( + "net/url" + + "golang.ngrok.com/ngrok/internal/tunnel/proto" +) // Tunnel is a marker interface for options that can be used to start // tunnels. @@ -14,9 +18,11 @@ type Tunnel interface { // the public interface with internal details. type tunnelConfigPrivate interface { ForwardsTo() string - WithForwardsTo(string) Extra() proto.BindExtra Proto() string Opts() any Labels() map[string]string + // Extra config when auto-forwarding to a URL. + // Normal operation should use the functional builder. + WithForwardsTo(*url.URL) } diff --git a/config/user_agent_filter_test.go b/config/user_agent_filter_test.go index acae73a..9976ccf 100644 --- a/config/user_agent_filter_test.go +++ b/config/user_agent_filter_test.go @@ -84,7 +84,7 @@ func testUserAgentFilter[T tunnelConfigPrivate, O any, OT any](t *testing.T, } func TestUserAgentFilter(t *testing.T) { - testUserAgentFilter[httpOptions](t, HTTPEndpoint, + testUserAgentFilter[*httpOptions](t, HTTPEndpoint, func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_UserAgentFilter { return h.UserAgentFilter }) diff --git a/config/webhook_verification_test.go b/config/webhook_verification_test.go index d0761a4..de8ce44 100644 --- a/config/webhook_verification_test.go +++ b/config/webhook_verification_test.go @@ -9,7 +9,7 @@ import ( ) func TestWebhookVerification(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/config/websocket_tcp_conversion_test.go b/config/websocket_tcp_conversion_test.go index edcdd4e..4c1092b 100644 --- a/config/websocket_tcp_conversion_test.go +++ b/config/websocket_tcp_conversion_test.go @@ -9,7 +9,7 @@ import ( ) func TestWebsocketTCPConversion(t *testing.T) { - cases := testCases[httpOptions, proto.HTTPEndpoint]{ + cases := testCases[*httpOptions, proto.HTTPEndpoint]{ { name: "absent", opts: HTTPEndpoint(), diff --git a/session.go b/session.go index 64163e1..9ed3180 100644 --- a/session.go +++ b/session.go @@ -829,7 +829,7 @@ func (s *sessionImpl) ListenAndForward(ctx context.Context, url *url.URL, cfg co } // Set 'Forwards To' - tunnelCfg.WithForwardsTo(url.Host) + tunnelCfg.WithForwardsTo(url) tun, err := s.Listen(ctx, cfg) if err != nil { diff --git a/tunnel_config.go b/tunnel_config.go index 697f795..3747b20 100644 --- a/tunnel_config.go +++ b/tunnel_config.go @@ -1,6 +1,10 @@ package ngrok -import "golang.ngrok.com/ngrok/internal/tunnel/proto" +import ( + "net/url" + + "golang.ngrok.com/ngrok/internal/tunnel/proto" +) // This is the internal-only interface that all config.Tunnel implementations // *also* implement. This lets us pull the necessary bits out of it without @@ -13,5 +17,7 @@ type tunnelConfigPrivate interface { Proto() string Opts() any Labels() map[string]string - WithForwardsTo(string) + // Extra config when auto-forwarding to a URL. + // Normal operation should use the functional builder. + WithForwardsTo(*url.URL) }