diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 659ec8f..0c57f37 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -1,7 +1,7 @@ { "ImportPath": "github.com/hashrocket/ws", "GoVersion": "go1.7", - "GodepVersion": "v74", + "GodepVersion": "v79", "Deps": [ { "ImportPath": "github.com/chzyer/readline", @@ -15,8 +15,8 @@ }, { "ImportPath": "github.com/gorilla/websocket", - "Comment": "v1.0.0-20-ga69d25b", - "Rev": "a69d25be2fe2923a97c2af6849b2f52426f68fc0" + "Comment": "v1.2.0-44-g5ed622c", + "Rev": "5ed622c449da6d44c3c8329331ff47a9e5844f71" }, { "ImportPath": "github.com/inconshreveable/mousetrap", diff --git a/connection.go b/connection.go index 9a6265e..5adfd6b 100644 --- a/connection.go +++ b/connection.go @@ -18,15 +18,16 @@ type session struct { errChan chan error } -func connect(url, origin string, rlConf *readline.Config, allowInsecure bool) error { +func connect(url, origin string, rlConf *readline.Config, allowInsecure, enableCompression bool) error { headers := make(http.Header) headers.Add("Origin", origin) dialer := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, - TLSClientConfig:&tls.Config{ + TLSClientConfig: &tls.Config{ InsecureSkipVerify: allowInsecure, }, + EnableCompression: enableCompression, } ws, _, err := dialer.Dial(url, headers) if err != nil { diff --git a/main.go b/main.go index 7fdc26b..a09c1b4 100644 --- a/main.go +++ b/main.go @@ -17,7 +17,8 @@ const Version = "0.2.1" var options struct { origin string printVersion bool - insecure bool + insecure bool + compression bool } func main() { @@ -29,6 +30,7 @@ func main() { rootCmd.Flags().StringVarP(&options.origin, "origin", "o", "", "websocket origin") rootCmd.Flags().BoolVarP(&options.printVersion, "version", "v", false, "print version") rootCmd.Flags().BoolVarP(&options.insecure, "insecure", "k", false, "skip ssl certificate check") + rootCmd.Flags().BoolVarP(&options.compression, "compression", "c", false, "enable compression") rootCmd.Execute() } @@ -72,7 +74,7 @@ func root(cmd *cobra.Command, args []string) { err = connect(dest.String(), origin, &readline.Config{ Prompt: "> ", HistoryFile: historyFile, - }, options.insecure) + }, options.insecure, options.compression) if err != nil { fmt.Fprintln(os.Stderr, err) if err != io.EOF && err != readline.ErrInterrupt { diff --git a/vendor/github.com/gorilla/websocket/.gitignore b/vendor/github.com/gorilla/websocket/.gitignore index ac71020..cd3fcd1 100644 --- a/vendor/github.com/gorilla/websocket/.gitignore +++ b/vendor/github.com/gorilla/websocket/.gitignore @@ -22,4 +22,4 @@ _testmain.go *.exe .idea/ -*.iml \ No newline at end of file +*.iml diff --git a/vendor/github.com/gorilla/websocket/.travis.yml b/vendor/github.com/gorilla/websocket/.travis.yml index 66435ac..1f73047 100644 --- a/vendor/github.com/gorilla/websocket/.travis.yml +++ b/vendor/github.com/gorilla/websocket/.travis.yml @@ -4,8 +4,12 @@ sudo: false matrix: include: - go: 1.4 - - go: 1.5 - - go: 1.6 + - go: 1.5.x + - go: 1.6.x + - go: 1.7.x + - go: 1.8.x + - go: 1.9.x + - go: 1.10.x - go: tip allow_failures: - go: tip diff --git a/vendor/github.com/gorilla/websocket/AUTHORS b/vendor/github.com/gorilla/websocket/AUTHORS index b003eca..1931f40 100644 --- a/vendor/github.com/gorilla/websocket/AUTHORS +++ b/vendor/github.com/gorilla/websocket/AUTHORS @@ -4,5 +4,6 @@ # Please keep the list sorted. Gary Burd +Google LLC (https://opensource.google.com/) Joachim Bauch diff --git a/vendor/github.com/gorilla/websocket/README.md b/vendor/github.com/gorilla/websocket/README.md index 9d71959..20e391f 100644 --- a/vendor/github.com/gorilla/websocket/README.md +++ b/vendor/github.com/gorilla/websocket/README.md @@ -3,6 +3,9 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. +[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket) +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) + ### Documentation * [API Reference](http://godoc.org/github.com/gorilla/websocket) @@ -43,12 +46,12 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn Send pings and receive pongsYesNo Get the type of a received data messageYesYes, see note 2 Other Features -Limit size of received messageYesNo +Compression ExtensionsExperimentalNo Read message using io.ReaderYesNo, see note 3 Write message using io.WriteCloserYesNo, see note 3 -Notes: +Notes: 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). 2. The application can get the type of a received data message by implementing diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go index 879d33e..41f8ed5 100644 --- a/vendor/github.com/gorilla/websocket/client.go +++ b/vendor/github.com/gorilla/websocket/client.go @@ -5,10 +5,8 @@ package websocket import ( - "bufio" "bytes" "crypto/tls" - "encoding/base64" "errors" "io" "io/ioutil" @@ -23,6 +21,8 @@ import ( // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + // NewClient creates a new client connection using the given net connection. // The URL u specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies @@ -64,61 +64,28 @@ type Dialer struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration - // Input and output buffer sizes. If the buffer size is zero, then a - // default value of 4096 is used. + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the client's requested subprotocols. Subprotocols []string -} - -var errMalformedURL = errors.New("malformed ws or wss URL") - -// parseURL parses the URL. -// -// This function is a replacement for the standard library url.Parse function. -// In Go 1.4 and earlier, url.Parse loses information from the path. -func parseURL(s string) (*url.URL, error) { - // From the RFC: - // - // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] - // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - - var u url.URL - switch { - case strings.HasPrefix(s, "ws://"): - u.Scheme = "ws" - s = s[len("ws://"):] - case strings.HasPrefix(s, "wss://"): - u.Scheme = "wss" - s = s[len("wss://"):] - default: - return nil, errMalformedURL - } - - if i := strings.Index(s, "?"); i >= 0 { - u.RawQuery = s[i+1:] - s = s[:i] - } - if i := strings.Index(s, "/"); i >= 0 { - u.Opaque = s[i:] - s = s[:i] - } else { - u.Opaque = "/" - } + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool - u.Host = s - - if strings.Contains(u.Host, "@") { - // Don't bother parsing user information because user information is - // not allowed in websocket URIs. - return nil, errMalformedURL - } - - return &u, nil + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar } +var errMalformedURL = errors.New("malformed ws or wss URL") + func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host @@ -137,11 +104,15 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { return hostPort, hostNoPort } -// DefaultDialer is a dialer with all fields set to the default zero values. +// DefaultDialer is a dialer with all fields set to the default values. var DefaultDialer = &Dialer{ - Proxy: http.ProxyFromEnvironment, + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, } +// nilDialer is dialer to use when receiver is nil. +var nilDialer Dialer = *DefaultDialer + // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // Use the response.Header to get the selected subprotocol @@ -154,9 +125,7 @@ var DefaultDialer = &Dialer{ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { - d = &Dialer{ - Proxy: http.ProxyFromEnvironment, - } + d = &nilDialer } challengeKey, err := generateChallengeKey() @@ -164,7 +133,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - u, err := parseURL(urlStr) + u, err := url.Parse(urlStr) if err != nil { return nil, nil, err } @@ -193,6 +162,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re Host: u.Host, } + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + // Set the request headers using the capitalization for names and values in // RFC examples. Although the capitalization shouldn't matter, there are // servers that depend on it. The Header.Set method is not used because the @@ -214,29 +190,18 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs default: req.Header[k] = vs } } - hostPort, hostNoPort := hostPortNoPort(u) - - var proxyURL *url.URL - // Check wether the proxy method has been configured - if d.Proxy != nil { - proxyURL, err = d.Proxy(req) - } - if err != nil { - return nil, nil, err - } - - var targetHostPort string - if proxyURL != nil { - targetHostPort, _ = hostPortNoPort(proxyURL) - } else { - targetHostPort = hostPort + if d.EnableCompression { + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} } var deadline time.Time @@ -244,13 +209,47 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re deadline = time.Now().Add(d.HandshakeTimeout) } + // Get network dial function. netDial := d.NetDial if netDial == nil { netDialer := &net.Dialer{Deadline: deadline} netDial = netDialer.Dial } - netConn, err := netDial("tcp", targetHostPort) + // If needed, wrap the dial function to set the connection deadline. + if !deadline.Equal(time.Time{}) { + forwardDial := netDial + netDial = func(network, addr string) (net.Conn, error) { + c, err := forwardDial(network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + if err != nil { + return nil, nil, err + } + netDial = dialer.Dial + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + netConn, err := netDial("tcp", hostPort) if err != nil { return nil, nil, err } @@ -261,42 +260,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } }() - if err := netConn.SetDeadline(deadline); err != nil { - return nil, nil, err - } - - if proxyURL != nil { - connectHeader := make(http.Header) - if user := proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) - } - } - connectReq := &http.Request{ - Method: "CONNECT", - URL: &url.URL{Opaque: hostPort}, - Host: hostPort, - Header: connectHeader, - } - - connectReq.Write(netConn) - - // Read response. - // Okay to use and discard buffered reader here, because - // TLS server will not speak until spoken to. - br := bufio.NewReader(netConn) - resp, err := http.ReadResponse(br, connectReq) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 200 { - f := strings.SplitN(resp.Status, " ", 2) - return nil, nil, errors.New(f[1]) - } - } - if u.Scheme == "https" { cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { @@ -324,6 +287,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re if err != nil { return nil, nil, err } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || @@ -337,6 +307,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, resp, ErrBadHandshake } + for _, ext := range parseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") @@ -344,32 +328,3 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netConn = nil // to avoid close in defer. return conn, resp, nil } - -// cloneTLSConfig clones all public fields except the fields -// SessionTicketsDisabled and SessionTicketKey. This avoids copying the -// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a -// config in active use. -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } -} diff --git a/vendor/github.com/gorilla/websocket/client_clone.go b/vendor/github.com/gorilla/websocket/client_clone.go new file mode 100644 index 0000000..4f0d943 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_clone.go @@ -0,0 +1,16 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "crypto/tls" + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/vendor/github.com/gorilla/websocket/client_clone_legacy.go b/vendor/github.com/gorilla/websocket/client_clone_legacy.go new file mode 100644 index 0000000..babb007 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client_clone_legacy.go @@ -0,0 +1,38 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +import "crypto/tls" + +// cloneTLSConfig clones all public fields except the fields +// SessionTicketsDisabled and SessionTicketKey. This avoids copying the +// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a +// config in active use. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/vendor/github.com/gorilla/websocket/compression.go b/vendor/github.com/gorilla/websocket/compression.go index e2ac761..813ffb1 100644 --- a/vendor/github.com/gorilla/websocket/compression.go +++ b/vendor/github.com/gorilla/websocket/compression.go @@ -1,4 +1,4 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -9,22 +9,48 @@ import ( "errors" "io" "strings" + "sync" ) -func decompressNoContextTakeover(r io.Reader) io.Reader { +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" - return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} } -func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} - fw, err := flate.NewWriter(tw, 3) - return &flateWrapper{fw: fw, tw: tw}, err + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -63,17 +89,26 @@ func (w *truncWriter) Write(p []byte) (int, error) { return n + nn, err } -type flateWrapper struct { +type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter + p *sync.Pool } -func (w *flateWrapper) Write(p []byte) (int, error) { +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } return w.fw.Write(p) } -func (w *flateWrapper) Close() error { +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } @@ -83,3 +118,31 @@ func (w *flateWrapper) Close() error { } return err2 } + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go index eb4334e..5f46bf4 100644 --- a/vendor/github.com/gorilla/websocket/conn.go +++ b/vendor/github.com/gorilla/websocket/conn.go @@ -13,6 +13,7 @@ import ( "math/rand" "net" "strconv" + "sync" "time" "unicode/utf8" ) @@ -75,7 +76,7 @@ const ( // is UTF-8 encoded text. PingMessage = 9 - // PongMessage denotes a ping control message. The optional message payload + // PongMessage denotes a pong control message. The optional message payload // is UTF-8 encoded text. PongMessage = 10 ) @@ -99,9 +100,8 @@ func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } -// CloseError represents close frame. +// CloseError represents a close message. type CloseError struct { - // Code is defined in RFC 6455, section 11.7. Code int @@ -180,6 +180,11 @@ var ( errInvalidControlFrame = errors.New("websocket: invalid control frame") ) +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + func hideTempErr(err error) error { if e, ok := err.(net.Error); ok && e.Temporary() { err = &netError{msg: e.Error(), timeout: e.Timeout()} @@ -218,42 +223,28 @@ func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } -func maskBytes(key [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= key[pos&3] - pos++ - } - return pos & 3 -} - -func newMaskKey() [4]byte { - n := rand.Uint32() - return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} -} - -// Conn represents a WebSocket connection. +// The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn isServer bool subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn and closeSent - closeSent bool // whether close message was sent - writeErr error - writeBuf []byte // frame is constructed in this buffer. - writePos int // end of data in writeBuf. - writeFrameType int // type of the current frame. - writeDeadline time.Time - messageWriter *messageWriter // the current low-level message writer - writer io.WriteCloser // the current writer returned to the application - isWriting bool // for best-effort concurrent write detection + mu chan bool // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error enableWriteCompression bool - writeCompress bool // whether next call to flushFrame should set RSV1 - newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields + reader io.ReadCloser // the current reader returned to the application readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. @@ -264,38 +255,83 @@ type Conn struct { readMaskKey [4]byte handlePong func(string) error handlePing func(string) error + handleClose func(int, string) error readErrCount int messageReader *messageReader // the current low-level reader readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.Reader + newDecompressionReader func(io.Reader) io.ReadCloser } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { + return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) +} + +type writeHook struct { + p []byte +} + +func (wh *writeHook) Write(p []byte) (int, error) { + wh.p = p + return len(p), nil +} + +func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { mu := make(chan bool, 1) mu <- true - if readBufferSize == 0 { - readBufferSize = defaultReadBufferSize + var br *bufio.Reader + if readBufferSize == 0 && brw != nil && brw.Reader != nil { + // Reuse the supplied bufio.Reader if the buffer has a useful size. + // This code assumes that peek on a reader returns + // bufio.Reader.buf[:0]. + brw.Reader.Reset(conn) + if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { + br = brw.Reader + } } - if readBufferSize < maxControlFramePayloadSize { - readBufferSize = maxControlFramePayloadSize + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + var writeBuf []byte + if writeBufferSize == 0 && brw != nil && brw.Writer != nil { + // Use the bufio.Writer's buffer if the buffer has a useful size. This + // code assumes that bufio.Writer.buf[:1] is passed to the + // bufio.Writer's underlying writer. + var wh writeHook + brw.Writer.Reset(&wh) + brw.Writer.WriteByte(0) + brw.Flush() + if cap(wh.p) >= maxFrameHeaderSize+256 { + writeBuf = wh.p[:cap(wh.p)] + } } - if writeBufferSize == 0 { - writeBufferSize = defaultWriteBufferSize + + if writeBuf == nil { + if writeBufferSize == 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) } c := &Conn{ isServer: isServer, - br: bufio.NewReaderSize(conn, readBufferSize), + br: br, conn: conn, mu: mu, readFinal: true, - writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), - writeFrameType: noFrame, - writePos: maxFrameHeaderSize, + writeBuf: writeBuf, enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, } + c.SetCloseHandler(nil) c.SetPingHandler(nil) c.SetPongHandler(nil) return c @@ -306,7 +342,8 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -// Close closes the underlying network connection without sending or waiting for a close frame. +// Close closes the underlying network connection without sending or waiting +// for a close message. func (c *Conn) Close() error { return c.conn.Close() } @@ -323,28 +360,38 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods -func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { <-c.mu defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if frameType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) - for _, buf := range bufs { - if len(buf) > 0 { - n, err := c.conn.Write(buf) - if n != len(buf) { - // Close on partial write. - c.conn.Close() - } - if err != nil { - return err - } - } + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) } return nil } @@ -394,84 +441,106 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if messageType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) - n, err := c.conn.Write(buf) - if n != 0 && n != len(buf) { - c.conn.Close() + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) } - return hideTempErr(err) -} - -// NextWriter returns a writer for the next message to send. The writer's Close -// method flushes the complete message to the network. -// -// There can be at most one open writer on a connection. NextWriter closes the -// previous writer if the application has not already done so. -func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { - if c.writeErr != nil { - return nil, c.writeErr + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) } + return err +} +func (c *Conn) prepWrite(messageType int) error { // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. if c.writer != nil { - err := c.writer.Close() - if err != nil { - return nil, err - } + c.writer.Close() + c.writer = nil } if !isControl(messageType) && !isData(messageType) { - return nil, errBadWriteOpCode + return errBadWriteOpCode } - c.writeFrameType = messageType - c.messageWriter = &messageWriter{c} + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + return err +} - var w io.WriteCloser = c.messageWriter +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + if err := c.prepWrite(messageType); err != nil { + return nil, err + } + + mw := &messageWriter{ + c: c, + frameType: messageType, + pos: maxFrameHeaderSize, + } + c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - c.writeCompress = true - var err error - w, err = c.newCompressionWriter(w) - if err != nil { - c.writer.Close() - return nil, err - } + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w } + return c.writer, nil +} - return w, nil +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) fatal(err error) error { + if w.err != nil { + w.err = err + w.c.writer = nil + } + return err } // flushFrame writes buffered data and extra as a frame to the network. The // final argument indicates that this is the last frame in the message. -func (c *Conn) flushFrame(final bool, extra []byte) error { - length := c.writePos - maxFrameHeaderSize + len(extra) +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. - if isControl(c.writeFrameType) && + if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { - c.messageWriter = nil - c.writer = nil - c.writeFrameType = noFrame - c.writePos = maxFrameHeaderSize - return errInvalidControlFrame + return w.fatal(errInvalidControlFrame) } - b0 := byte(c.writeFrameType) + b0 := byte(w.frameType) if final { b0 |= finalBit } - if c.writeCompress { + if w.compress { b0 |= rsv1Bit } - c.writeCompress = false + w.compress = false b1 := byte(0) if !c.isServer { @@ -504,10 +573,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if !c.isServer { key := newMaskKey() copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) - maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { - c.writeErr = errors.New("websocket: internal error, extra used in client mode") - return c.writeErr + return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) } } @@ -520,44 +588,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { } c.isWriting = true - c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false - // Setup for next frame. - c.writePos = maxFrameHeaderSize - c.writeFrameType = continuationFrame + if err != nil { + return w.fatal(err) + } + if final { - c.messageWriter = nil c.writer = nil - c.writeFrameType = noFrame + return nil } - return c.writeErr -} -type messageWriter struct{ c *Conn } - -func (w *messageWriter) err() error { - c := w.c - if c.messageWriter != w { - return errWriteClosed - } - if c.writeErr != nil { - return c.writeErr - } + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame return nil } func (w *messageWriter) ncopy(max int) (int, error) { - n := len(w.c.writeBuf) - w.c.writePos + n := len(w.c.writeBuf) - w.pos if n <= 0 { - if err := w.c.flushFrame(false, nil); err != nil { + if err := w.flushFrame(false, nil); err != nil { return 0, err } - n = len(w.c.writeBuf) - w.c.writePos + n = len(w.c.writeBuf) - w.pos } if n > max { n = max @@ -566,13 +625,13 @@ func (w *messageWriter) ncopy(max int) (int, error) { } func (w *messageWriter) Write(p []byte) (int, error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. - err := w.c.flushFrame(false, p) + err := w.flushFrame(false, p) if err != nil { return 0, err } @@ -585,16 +644,16 @@ func (w *messageWriter) Write(p []byte) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) WriteString(p string) (int, error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } nn := len(p) @@ -603,27 +662,27 @@ func (w *messageWriter) WriteString(p string) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } for { - if w.c.writePos == len(w.c.writeBuf) { - err = w.c.flushFrame(false, nil) + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) if err != nil { break } } var n int - n, err = r.Read(w.c.writeBuf[w.c.writePos:]) - w.c.writePos += n + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n nn += int64(n) if err != nil { if err == io.EOF { @@ -636,27 +695,59 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { } func (w *messageWriter) Close() error { - if err := w.err(); err != nil { + if w.err != nil { + return w.err + } + if err := w.flushFrame(true, nil); err != nil { return err } - return w.c.flushFrame(true, nil) + w.err = errWriteClosed + return nil +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + if err := c.prepWrite(messageType); err != nil { + return err + } + mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize} + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + w, err := c.NextWriter(messageType) if err != nil { return err } - if _, ok := w.(*messageWriter); ok && c.isServer { - // Optimize write as a single frame. - n := copy(c.writeBuf[c.writePos:], data) - c.writePos += n - data = data[n:] - err = c.flushFrame(true, data) - return err - } if _, err = w.Write(data); err != nil { return err } @@ -675,7 +766,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods func (c *Conn) advanceFrame() (int, error) { - // 1. Skip remainder of previous frame. if c.readRemaining > 0 { @@ -799,11 +889,9 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, err } case CloseMessage: - echoMessage := []byte{} closeCode := CloseNoStatusReceived closeText := "" if len(payload) >= 2 { - echoMessage = payload[:2] closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { return noFrame, c.handleProtocolError("invalid close code") @@ -813,7 +901,9 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") } } - c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait)) + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } return noFrame, &CloseError{Code: closeCode, Text: closeText} } @@ -836,6 +926,11 @@ func (c *Conn) handleProtocolError(message string) error { // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } c.messageReader = nil c.readLength = 0 @@ -848,11 +943,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { } if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} - var r io.Reader = c.messageReader + c.reader = c.messageReader if c.readDecompress { - r = c.newDecompressionReader(r) + c.reader = c.newDecompressionReader(c.reader) } - return frameType, r, nil + return frameType, c.reader, nil } } @@ -914,6 +1009,10 @@ func (r *messageReader) Read(b []byte) (int, error) { return 0, err } +func (r *messageReader) Close() error { + return nil +} + // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { @@ -935,20 +1034,54 @@ func (c *Conn) SetReadDeadline(t time.Time) error { } // SetReadLimit sets the maximum size for a message read from the peer. If a -// message exceeds the limit, the connection sends a close frame to the peer +// message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit } +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + // PingHandler returns the current ping handler func (c *Conn) PingHandler() func(appData string) error { return c.handlePing } // SetPingHandler sets the handler for ping messages received from the peer. -// The appData argument to h is the PING frame application data. The default +// The appData argument to h is the PING message application data. The default // ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { @@ -970,8 +1103,12 @@ func (c *Conn) PongHandler() func(appData string) error { } // SetPongHandler sets the handler for pong messages received from the peer. -// The appData argument to h is the PONG frame application data. The default +// The appData argument to h is the PONG message application data. The default // pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. func (c *Conn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } @@ -985,8 +1122,34 @@ func (c *Conn) UnderlyingConn() net.Conn { return c.conn } +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } buf := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(buf, uint16(closeCode)) copy(buf[2:], text) diff --git a/vendor/github.com/gorilla/websocket/conn_write.go b/vendor/github.com/gorilla/websocket/conn_write.go new file mode 100644 index 0000000..a509a21 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_write.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.8 + +package websocket + +import "net" + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} diff --git a/vendor/github.com/gorilla/websocket/conn_write_legacy.go b/vendor/github.com/gorilla/websocket/conn_write_legacy.go new file mode 100644 index 0000000..37edaff --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn_write_legacy.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.8 + +package websocket + +func (c *Conn) writeBufs(bufs ...[]byte) error { + for _, buf := range bufs { + if len(buf) > 0 { + if _, err := c.conn.Write(buf); err != nil { + return err + } + } + } + return nil +} diff --git a/vendor/github.com/gorilla/websocket/doc.go b/vendor/github.com/gorilla/websocket/doc.go index c901a7a..dcce1a6 100644 --- a/vendor/github.com/gorilla/websocket/doc.go +++ b/vendor/github.com/gorilla/websocket/doc.go @@ -6,9 +6,8 @@ // // Overview // -// The Conn type represents a WebSocket connection. A server application uses -// the Upgrade function from an Upgrader object with a HTTP request handler -// to get a pointer to a Conn: +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: // // var upgrader = websocket.Upgrader{ // ReadBufferSize: 1024, @@ -31,10 +30,12 @@ // for { // messageType, p, err := conn.ReadMessage() // if err != nil { +// log.Println(err) // return // } -// if err = conn.WriteMessage(messageType, p); err != nil { -// return err +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return // } // } // @@ -85,20 +86,26 @@ // and pong. Call the connection WriteControl, WriteMessage or NextWriter // methods to send a control message to the peer. // -// Connections handle received close messages by sending a close message to the -// peer and returning a *CloseError from the the NextReader, ReadMessage or the -// message Read method. +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. // -// Connections handle received ping and pong messages by invoking callback -// functions set with SetPingHandler and SetPongHandler methods. The callback -// functions are called from the NextReader, ReadMessage and the message Read -// methods. +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. // -// The default ping handler sends a pong to the peer. The application's reading -// goroutine can block for a short time while the handler writes the pong data -// to the connection. +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. // -// The application must read the connection to process ping, pong and close +// The application must read the connection to process close, ping and pong // messages sent from the peer. If the application is not otherwise interested // in messages from the peer, then the application should start a goroutine to // read and discard messages from the peer. A simple example is: @@ -118,9 +125,10 @@ // // Applications are responsible for ensuring that no more than one goroutine // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, -// WriteJSON) concurrently and that no more than one goroutine calls the read -// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, -// SetPingHandler) concurrently. +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. // // The Close and WriteControl methods can be called concurrently with all other // methods. @@ -136,17 +144,37 @@ // method fails the WebSocket handshake with HTTP status 403. // // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail -// the handshake if the Origin request header is present and not equal to the -// Host request header. +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. // -// An application can allow connections from any origin by specifying a -// function that always returns true: +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Compression EXPERIMENTAL +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. // // var upgrader = websocket.Upgrader{ -// CheckOrigin: func(r *http.Request) bool { return true }, +// EnableCompression: true, // } // -// The deprecated Upgrade function does not enforce an origin policy. It's the -// application's responsibility to check the Origin header before calling -// Upgrade. +// If compression was successfully negotiated with the connection's peer, any +// message received in compressed form will be automatically decompressed. +// All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(false) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. package websocket diff --git a/vendor/github.com/gorilla/websocket/json.go b/vendor/github.com/gorilla/websocket/json.go index 4f0e368..dc2c1f6 100644 --- a/vendor/github.com/gorilla/websocket/json.go +++ b/vendor/github.com/gorilla/websocket/json.go @@ -9,12 +9,14 @@ import ( "io" ) -// WriteJSON is deprecated, use c.WriteJSON instead. +// WriteJSON writes the JSON encoding of v as a message. +// +// Deprecated: Use c.WriteJSON instead. func WriteJSON(c *Conn, v interface{}) error { return c.WriteJSON(v) } -// WriteJSON writes the JSON encoding of v to the connection. +// WriteJSON writes the JSON encoding of v as a message. // // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. @@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error { return err2 } -// ReadJSON is deprecated, use c.ReadJSON instead. +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// Deprecated: Use c.ReadJSON instead. func ReadJSON(c *Conn, v interface{}) error { return c.ReadJSON(v) } diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go new file mode 100644 index 0000000..577fce9 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -0,0 +1,54 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build !appengine + +package websocket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/mask_safe.go b/vendor/github.com/gorilla/websocket/mask_safe.go new file mode 100644 index 0000000..2aac060 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask_safe.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// +build appengine + +package websocket + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/prepared.go b/vendor/github.com/gorilla/websocket/prepared.go new file mode 100644 index 0000000..1efffbd --- /dev/null +++ b/vendor/github.com/gorilla/websocket/prepared.go @@ -0,0 +1,103 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + err error + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan bool, 1) + mu <- true + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/vendor/github.com/gorilla/websocket/proxy.go b/vendor/github.com/gorilla/websocket/proxy.go new file mode 100644 index 0000000..bf2478e --- /dev/null +++ b/vendor/github.com/gorilla/websocket/proxy.go @@ -0,0 +1,77 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "encoding/base64" + "errors" + "net" + "net/http" + "net/url" + "strings" +) + +type netDialerFunc func(network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(network, addr) +} + +func init() { + proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil + }) +} + +type httpProxyDialer struct { + proxyURL *url.URL + fowardDial func(network, addr string) (net.Conn, error) +} + +func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(hpd.proxyURL) + conn, err := hpd.fowardDial(network, hostPort) + if err != nil { + return nil, err + } + + connectHeader := make(http.Header) + if user := hpd.proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } + + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } + + // Read response. It's OK to use and discard buffered reader here becaue + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + + if resp.StatusCode != 200 { + conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil +} diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go index 8402d20..aee2705 100644 --- a/vendor/github.com/gorilla/websocket/server.go +++ b/vendor/github.com/gorilla/websocket/server.go @@ -28,8 +28,9 @@ type Upgrader struct { HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer - // size is zero, then a default value of 4096 is used. The I/O buffer sizes - // do not limit the size of the messages that can be sent or received. + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the server's supported protocols in order of @@ -43,9 +44,19 @@ type Upgrader struct { Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If - // CheckOrigin is nil, the host in the Origin header must not be set or - // must match the host of the request. + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -69,7 +80,7 @@ func checkSameOrigin(r *http.Request) bool { if err != nil { return false } - return u.Host == r.Host + return equalASCIIFold(u.Host, r.Host) } func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { @@ -92,24 +103,31 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the -// application negotiated subprotocol (Sec-Websocket-Protocol). +// application negotiated subprotocol (Sec-WebSocket-Protocol). // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { - if r.Method != "GET" { - return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") - } - if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { - return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") - } + const badHandshake = "websocket: the client is not using the websocket protocol: " if !tokenListContainsValue(r.Header, "Connection", "upgrade") { - return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { - return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != "GET" { + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") } checkOrigin := u.CheckOrigin @@ -117,19 +135,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade checkOrigin = checkSameOrigin } if !checkOrigin(r) { - return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed") + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") } challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { - return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank") + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") } subprotocol := u.selectSubprotocol(r, responseHeader) + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + var ( netConn net.Conn - br *bufio.Reader err error ) @@ -137,30 +166,37 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } - var rw *bufio.ReadWriter - netConn, rw, err = h.Hijack() + var brw *bufio.ReadWriter + netConn, brw, err = h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } - br = rw.Reader - if br.Buffered() > 0 { + if brw.Reader.Buffered() > 0 { netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } - c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) + c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) c.subprotocol = subprotocol + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + p := c.writeBuf[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) if c.subprotocol != "" { - p = append(p, "Sec-Websocket-Protocol: "...) + p = append(p, "Sec-WebSocket-Protocol: "...) p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } + if compress { + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue @@ -200,13 +236,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // -// This function is deprecated, use websocket.Upgrader instead. +// Deprecated: Use websocket.Upgrader instead. // -// The application is responsible for checking the request origin before -// calling Upgrade. An example implementation of the same origin policy is: +// Upgrade does not perform origin checking. The application is responsible for +// checking the Origin header before calling Upgrade. An example implementation +// of the same origin policy check is: // // if req.Header.Get("Origin") != "http://"+req.Host { -// http.Error(w, "Origin not allowed", 403) +// http.Error(w, "Origin not allowed", http.StatusForbidden) // return // } // diff --git a/vendor/github.com/gorilla/websocket/util.go b/vendor/github.com/gorilla/websocket/util.go index 9a4908d..385fa01 100644 --- a/vendor/github.com/gorilla/websocket/util.go +++ b/vendor/github.com/gorilla/websocket/util.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "strings" + "unicode/utf8" ) var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") @@ -111,14 +112,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) { case escape: escape = false p[j] = b - j += 1 + j++ case b == '\\': escape = true case b == '"': return string(p[:j]), s[i+1:] default: p[j] = b - j += 1 + j++ } } return "", "" @@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) { return "", "" } +// equalASCIIFold returns true if s is equal to t with ASCII case folding. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + // tokenListContainsValue returns true if the 1#token header with the given -// name contains token. +// name contains a token equal to value with ASCII case folding. func tokenListContainsValue(header http.Header, name string, value string) bool { headers: for _, s := range header[name] { @@ -142,7 +166,7 @@ headers: if s != "" && s[0] != ',' { continue headers } - if strings.EqualFold(t, value) { + if equalASCIIFold(t, value) { return true } if s == "" { @@ -156,7 +180,6 @@ headers: // parseExtensiosn parses WebSocket extensions from a header. func parseExtensions(header http.Header) []map[string]string { - // From RFC 6455: // // Sec-WebSocket-Extensions = extension-list diff --git a/vendor/github.com/gorilla/websocket/x_net_proxy.go b/vendor/github.com/gorilla/websocket/x_net_proxy.go new file mode 100644 index 0000000..2e668f6 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/x_net_proxy.go @@ -0,0 +1,473 @@ +// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. +//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy + +// Package proxy provides support for a variety of protocols to proxy network +// data. +// + +package websocket + +import ( + "errors" + "io" + "net" + "net/url" + "os" + "strconv" + "strings" + "sync" +) + +type proxy_direct struct{} + +// Direct is a direct proxy: one that makes network connections directly. +var proxy_Direct = proxy_direct{} + +func (proxy_direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type proxy_PerHost struct { + def, bypass proxy_Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { + return &proxy_PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *proxy_PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *proxy_PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *proxy_PerHost) AddZone(zone string) { + if strings.HasSuffix(zone, ".") { + zone = zone[:len(zone)-1] + } + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *proxy_PerHost) AddHost(host string) { + if strings.HasSuffix(host, ".") { + host = host[:len(host)-1] + } + p.bypassHosts = append(p.bypassHosts, host) +} + +// A Dialer is a means to establish a connection. +type proxy_Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type proxy_Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy related variables in +// the environment. +func proxy_FromEnvironment() proxy_Dialer { + allProxy := proxy_allProxyEnv.Get() + if len(allProxy) == 0 { + return proxy_Direct + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return proxy_Direct + } + proxy, err := proxy_FromURL(proxyURL, proxy_Direct) + if err != nil { + return proxy_Direct + } + + noProxy := proxy_noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := proxy_NewPerHost(proxy, proxy_Direct) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { + if proxy_proxySchemes == nil { + proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) + } + proxy_proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { + var auth *proxy_Auth + if u.User != nil { + auth = new(proxy_Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5": + return proxy_SOCKS5("tcp", u.Host, auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxy_proxySchemes != nil { + if f, ok := proxy_proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + proxy_allProxyEnv = &proxy_envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + proxy_noProxyEnv = &proxy_envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type proxy_envOnce struct { + names []string + once sync.Once + val string +} + +func (e *proxy_envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *proxy_envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address +// with an optional username and password. See RFC 1928 and RFC 1929. +func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { + s := &proxy_socks5{ + network: network, + addr: addr, + forward: forward, + } + if auth != nil { + s.user = auth.User + s.password = auth.Password + } + + return s, nil +} + +type proxy_socks5 struct { + user, password string + network, addr string + forward proxy_Dialer +} + +const proxy_socks5Version = 5 + +const ( + proxy_socks5AuthNone = 0 + proxy_socks5AuthPassword = 2 +) + +const proxy_socks5Connect = 1 + +const ( + proxy_socks5IP4 = 1 + proxy_socks5Domain = 3 + proxy_socks5IP6 = 4 +) + +var proxy_socks5Errors = []string{ + "", + "general failure", + "connection forbidden", + "network unreachable", + "host unreachable", + "connection refused", + "TTL expired", + "command not supported", + "address type not supported", +} + +// Dial connects to the address addr on the given network via the SOCKS5 proxy. +func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) + } + + conn, err := s.forward.Dial(s.network, s.addr) + if err != nil { + return nil, err + } + if err := s.connect(conn, addr); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +// connect takes an existing connection to a socks5 proxy server, +// and commands the server to extend that connection to target, +// which must be a canonical address with a host and port. +func (s *proxy_socks5) connect(conn net.Conn, target string) error { + host, portStr, err := net.SplitHostPort(target) + if err != nil { + return err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return errors.New("proxy: failed to parse port number: " + portStr) + } + if port < 1 || port > 0xffff { + return errors.New("proxy: port number out of range: " + portStr) + } + + // the size here is just an estimate + buf := make([]byte, 0, 6+len(host)) + + buf = append(buf, proxy_socks5Version) + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { + buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) + } else { + buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) + } + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + if buf[0] != 5 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) + } + if buf[1] == 0xff { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") + } + + // See RFC 1929 + if buf[1] == proxy_socks5AuthPassword { + buf = buf[:0] + buf = append(buf, 1 /* password protocol version */) + buf = append(buf, uint8(len(s.user))) + buf = append(buf, s.user...) + buf = append(buf, uint8(len(s.password))) + buf = append(buf, s.password...) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if buf[1] != 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") + } + } + + buf = buf[:0] + buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) + + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, proxy_socks5IP4) + ip = ip4 + } else { + buf = append(buf, proxy_socks5IP6) + } + buf = append(buf, ip...) + } else { + if len(host) > 255 { + return errors.New("proxy: destination host name too long: " + host) + } + buf = append(buf, proxy_socks5Domain) + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + buf = append(buf, byte(port>>8), byte(port)) + + if _, err := conn.Write(buf); err != nil { + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + if _, err := io.ReadFull(conn, buf[:4]); err != nil { + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + failure := "unknown error" + if int(buf[1]) < len(proxy_socks5Errors) { + failure = proxy_socks5Errors[buf[1]] + } + + if len(failure) > 0 { + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) + } + + bytesToDiscard := 0 + switch buf[3] { + case proxy_socks5IP4: + bytesToDiscard = net.IPv4len + case proxy_socks5IP6: + bytesToDiscard = net.IPv6len + case proxy_socks5Domain: + _, err := io.ReadFull(conn, buf[:1]) + if err != nil { + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + bytesToDiscard = int(buf[0]) + default: + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) + } + + if cap(buf) < bytesToDiscard { + buf = make([]byte, bytesToDiscard) + } else { + buf = buf[:bytesToDiscard] + } + if _, err := io.ReadFull(conn, buf); err != nil { + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + // Also need to discard the port number + if _, err := io.ReadFull(conn, buf[:2]); err != nil { + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) + } + + return nil +}