From 80f91bb4c0bfa038aea22316cfc9237817a1c9c9 Mon Sep 17 00:00:00 2001 From: bobzilladev Date: Wed, 1 Nov 2023 18:28:55 +0000 Subject: [PATCH] Handle StopTunnel request --- errors.go | 23 +++++++++++++++++++++++ examples/ngrok-forward-lite/main.go | 27 ++++++++++++++++++++------- internal/tunnel/client/raw_session.go | 8 +++++++- internal/tunnel/client/session.go | 13 +++++++++++++ internal/tunnel/client/tunnel.go | 17 ++++++++++++++--- internal/tunnel/proto/msg.go | 18 +++++++++++++----- session.go | 15 ++++++++++++++- tunnel.go | 8 +++++++- 8 files changed, 111 insertions(+), 18 deletions(-) diff --git a/errors.go b/errors.go index cb22d4a..4448538 100644 --- a/errors.go +++ b/errors.go @@ -132,3 +132,26 @@ func (e errSessionDial) Is(target error) bool { _, ok := target.(errSessionDial) return ok } + +// Generic ngrok error that requires no parsing +type ngrokError struct { + Message string + ErrCode string +} + +func (m ngrokError) Error() string { + return m.Message + "\n\n" + m.ErrCode +} + +func (m ngrokError) Msg() string { + return m.Message +} + +func (m ngrokError) ErrorCode() string { + return m.ErrCode +} + +func (e ngrokError) Is(target error) bool { + _, ok := target.(ngrokError) + return ok +} diff --git a/examples/ngrok-forward-lite/main.go b/examples/ngrok-forward-lite/main.go index 50cd234..4baa62c 100644 --- a/examples/ngrok-forward-lite/main.go +++ b/examples/ngrok-forward-lite/main.go @@ -57,9 +57,7 @@ func main() { } func run(ctx context.Context, backend *url.URL) error { - fwd, err := ngrok.ListenAndForward(ctx, - backend, - config.HTTPEndpoint(), + sess, err := ngrok.Connect(ctx, ngrok.WithAuthtokenFromEnv(), ngrok.WithLogger(&logger{lvl: ngrok_log.LogLevelDebug}), ) @@ -67,9 +65,24 @@ func run(ctx context.Context, backend *url.URL) error { return err } - l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{ - "url": fwd.URL(), - }) + for { + fwd, err := sess.ListenAndForward(ctx, + backend, + config.HTTPEndpoint(), + ) + if err != nil { + return err + } - return fwd.Wait() + l.Log(ctx, ngrok_log.LogLevelInfo, "ingress established", map[string]any{ + "url": fwd.URL(), + }) + + err = fwd.Wait() + if err == nil { + return nil + } + l.Log(ctx, ngrok_log.LogLevelWarn, "accept error. now setting up a new forwarder.", + map[string]any{"err": err}) + } } diff --git a/internal/tunnel/client/raw_session.go b/internal/tunnel/client/raw_session.go index 224dba2..c9bc6e7 100644 --- a/internal/tunnel/client/raw_session.go +++ b/internal/tunnel/client/raw_session.go @@ -37,6 +37,7 @@ type SessionHandler interface { OnStop(*proto.Stop, HandlerRespFunc) OnRestart(*proto.Restart, HandlerRespFunc) OnUpdate(*proto.Update, HandlerRespFunc) + OnStopTunnel(*proto.StopTunnel, HandlerRespFunc) } // A RawSession is a client session which handles authorization with the tunnel @@ -75,7 +76,7 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp req := proto.Auth{ ClientID: id, Extra: extra, - Version: []string{proto.Version}, + Version: proto.Version, } if err = s.rpc(proto.AuthReq, &req, &resp); err != nil { return @@ -201,6 +202,11 @@ func (s *rawSession) Accept() (netx.LoggedConn, error) { if deserialize(&req) { go s.handler.OnUpdate(&req, respFunc) } + case proto.StopTunnelReq: + var req proto.StopTunnel + if deserialize(&req) { + go s.handler.OnStopTunnel(&req, respFunc) + } default: return netx.NewLoggedConn(s.Logger, raw, "type", "proxy", "sess", s.id), nil } diff --git a/internal/tunnel/client/session.go b/internal/tunnel/client/session.go index 398dfad..54b3bcb 100644 --- a/internal/tunnel/client/session.go +++ b/internal/tunnel/client/session.go @@ -67,6 +67,10 @@ type Session interface { // Latency updates Latency() <-chan time.Duration + // Close the tunnel with this clientID, with an error that will be reported + // from the tunnel's Accept() method. + CloseTunnel(clientID string, err error) error + // Closes the session Close() error } @@ -176,6 +180,15 @@ func (s *session) SrvInfo() (proto.SrvInfoResp, error) { return s.raw.SrvInfo() } +func (s *session) CloseTunnel(clientId string, err error) error { + t, ok := s.getTunnel(clientId) + if !ok { + return proto.StringError("no listener found for client id " + clientId) + } + t.CloseWithError(err) + return nil +} + func (s *session) Close() error { return s.raw.Close() } diff --git a/internal/tunnel/client/tunnel.go b/internal/tunnel/client/tunnel.go index 5cb1871..7871493 100644 --- a/internal/tunnel/client/tunnel.go +++ b/internal/tunnel/client/tunnel.go @@ -35,8 +35,9 @@ type tunnel struct { labels map[string]string forwardsTo string - accept chan *ProxyConn // new connections come on this channel - unlisten func() error // call this function to close the tunnel + accept chan *ProxyConn // new connections come on this channel + unlisten func() error // call this function to close the tunnel + closeError error // error to use on accept error after a tunnel close shut shutdown // for clean shutdowns } @@ -54,6 +55,7 @@ func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsT accept: make(chan *ProxyConn), unlisten: func() error { return s.unlisten(resp.ClientID) }, forwardsTo: forwardsTo, + closeError: errors.New("Listener closed"), } } @@ -69,6 +71,7 @@ func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels accept: make(chan *ProxyConn), unlisten: func() error { return s.unlisten(resp.ID) }, forwardsTo: forwardsTo, + closeError: errors.New("Listener closed"), } } @@ -83,11 +86,19 @@ func (t *tunnel) handleConn(r *ProxyConn) { func (t *tunnel) Accept() (*ProxyConn, error) { conn, ok := <-t.accept if !ok { - return nil, errors.New("Tunnel closed") + return nil, t.closeError } return conn, nil } +func (t *tunnel) CloseWithError(closeError error) { + t.closeError = closeError + // Skips the call to unlisten, since the remote has already rejected it. + t.shut.Shut(func() { + close(t.accept) + }) +} + // Closes the Tunnel by asking the remote machine to deallocate its listener, or // an error if the request failed. func (t *tunnel) Close() (err error) { diff --git a/internal/tunnel/proto/msg.go b/internal/tunnel/proto/msg.go index 34b1756..2b70613 100644 --- a/internal/tunnel/proto/msg.go +++ b/internal/tunnel/proto/msg.go @@ -23,16 +23,17 @@ const ( StartTunnelWithLabelReq ReqType = 7 // sent from the server to the client - ProxyReq ReqType = 3 - RestartReq ReqType = 4 - StopReq ReqType = 5 - UpdateReq ReqType = 6 + ProxyReq ReqType = 3 + RestartReq ReqType = 4 + StopReq ReqType = 5 + UpdateReq ReqType = 6 + StopTunnelReq ReqType = 9 // sent from client to the server SrvInfoReq ReqType = 8 ) -const Version = "2" +var Version = []string{"3", "2"} // integers in priority order // Match the error code in the format (ERR_NGROK_\d+). var ngrokErrorCodeRegex = regexp.MustCompile(`(ERR_NGROK_\d+)`) @@ -400,6 +401,13 @@ type UpdateResp struct { Error string // an error, if one } +// This request is sent from the server to the ngrok agent to request a tunnel to close, with a notice to display to the user +type StopTunnel struct { + ClientID string `json:"Id"` // a session-unique bind ID generated by the client + Message string // an message to display to the user + ErrorCode string // an error code to display to the user. empty on OK +} + type SrvInfo struct{} type SrvInfoResp struct { diff --git a/session.go b/session.go index 9ed3180..a060394 100644 --- a/session.go +++ b/session.go @@ -771,6 +771,10 @@ func (s *sessionImpl) setInner(raw *sessionInner) { atomic.StorePointer(&s.raw, unsafe.Pointer(raw)) } +func (s *sessionImpl) closeTunnel(clientID string, err error) error { + return s.inner().CloseTunnel(clientID, err) +} + func (s *sessionImpl) Close() error { return s.inner().Close() } @@ -908,7 +912,7 @@ func (s *sessionImpl) Latency() <-chan time.Duration { type remoteCallbackHandler struct { log15.Logger - sess Session + sess *sessionImpl stopHandler ServerCommandHandler restartHandler ServerCommandHandler updateHandler ServerCommandHandler @@ -959,3 +963,12 @@ func (rc remoteCallbackHandler) OnUpdate(_ *proto.Update, respond tunnel_client. } } } + +func (rc remoteCallbackHandler) OnStopTunnel(stopTunnel *proto.StopTunnel, respond tunnel_client.HandlerRespFunc) { + ngrokErr := &ngrokError{Message: stopTunnel.Message, ErrCode: stopTunnel.ErrorCode} + // close the tunnel and maintain the session + err := rc.sess.closeTunnel(stopTunnel.ClientID, ngrokErr) + if err != nil { + rc.Warn("error closing tunnel", "error", err) + } +} diff --git a/tunnel.go b/tunnel.go index 4b9db68..a878328 100644 --- a/tunnel.go +++ b/tunnel.go @@ -149,7 +149,13 @@ type tunnelImpl struct { func (t *tunnelImpl) Accept() (net.Conn, error) { conn, err := t.Tunnel.Accept() if err != nil { - return nil, errAcceptFailed{Inner: err} + err = errAcceptFailed{Inner: err} + if s, ok := t.Sess.(*sessionImpl); ok { + if si := s.inner(); si != nil { + si.Logger.Info(err.Error(), "clientid", t.Tunnel.ID()) + } + } + return nil, err } return &connImpl{ Conn: conn.Conn,