diff --git a/client.go b/client.go index 6e82bcfe..ecc2203f 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,8 @@ import ( "net/http" "time" + "crypto/rand" + "golang.org/x/net/http2" ) @@ -22,10 +24,10 @@ const ( HostProduction = "https://api.push.apple.com" ) -// DefaultHost is a mutable var for testing purposes -var DefaultHost = HostDevelopment - var ( + // DefaultHost is a mutable var for testing purposes + DefaultHost = HostDevelopment + // TLSDialTimeout is the maximum amount of time a dial will wait for a connect // to complete. TLSDialTimeout = 20 * time.Second @@ -40,6 +42,11 @@ type Client struct { HTTPClient *http.Client Certificate tls.Certificate Host string + + pinging bool + stopPinging chan struct{} + pingInterval time.Duration + conn *tls.Conn } // NewClient returns a new Client with an underlying http.Client configured with @@ -60,20 +67,77 @@ func NewClient(certificate tls.Certificate) *Client { if len(certificate.Certificate) > 0 { tlsConfig.BuildNameToCertificate() } + + client := &Client{} + transport := &http2.Transport{ TLSClientConfig: tlsConfig, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg) - }, } - return &Client{ - HTTPClient: &http.Client{ - Transport: transport, - Timeout: HTTPClientTimeout, - }, - Certificate: certificate, - Host: DefaultHost, + + dialTsl := func(network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg) + if err != nil { + return nil, err + } + + client.conn = conn + return conn, nil + } + + transport.DialTLS = dialTsl + client.HTTPClient = &http.Client{ + Transport: transport, + Timeout: HTTPClientTimeout, + } + client.Certificate = certificate + client.Host = DefaultHost + + return client +} + +func (c *Client) EnablePinging(pingInterval time.Duration, pingError chan error) { + //lets make sure that the old goroutine has exited in case the user calls this method multiple times + c.DisablePinging() + + c.pinging = true + c.pingInterval = pingInterval + + go func() { + t := time.NewTicker(pingInterval) + var framer *http2.Framer + for { + select { + case <-t.C: + if c.conn == nil { + continue + } + + if framer == nil { + framer = http2.NewFramer(c.conn, c.conn) + } + + var p [8]byte + rand.Read(p[:]) + err := framer.WritePing(false, p) + if err != nil && pingError != nil { + pingError <- err + } + case <-c.stopPinging: + t.Stop() + framer = nil + close(pingError) + return + } + } + }() +} + +func (c *Client) DisablePinging() { + if c.pinging { + c.stopPinging <- struct{}{} } + + c.pinging = false } // Development sets the Client to use the APNs development push endpoint. @@ -88,6 +152,14 @@ func (c *Client) Production() *Client { return c } +func (c *Client) IsPinging() bool { + return c.pinging +} + +func (c *Client) GetPingInterval() time.Duration { + return c.pingInterval +} + // Push sends a Notification to the APNs gateway. If the underlying http.Client // is not currently connected, this method will attempt to reconnect // transparently before sending the notification. It will return a Response @@ -117,6 +189,7 @@ func (c *Client) Push(n *Notification) (*Response, error) { if err := decoder.Decode(&response); err != nil && err != io.EOF { return &Response{}, err } + return response, nil }