Skip to content

Commit

Permalink
print chains with insecure flag
Browse files Browse the repository at this point in the history
  • Loading branch information
pete911 committed Dec 17, 2024
1 parent a2eaff6 commit 5396db3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 56 deletions.
18 changes: 16 additions & 2 deletions pkg/cert/cert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cert

import (
"bytes"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
Expand Down Expand Up @@ -186,7 +187,7 @@ func (c Certificate) String() string {
fmt.Sprintf("Version: %d", c.x509Certificate.Version),
fmt.Sprintf("Serial Number: %s", formatHexArray(c.x509Certificate.SerialNumber.Bytes())),
fmt.Sprintf("Signature Algorithm: %s", c.x509Certificate.SignatureAlgorithm),
fmt.Sprintf("Type: %s", CertificateType(c.x509Certificate)),
fmt.Sprintf("Type: %s", c.Type()),
fmt.Sprintf("Issuer: %s", c.x509Certificate.Issuer),
fmt.Sprintf("Validity\n Not Before: %s\n Not After : %s",
ValidityFormat(c.x509Certificate.NotBefore),
Expand All @@ -202,6 +203,17 @@ func (c Certificate) String() string {
}, "\n")
}

func (c Certificate) Type() string {
if c.x509Certificate.AuthorityKeyId == nil || bytes.Equal(c.x509Certificate.AuthorityKeyId, c.x509Certificate.SubjectKeyId) {
return "root"
}

if c.x509Certificate.IsCA {
return "intermediate"
}
return "end-entity"
}

func (c Certificate) Extensions() string {
var lines []string
for _, v := range ToExtensions(c.x509Certificate.Extensions) {
Expand Down Expand Up @@ -236,7 +248,9 @@ func formatExpiry(t time.Time) string {
}

func formatHexArray(b []byte) string {

if len(b) == 0 {
return ""
}
buf := make([]byte, 0, 3*len(b))
x := buf[1*len(b) : 3*len(b)]
hex.Encode(x, b)
Expand Down
6 changes: 3 additions & 3 deletions pkg/cert/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ func Test_rootIdentification(t *testing.T) {
require.Len(t, certificate, 1)
require.Equal(t, certificate[0].x509Certificate.RawSubject, certificate[0].x509Certificate.RawIssuer)
require.NotEmpty(t, certificate[0].x509Certificate.AuthorityKeyId)
require.Equal(t, "root", CertificateType(certificate[0].x509Certificate))
require.Equal(t, "root", certificate[0].Type())
})

t.Run("given certificate authority key id is unset then identify as root", func(t *testing.T) {
certificate := loadTestCertificates(t, "cert.pem")
require.Len(t, certificate, 1)
assert.Len(t, certificate[0].x509Certificate.AuthorityKeyId, 0)
assert.True(t, certificate[0].x509Certificate.IsCA)
require.Equal(t, "root", CertificateType(certificate[0].x509Certificate))
require.Equal(t, "root", certificate[0].Type())
})
}

Expand All @@ -106,6 +106,6 @@ func Test_intermediateIdentification(t *testing.T) {
require.Len(t, certificate, 1)
require.Equal(t, certificate[0].x509Certificate.RawSubject, certificate[0].x509Certificate.RawIssuer)
require.NotEmpty(t, certificate[0].x509Certificate.AuthorityKeyId)
require.Equal(t, "intermediate", CertificateType(certificate[0].x509Certificate))
require.Equal(t, "intermediate", certificate[0].Type())
})
}
55 changes: 41 additions & 14 deletions pkg/cert/location.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cert
import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"golang.design/x/clipboard"
Expand Down Expand Up @@ -73,11 +74,43 @@ func (c CertificateLocations) SortByExpiry() CertificateLocations {
}

type CertificateLocation struct {
TLSVersion uint16 // only applicable for network certificates
Path string
Error error
Certificates Certificates
VerifiedChains []Certificates // only applicable for network certificates
TLSVersion uint16 // only applicable for network certificates
Path string
Error error
Certificates Certificates
}

func (c CertificateLocation) Chains() ([]Certificates, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

// we are not verifying time and dns, because we want to work with -insecure flag as well
// just to see what local chains are used for verification
opts := x509.VerifyOptions{
Roots: pool,
Intermediates: x509.NewCertPool(),
}
for _, cert := range c.Certificates {
if cert.Type() == "intermediate" {
opts.Intermediates.AddCert(cert.x509Certificate)
}
}

var verifiedChains []Certificates
for _, cert := range c.Certificates {
if cert.Type() == "end-entity" {
chains, err := cert.x509Certificate.Verify(opts)
if err != nil {
return nil, err
}
for _, chain := range chains {
verifiedChains = append(verifiedChains, FromX509Certificates(chain))
}
}
}
return verifiedChains, nil
}

func (c CertificateLocation) Name() string {
Expand Down Expand Up @@ -119,16 +152,10 @@ func LoadCertificatesFromNetwork(addr string, tlsSkipVerify bool) CertificateLoc
connectionState := conn.ConnectionState()
x509Certificates := connectionState.PeerCertificates

var verifiedChains []Certificates
for _, chain := range connectionState.VerifiedChains {
verifiedChains = append(verifiedChains, FromX509Certificates(chain))
}

return CertificateLocation{
TLSVersion: conn.ConnectionState().Version,
Path: addr,
Certificates: FromX509Certificates(x509Certificates),
VerifiedChains: verifiedChains,
TLSVersion: conn.ConnectionState().Version,
Path: addr,
Certificates: FromX509Certificates(x509Certificates),
}
}

Expand Down
13 changes: 0 additions & 13 deletions pkg/cert/util.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cert

import (
"bytes"
"crypto/x509"
"time"
)
Expand Down Expand Up @@ -44,18 +43,6 @@ func ValidityFormat(t time.Time) string {
return t.Format(validityFormat)
}

func CertificateType(cert *x509.Certificate) string {

if cert.AuthorityKeyId == nil || bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId) {
return "root"
}

if cert.IsCA {
return "intermediate"
}
return "end-entity"
}

// ExtKeyUsageToString converts extended key usage integer values to strings
func ExtKeyUsageToString(extKeyUsage []x509.ExtKeyUsage) []string {

Expand Down
57 changes: 33 additions & 24 deletions print.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,38 @@ func PrintCertificatesLocations(certificateLocations []cert.CertificateLocation,
}

fmt.Printf("--- [%s] ---\n", certificateLocation.Name())
printCertificates(certificateLocation.Certificates, printPem, printExtensions)
printCertificates(certificateLocation, printPem, printChains, printExtensions)
}
}

if certificateLocation.VerifiedChains != nil {
fmt.Printf("--- %d verified chains ---\n", len(certificateLocation.VerifiedChains))
}
func printCertificates(certLocation cert.CertificateLocation, printPem, printChains, printExtensions bool) {

if printChains {
for i, chain := range certificateLocation.VerifiedChains {
fmt.Printf("--- chain %d ---\n", i+1)
printCertificates(chain, printPem, printExtensions)
var prt = func(certs []cert.Certificate, printPem, printExtensions bool) {
for _, certificate := range certs {
fmt.Println(certificate)
if printExtensions {
fmt.Println("--- extensions ---")
fmt.Print(certificate.Extensions())
fmt.Println()
}
fmt.Println()
if printPem {
fmt.Println(string(certificate.ToPEM()))
}
}
}
}

func printCertificates(certificates []cert.Certificate, printPem, printExtensions bool) {

for _, certificate := range certificates {
fmt.Println(certificate)
if printExtensions {
fmt.Println("--- extensions ---")
fmt.Print(certificate.Extensions())
fmt.Println()
prt(certLocation.Certificates, printPem, printExtensions)
if printChains {
chains, err := certLocation.Chains()
if err != nil {
fmt.Printf("--- chains: %v ---\n", err)
return
}
fmt.Println()
if printPem {
fmt.Println(string(certificate.ToPEM()))
fmt.Printf("--- %d chains ---\n", len(chains))
for i, chain := range chains {
fmt.Printf("--- chain %d ---\n", i+1)
prt(chain, printPem, printExtensions)
}
}
}
Expand All @@ -54,10 +59,14 @@ func PrintPemOnly(certificateLocations []cert.CertificateLocation, printChains b
}

if printChains {
for _, chains := range certificateLocation.VerifiedChains {
fmt.Println()
for _, chain := range chains {
fmt.Print(string(chain.ToPEM()))
chains, err := certificateLocation.Chains()
if err != nil {
fmt.Printf("--- chains: %v ---\n", err)
continue
}
for _, chain := range chains {
for _, c := range chain {
fmt.Print(string(c.ToPEM()))
}
}
}
Expand Down

0 comments on commit 5396db3

Please sign in to comment.