diff --git a/cmd/goatak_server/tcpserver.go b/cmd/goatak_server/tcpserver.go index 9f6cfbf..2d1f80a 100644 --- a/cmd/goatak_server/tcpserver.go +++ b/cmd/goatak_server/tcpserver.go @@ -2,8 +2,8 @@ package main import ( "crypto/tls" - "crypto/x509" "fmt" + "github.com/kdudkov/goatak/pkg/tlsutil" "net" "go.uber.org/zap" @@ -88,7 +88,7 @@ func (app *App) listenTls(addr string) error { func (app *App) verifyConnection(st tls.ConnectionState) error { user, sn := getCertUser(&st) - app.logCert(st.PeerCertificates) + tlsutil.LogCerts(app.Logger, st.PeerCertificates...) if !app.users.UserIsValid(user, sn) { app.Logger.Warnf("bad user %s", user) @@ -108,14 +108,6 @@ func getCertUser(st *tls.ConnectionState) (string, string) { return "", "" } -func (app *App) logCert(cert []*x509.Certificate) { - for i, cert := range cert { - app.Logger.Infof("#%d issuer: %s", i, cert.Issuer.String()) - app.Logger.Infof("#%d subject: %s", i, cert.Subject.String()) - app.Logger.Infof("#%d sn: %x", i, cert.SerialNumber) - } -} - func (app *App) onTlsClientConnect(username, sn string) { } diff --git a/cmd/webclient/enroll.go b/cmd/webclient/enroll.go index d78fc6c..3fe1608 100644 --- a/cmd/webclient/enroll.go +++ b/cmd/webclient/enroll.go @@ -98,7 +98,6 @@ func (e *Enroller) getOrEnrollCert(uid, version string) (*tls.Certificate, []*x5 fname := fmt.Sprintf("%s_%s.p12", e.host, e.user) if cert, cas, err := loadP12(fname, viper.GetString("ssl.password")); err == nil { e.logger.Infof("loading cert from file %s", fname) - e.logger.Infof("cert is valid till %s", cert.Leaf.NotAfter) return cert, cas, nil } @@ -183,6 +182,8 @@ func (e *Enroller) getOrEnrollCert(uid, version string) (*tls.Certificate, []*x5 return nil, nil, fmt.Errorf("no signed cert in answer") } + tlsutil.LogCert(e.logger, "signed cert", cert) + if e.save { if err := e.saveP12(key, cert, ca); err != nil { e.logger.Errorf("%s", err) @@ -257,7 +258,7 @@ func (e *Enroller) saveP12(key interface{}, cert *x509.Certificate, ca []*x509.C } defer f.Close() - data, err := pkcs12.Encode(rand.Reader, key, cert, ca, viper.GetString("ssl.password")) + data, err := pkcs12.Modern.Encode(key, cert, ca, viper.GetString("ssl.password")) if err != nil { return err } diff --git a/cmd/webclient/main.go b/cmd/webclient/main.go index f0dd01f..73ba53a 100644 --- a/cmd/webclient/main.go +++ b/cmd/webclient/main.go @@ -452,6 +452,7 @@ func main() { app.Logger.Errorf("error while loading cert: %s", err.Error()) return } + tlsutil.LogCert(app.Logger, "loaded cert", cert.Leaf) app.tlsCert = cert app.cas = tlsutil.MakeCertPool(cas...) } diff --git a/cmd/webclient/tcp_handler.go b/cmd/webclient/tcp_handler.go index 56b084c..59e4e63 100644 --- a/cmd/webclient/tcp_handler.go +++ b/cmd/webclient/tcp_handler.go @@ -3,9 +3,9 @@ package main import ( "crypto/tls" "fmt" + "github.com/kdudkov/goatak/pkg/tlsutil" "github.com/spf13/viper" "net" - "strings" ) func (app *App) connect() (net.Conn, error) { @@ -25,11 +25,7 @@ func (app *App) connect() (net.Conn, error) { app.Logger.Infof("Handshake complete: %t", cs.HandshakeComplete) app.Logger.Infof("version: %d", cs.Version) - for i, cert := range cs.PeerCertificates { - app.Logger.Infof("cert #%d subject: %s", i, cert.Subject.String()) - app.Logger.Infof("cert #%d issuer: %s", i, cert.Issuer.String()) - app.Logger.Infof("cert #%d dns_names: %s", i, strings.Join(cert.DNSNames, ",")) - } + tlsutil.LogCerts(app.Logger, cs.PeerCertificates...) return conn, nil } else { app.Logger.Infof("connecting to %s...", addr) @@ -41,7 +37,7 @@ func (app *App) getTlsConfig() *tls.Config { conf := &tls.Config{ Certificates: []tls.Certificate{*app.tlsCert}, RootCAs: app.cas, - //InsecureSkipVerify: true, + ClientCAs: app.cas, } if !viper.GetBool("ssl.strict") { diff --git a/pkg/tlsutil/util.go b/pkg/tlsutil/util.go index 4a7955f..afe5d49 100644 --- a/pkg/tlsutil/util.go +++ b/pkg/tlsutil/util.go @@ -6,6 +6,8 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "fmt" + "go.uber.org/zap" "strings" "software.sslmate.com/src/go-pkcs12" @@ -77,3 +79,30 @@ func MakeCertPool(certs ...*x509.Certificate) *x509.CertPool { return cp } + +func LogCert(logger *zap.SugaredLogger, name string, cert *x509.Certificate) { + if cert == nil { + logger.Errorf("no %s!!!", name) + return + } + logger.Infof("%s sn: %x", name, cert.SerialNumber) + logger.Infof("%s subject: %s", name, cert.Subject.String()) + logger.Infof("%s issuer: %s", name, cert.Issuer.String()) + logger.Infof("%s valid till %s", name, cert.NotAfter) + if len(cert.DNSNames) > 0 { + logger.Infof("%s dns_names: %s", name, strings.Join(cert.DNSNames, ",")) + } + if len(cert.IPAddresses) > 0 { + ip1 := make([]string, len(cert.IPAddresses)) + for i, ip := range cert.IPAddresses { + ip1[i] = ip.String() + } + logger.Infof("%s ip_addresses: %s", name, strings.Join(ip1, ",")) + } +} + +func LogCerts(logger *zap.SugaredLogger, certs ...*x509.Certificate) { + for i, c := range certs { + LogCert(logger, fmt.Sprintf("cert #%d", i), c) + } +}