Skip to content

Commit

Permalink
watch the change the certificate files and reload them
Browse files Browse the repository at this point in the history
  • Loading branch information
xgfone committed Nov 3, 2021
1 parent bcb4a8e commit 467c0cc
Showing 1 changed file with 86 additions and 9 deletions.
95 changes: 86 additions & 9 deletions runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ package ship

import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"os"
"os/signal"
"sync/atomic"
"syscall"
"time"
)

// DefaultSignals is a set of default signals.
Expand Down Expand Up @@ -184,18 +187,16 @@ func (r *Runner) startServer(certFile, keyFile string) {
r.infof("The HTTP Server [%s] is running on %s", r.Name, r.Server.Addr)
}

go r.handleSignals(r.done)

var isTLS bool
if certFile != "" && keyFile != "" {
isTLS = true
} else if r.Server.TLSConfig != nil &&
(len(r.Server.TLSConfig.Certificates) > 0 ||
r.Server.TLSConfig.GetCertificate != nil) {
isTLS = true
if r.Server.TLSConfig == nil {
r.Server.TLSConfig = &tls.Config{GetCertificate: r.getCertificate(certFile, keyFile)}
} else if r.Server.TLSConfig.GetCertificate == nil {
r.Server.TLSConfig.GetCertificate = r.getCertificate(certFile, keyFile)
}
}

if isTLS {
go r.handleSignals(r.done)
if r.Server.TLSConfig != nil {
r.err = r.Server.ListenAndServeTLS(certFile, keyFile)
} else {
r.err = r.Server.ListenAndServe()
Expand Down Expand Up @@ -250,3 +251,79 @@ func (r *Runner) handleSignals(exit <-chan struct{}) {
}
}
}

func (r *Runner) getCertificate(certFile, keyFile string) getCertificate {
cert := tlscert{runner: r, certFile: certFile, keyFile: keyFile}
if _, err := cert.updateCert(); err != nil {
r.errorf("fail to load certificate: cert=%s, key=%s, err=%v",
certFile, keyFile, err)
}
go cert.WatchCertFile()
return cert.GetCertificate
}

type getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)

type tlscert struct {
runner *Runner
certFile string
keyFile string
certLast time.Time
keyLast time.Time
cert atomic.Value
}

func (c *tlscert) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if cert := c.cert.Load(); cert != nil {
return cert.(*tls.Certificate), nil
}
return nil, errors.New("missing the certificate")
}

func (c *tlscert) WatchCertFile() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

for {
select {
case <-c.runner.done:
return
case <-ticker.C:
if ok, err := c.updateCert(); err != nil {
c.runner.errorf("fail to reload certificate: cert=%s, key=%s, err=%v",
c.certFile, c.keyFile, err)
} else if ok {
c.runner.infof("successfully reload the certificate: cert=%s, key=%s",
c.certFile, c.keyFile)
}
}
}
}

func (c *tlscert) updateCert() (ok bool, err error) {
certfi, err := os.Stat(c.certFile)
if err != nil {
return
}

keyfi, err := os.Stat(c.keyFile)
if err != nil {
return
}

certLast := certfi.ModTime()
keyLast := keyfi.ModTime()
if !certLast.After(c.certLast) && !keyLast.After(c.keyLast) {
return false, nil
}

cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile)
if err != nil {
return
}

c.certLast = certLast
c.keyLast = keyLast
c.cert.Store(&cert)
return true, nil
}

0 comments on commit 467c0cc

Please sign in to comment.