Skip to content

Commit

Permalink
tcp: add extended tcp support (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey authored Dec 27, 2022
1 parent 87e770b commit 7e0f697
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 32 deletions.
10 changes: 5 additions & 5 deletions api/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ func newTunnel(conn *pb.Connection, browserCmd, serviceAccount, serviceAccountFi
listenAddr = *conn.ListenAddr
}

pxy, err := getProxy(conn)
destinationAddr, proxyURL, err := tcptunnel.ParseURLs(conn.GetRemoteAddr(), conn.GetPomeriumUrl())
if err != nil {
return nil, "", fmt.Errorf("cannot determine proxy host: %w", err)
return nil, "", err
}

var tlsCfg *tls.Config
if pxy.Scheme == "https" {
if proxyURL.Scheme == "https" {
tlsCfg, err = getTLSConfig(conn)
if err != nil {
return nil, "", fmt.Errorf("tls: %w", err)
}
}

return tcptunnel.New(
tcptunnel.WithDestinationHost(conn.GetRemoteAddr()),
tcptunnel.WithProxyHost(pxy.Host),
tcptunnel.WithDestinationHost(destinationAddr),
tcptunnel.WithProxyHost(proxyURL.Host),
tcptunnel.WithServiceAccount(serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountFile),
tcptunnel.WithTLSConfig(tlsCfg),
Expand Down
32 changes: 5 additions & 27 deletions cmd/pomerium-cli/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ import (
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"os/signal"
"strings"
"syscall"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -39,32 +36,13 @@ var tcpCmd = &cobra.Command{
Short: "creates a TCP tunnel through Pomerium",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
dstHost := args[0]
dstHostname, _, err := net.SplitHostPort(dstHost)
destinationAddr, proxyURL, err := tcptunnel.ParseURLs(args[0], tcpCmdOptions.pomeriumURL)
if err != nil {
return fmt.Errorf("invalid destination: %w", err)
}

pomeriumURL := &url.URL{
Scheme: "https",
Host: net.JoinHostPort(dstHostname, "443"),
}
if tcpCmdOptions.pomeriumURL != "" {
pomeriumURL, err = url.Parse(tcpCmdOptions.pomeriumURL)
if err != nil {
return fmt.Errorf("invalid pomerium URL: %w", err)
}
if !strings.Contains(pomeriumURL.Host, ":") {
if pomeriumURL.Scheme == "https" {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "443")
} else {
pomeriumURL.Host = net.JoinHostPort(pomeriumURL.Hostname(), "80")
}
}
return err
}

var tlsConfig *tls.Config
if pomeriumURL.Scheme == "https" {
if proxyURL.Scheme == "https" {
tlsConfig, err = getTLSConfig()
if err != nil {
return err
Expand All @@ -81,8 +59,8 @@ var tcpCmd = &cobra.Command{

tun := tcptunnel.New(
tcptunnel.WithBrowserCommand(browserOptions.command),
tcptunnel.WithDestinationHost(dstHost),
tcptunnel.WithProxyHost(pomeriumURL.Host),
tcptunnel.WithDestinationHost(destinationAddr),
tcptunnel.WithProxyHost(proxyURL.Host),
tcptunnel.WithServiceAccount(serviceAccountOptions.serviceAccount),
tcptunnel.WithServiceAccountFile(serviceAccountOptions.serviceAccountFile),
tcptunnel.WithTLSConfig(tlsConfig),
Expand Down
60 changes: 60 additions & 0 deletions tcptunnel/urls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package tcptunnel

import (
"fmt"
"net"
"net/url"
"strings"
)

func ParseURLs(destination string, pomeriumURL string) (destinationAddr string, proxyURL *url.URL, err error) {
if strings.Contains(destination, "://") {
destinationURL, err := url.Parse(destination)
if err != nil {
return "", nil, fmt.Errorf("invalid destination")
}

paths := strings.Split(destinationURL.Path, "/")[1:]
if len(paths) == 0 {
destinationAddr = destinationURL.Host
proxyURL = &url.URL{
Scheme: strings.TrimPrefix(destinationURL.Scheme, "tcp+"),
Host: destinationURL.Hostname(),
}
} else {
destinationAddr = paths[0]
proxyURL = &url.URL{
Scheme: strings.TrimPrefix(destinationURL.Scheme, "tcp+"),
Host: destinationURL.Host,
}
}
} else if h, p, err := net.SplitHostPort(destination); err == nil {
destinationAddr = net.JoinHostPort(h, p)
proxyURL = &url.URL{
Scheme: "https",
Host: h,
}
} else {
return "", nil, fmt.Errorf("invalid destination")
}

if pomeriumURL != "" {
proxyURL, err = url.Parse(pomeriumURL)
if err != nil {
return "", nil, fmt.Errorf("invalid pomerium url")
}
if proxyURL.Host == "" {
return "", nil, fmt.Errorf("invalid pomerium url")
}
}

if !strings.Contains(proxyURL.Host, ":") {
if proxyURL.Scheme == "https" {
proxyURL.Host = net.JoinHostPort(proxyURL.Host, "443")
} else {
proxyURL.Host = net.JoinHostPort(proxyURL.Host, "80")
}
}

return destinationAddr, proxyURL, nil
}
53 changes: 53 additions & 0 deletions tcptunnel/urls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package tcptunnel

import (
"errors"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseURLs(t *testing.T) {
t.Parallel()

for _, tc := range []struct {
name string
destination, pomeriumURL string
destinationAddr string
proxyURL string
err error
}{
{"invalid destination", "", "", "", "", errors.New("invalid destination")},
{"host:port", "redis.example.com:6379", "", "redis.example.com:6379", "https://redis.example.com:443", nil},
{"https url", "tcp+https://redis.example.com:6379", "", "redis.example.com:6379", "https://redis.example.com:443", nil},
{"http url", "http://redis.example.com:6379", "", "redis.example.com:6379", "http://redis.example.com:80", nil},
{"https url path", "https://proxy.example.com/redis.example.com:6379", "", "redis.example.com:6379", "https://proxy.example.com:443", nil},
{"non standard port path", "https://proxy.example.com:8443/redis.example.com:6379", "", "redis.example.com:6379", "https://proxy.example.com:8443", nil},

{"invalid pomerium url", "redis.example.com:6379", "example.com:1234", "", "", errors.New("invalid pomerium url")},
{"pomerium url", "redis.example.com:6379", "https://proxy.example.com", "redis.example.com:6379", "https://proxy.example.com:443", nil},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

var expectedProxyURL *url.URL
if tc.proxyURL != "" {
expectedProxyURL = must(url.Parse(tc.proxyURL))
}

destinationAddr, proxyURL, err := ParseURLs(tc.destination, tc.pomeriumURL)
assert.Equal(t, tc.destinationAddr, destinationAddr)
assert.Equal(t, expectedProxyURL, proxyURL)
assert.Equal(t, tc.err, err)
})
}
}

func must[T any](ret T, err error) T {
if err != nil {
panic(err)
}
return ret
}

0 comments on commit 7e0f697

Please sign in to comment.