Skip to content

Commit

Permalink
Handle StopTunnel request
Browse files Browse the repository at this point in the history
  • Loading branch information
bobzilladev committed Nov 22, 2023
1 parent fd2067b commit 80f91bb
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 18 deletions.
23 changes: 23 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,26 @@ func (e errSessionDial) Is(target error) bool {
_, ok := target.(errSessionDial)
return ok
}

// Generic ngrok error that requires no parsing
type ngrokError struct {
Message string
ErrCode string
}

func (m ngrokError) Error() string {
return m.Message + "\n\n" + m.ErrCode
}

func (m ngrokError) Msg() string {
return m.Message
}

func (m ngrokError) ErrorCode() string {
return m.ErrCode
}

func (e ngrokError) Is(target error) bool {
_, ok := target.(ngrokError)
return ok
}
27 changes: 20 additions & 7 deletions examples/ngrok-forward-lite/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,32 @@ func main() {
}

func run(ctx context.Context, backend *url.URL) error {
fwd, err := ngrok.ListenAndForward(ctx,
backend,
config.HTTPEndpoint(),
sess, err := ngrok.Connect(ctx,
ngrok.WithAuthtokenFromEnv(),
ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}),
)
if err != nil {
return err
}

l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{
"url": fwd.URL(),
})
for {
fwd, err := sess.ListenAndForward(ctx,
backend,
config.HTTPEndpoint(),
)
if err != nil {
return err
}

return fwd.Wait()
l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{
"url": fwd.URL(),
})

err = fwd.Wait()
if err == nil {
return nil
}
l.Log(ctx, ngrok_log.LogLevelWarn, "accept error. now setting up a new forwarder.",
map[string]any{"err": err})
}
}
8 changes: 7 additions & 1 deletion internal/tunnel/client/raw_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type SessionHandler interface {
OnStop(*proto.Stop, HandlerRespFunc)
OnRestart(*proto.Restart, HandlerRespFunc)
OnUpdate(*proto.Update, HandlerRespFunc)
OnStopTunnel(*proto.StopTunnel, HandlerRespFunc)
}

// A RawSession is a client session which handles authorization with the tunnel
Expand Down Expand Up @@ -75,7 +76,7 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp
req := proto.Auth{
ClientID: id,
Extra: extra,
Version: []string{proto.Version},
Version: proto.Version,
}
if err = s.rpc(proto.AuthReq, &req, &resp); err != nil {
return
Expand Down Expand Up @@ -201,6 +202,11 @@ func (s *rawSession) Accept() (netx.LoggedConn, error) {
if deserialize(&req) {
go s.handler.OnUpdate(&req, respFunc)
}
case proto.StopTunnelReq:
var req proto.StopTunnel
if deserialize(&req) {
go s.handler.OnStopTunnel(&req, respFunc)
}
default:
return netx.NewLoggedConn(s.Logger, raw, "type", "proxy", "sess", s.id), nil
}
Expand Down
13 changes: 13 additions & 0 deletions internal/tunnel/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type Session interface {
// Latency updates
Latency() <-chan time.Duration

// Close the tunnel with this clientID, with an error that will be reported
// from the tunnel's Accept() method.
CloseTunnel(clientID string, err error) error

// Closes the session
Close() error
}
Expand Down Expand Up @@ -176,6 +180,15 @@ func (s *session) SrvInfo() (proto.SrvInfoResp, error) {
return s.raw.SrvInfo()
}

func (s *session) CloseTunnel(clientId string, err error) error {
t, ok := s.getTunnel(clientId)
if !ok {
return proto.StringError("no listener found for client id " + clientId)
}
t.CloseWithError(err)
return nil
}

func (s *session) Close() error {
return s.raw.Close()
}
Expand Down
17 changes: 14 additions & 3 deletions internal/tunnel/client/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ type tunnel struct {
labels map[string]string
forwardsTo string

accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
closeError error // error to use on accept error after a tunnel close

shut shutdown // for clean shutdowns
}
Expand All @@ -54,6 +55,7 @@ func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsT
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
}
}

Expand All @@ -69,6 +71,7 @@ func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
}
}

Expand All @@ -83,11 +86,19 @@ func (t *tunnel) handleConn(r *ProxyConn) {
func (t *tunnel) Accept() (*ProxyConn, error) {
conn, ok := <-t.accept
if !ok {
return nil, errors.New("Tunnel closed")
return nil, t.closeError
}
return conn, nil
}

func (t *tunnel) CloseWithError(closeError error) {
t.closeError = closeError
// Skips the call to unlisten, since the remote has already rejected it.
t.shut.Shut(func() {
close(t.accept)
})
}

// Closes the Tunnel by asking the remote machine to deallocate its listener, or
// an error if the request failed.
func (t *tunnel) Close() (err error) {
Expand Down
18 changes: 13 additions & 5 deletions internal/tunnel/proto/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ const (
StartTunnelWithLabelReq ReqType = 7

// sent from the server to the client
ProxyReq ReqType = 3
RestartReq ReqType = 4
StopReq ReqType = 5
UpdateReq ReqType = 6
ProxyReq ReqType = 3
RestartReq ReqType = 4
StopReq ReqType = 5
UpdateReq ReqType = 6
StopTunnelReq ReqType = 9

// sent from client to the server
SrvInfoReq ReqType = 8
)

const Version = "2"
var Version = []string{"3", "2"} // integers in priority order

// Match the error code in the format (ERR_NGROK_\d+).
var ngrokErrorCodeRegex = regexp.MustCompile(`(ERR_NGROK_\d+)`)
Expand Down Expand Up @@ -400,6 +401,13 @@ type UpdateResp struct {
Error string // an error, if one
}

// This request is sent from the server to the ngrok agent to request a tunnel to close, with a notice to display to the user
type StopTunnel struct {
ClientID string `json:"Id"` // a session-unique bind ID generated by the client
Message string // an message to display to the user
ErrorCode string // an error code to display to the user. empty on OK
}

type SrvInfo struct{}

type SrvInfoResp struct {
Expand Down
15 changes: 14 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ func (s *sessionImpl) setInner(raw *sessionInner) {
atomic.StorePointer(&s.raw, unsafe.Pointer(raw))
}

func (s *sessionImpl) closeTunnel(clientID string, err error) error {
return s.inner().CloseTunnel(clientID, err)
}

func (s *sessionImpl) Close() error {
return s.inner().Close()
}
Expand Down Expand Up @@ -908,7 +912,7 @@ func (s *sessionImpl) Latency() <-chan time.Duration {

type remoteCallbackHandler struct {
log15.Logger
sess Session
sess *sessionImpl
stopHandler ServerCommandHandler
restartHandler ServerCommandHandler
updateHandler ServerCommandHandler
Expand Down Expand Up @@ -959,3 +963,12 @@ func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client.
}
}
}

func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) {
ngrokErr := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode}
// close the tunnel and maintain the session
err := rc.sess.closeTunnel(stopTunnel.ClientID, ngrokErr)
if err != nil {
rc.Warn("error closing tunnel", "error", err)
}
}
8 changes: 7 additions & 1 deletion tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ type tunnelImpl struct {
func (t *tunnelImpl) Accept() (net.Conn, error) {
conn, err := t.Tunnel.Accept()
if err != nil {
return nil, errAcceptFailed{Inner: err}
err = errAcceptFailed{Inner: err}
if s, ok := t.Sess.(*sessionImpl); ok {
if si := s.inner(); si != nil {
si.Logger.Info(err.Error(), "clientid", t.Tunnel.ID())
}
}
return nil, err
}
return &connImpl{
Conn: conn.Conn,
Expand Down

0 comments on commit 80f91bb

Please sign in to comment.