diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1790b87..cc92844 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: build: strategy: matrix: - go-version: [1.18.x, 1.19.x] + go-version: [1.19.x, 1.20.x] name: Linux runs-on: ubuntu-latest steps: @@ -22,8 +22,6 @@ jobs: id: go - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: Install staticcheck - run: go install honnef.co/go/tools/cmd/staticcheck@v0.3.3 - name: Install libpcsc run: sudo apt-get install -y libpcsclite-dev pcscd pcsc-tools - name: Test @@ -31,7 +29,7 @@ jobs: build-windows: strategy: matrix: - go-version: [1.18.x, 1.19.x] + go-version: [1.19.x, 1.20.x] name: Windows runs-on: windows-latest steps: @@ -42,8 +40,6 @@ jobs: id: go - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: Install staticcheck - run: go install honnef.co/go/tools/cmd/staticcheck@v0.3.3 - name: Test run: "make build" env: diff --git a/Makefile b/Makefile index 698d316..d454b69 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,7 @@ .PHONY: test -test: lint +test: go test -v ./... -.PHONY: lint -lint: - staticcheck ./... - .PHONY: build -build: lint +build: go build ./... diff --git a/piv/key.go b/piv/key.go index d3ec1a5..5b8bc07 100644 --- a/piv/key.go +++ b/piv/key.go @@ -31,6 +31,8 @@ import ( "math/big" "strconv" "strings" + + rsafork "github.com/go-piv/piv-go/third_party/rsa" ) // errMismatchingAlgorithms is returned when a cryptographic operation @@ -1072,7 +1074,7 @@ func (k *keyRSA) Public() crypto.PublicKey { func (k *keyRSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - return ykSignRSA(tx, k.slot, k.pub, digest, opts) + return ykSignRSA(tx, rand, k.slot, k.pub, digest, opts) }) } @@ -1285,43 +1287,54 @@ func ykDecryptRSA(tx *scTx, slot Slot, pub *rsa.PublicKey, data []byte) ([]byte, // PKCS#1 v15 is largely informed by the standard library // https://github.com/golang/go/blob/go1.13.5/src/crypto/rsa/pkcs1v15.go -func ykSignRSA(tx *scTx, slot Slot, pub *rsa.PublicKey, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - if _, ok := opts.(*rsa.PSSOptions); ok { - return nil, fmt.Errorf("rsassa-pss signatures not supported") +func ykSignRSA(tx *scTx, rand io.Reader, slot Slot, pub *rsa.PublicKey, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + hash := opts.HashFunc() + if hash.Size() != len(digest) { + return nil, fmt.Errorf("input must be a hashed message") } alg, err := rsaAlg(pub) if err != nil { return nil, err } - hash := opts.HashFunc() - if hash.Size() != len(digest) { - return nil, fmt.Errorf("input must be a hashed message") - } - prefix, ok := hashPrefixes[hash] - if !ok { - return nil, fmt.Errorf("unsupported hash algorithm: crypto.Hash(%d)", hash) - } - // https://tools.ietf.org/pdf/rfc2313.pdf#page=9 - d := make([]byte, len(prefix)+len(digest)) - copy(d[:len(prefix)], prefix) - copy(d[len(prefix):], digest) + var data []byte + if o, ok := opts.(*rsa.PSSOptions); ok { + salt, err := rsafork.NewSalt(rand, pub, hash, o) + if err != nil { + return nil, err + } + em, err := rsafork.EMSAPSSEncode(digest, pub, salt, hash.New()) + if err != nil { + return nil, err + } + data = em + } else { + prefix, ok := hashPrefixes[hash] + if !ok { + return nil, fmt.Errorf("unsupported hash algorithm: crypto.Hash(%d)", hash) + } + + // https://tools.ietf.org/pdf/rfc2313.pdf#page=9 + d := make([]byte, len(prefix)+len(digest)) + copy(d[:len(prefix)], prefix) + copy(d[len(prefix):], digest) - paddingLen := pub.Size() - 3 - len(d) - if paddingLen < 0 { - return nil, fmt.Errorf("message too large") - } + paddingLen := pub.Size() - 3 - len(d) + if paddingLen < 0 { + return nil, fmt.Errorf("message too large") + } - padding := make([]byte, paddingLen) - for i := range padding { - padding[i] = 0xff - } + padding := make([]byte, paddingLen) + for i := range padding { + padding[i] = 0xff + } - // https://tools.ietf.org/pdf/rfc2313.pdf#page=9 - data := append([]byte{0x00, 0x01}, padding...) - data = append(data, 0x00) - data = append(data, d...) + // https://tools.ietf.org/pdf/rfc2313.pdf#page=9 + data = append([]byte{0x00, 0x01}, padding...) + data = append(data, 0x00) + data = append(data, d...) + } // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=117 cmd := apdu{ diff --git a/piv/key_test.go b/piv/key_test.go index 46b9b8a..f8ebf55 100644 --- a/piv/key_test.go +++ b/piv/key_test.go @@ -22,11 +22,14 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "errors" + "fmt" + "io" "math/big" "testing" "time" @@ -333,6 +336,179 @@ func TestYubiKeySignRSA(t *testing.T) { } } +func TestYubiKeySignRSAPSS(t *testing.T) { + tests := []struct { + name string + alg Algorithm + long bool + }{ + {"rsa1024", AlgorithmRSA1024, false}, + {"rsa2048", AlgorithmRSA2048, true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.long && testing.Short() { + t.Skip("skipping test in short mode") + } + yk, close := newTestYubiKey(t) + defer close() + slot := SlotAuthentication + key := Key{ + Algorithm: test.alg, + TouchPolicy: TouchPolicyNever, + PINPolicy: PINPolicyNever, + } + pubKey, err := yk.GenerateKey(DefaultManagementKey, slot, key) + if err != nil { + t.Fatalf("generating key: %v", err) + } + pub, ok := pubKey.(*rsa.PublicKey) + if !ok { + t.Fatalf("public key is not an rsa key") + } + data := sha256.Sum256([]byte("hello")) + priv, err := yk.PrivateKey(slot, pub, KeyAuth{}) + if err != nil { + t.Fatalf("getting private key: %v", err) + } + s, ok := priv.(crypto.Signer) + if !ok { + t.Fatalf("private key didn't implement crypto.Signer") + } + + opt := &rsa.PSSOptions{Hash: crypto.SHA256} + out, err := s.Sign(rand.Reader, data[:], opt) + if err != nil { + t.Fatalf("signing failed: %v", err) + } + if err := rsa.VerifyPSS(pub, crypto.SHA256, data[:], out, opt); err != nil { + t.Errorf("failed to verify signature: %v", err) + } + }) + } +} + +func TestTLS13(t *testing.T) { + yk, close := newTestYubiKey(t) + defer close() + slot := SlotAuthentication + key := Key{ + Algorithm: AlgorithmRSA1024, + TouchPolicy: TouchPolicyNever, + PINPolicy: PINPolicyNever, + } + pub, err := yk.GenerateKey(DefaultManagementKey, slot, key) + if err != nil { + t.Fatalf("generating key: %v", err) + } + priv, err := yk.PrivateKey(slot, pub, KeyAuth{}) + if err != nil { + t.Fatalf("getting private key: %v", err) + } + + tmpl := &x509.Certificate{ + Subject: pkix.Name{CommonName: "test"}, + SerialNumber: big.NewInt(100), + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + DNSNames: []string{"example.com"}, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + } + + rawCert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv) + if err != nil { + t.Fatalf("creating certificate: %v", err) + } + x509Cert, err := x509.ParseCertificate(rawCert) + if err != nil { + t.Fatalf("parsing cert: %v", err) + } + cert := tls.Certificate{ + Certificate: [][]byte{rawCert}, + PrivateKey: priv, + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.PSSWithSHA256, + }, + } + pool := x509.NewCertPool() + pool.AddCert(x509Cert) + + cliConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + ServerName: "example.com", + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + srvConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + + srv, err := tls.Listen("tcp", "0.0.0.0:0", srvConf) + if err != nil { + t.Fatalf("creating tls listener: %v", err) + } + defer srv.Close() + + errCh := make(chan error, 2) + + want := []byte("hello, world") + + go func() { + conn, err := srv.Accept() + if err != nil { + errCh <- fmt.Errorf("accepting conn: %v", err) + return + } + defer conn.Close() + + got := make([]byte, len(want)) + if _, err := io.ReadFull(conn, got); err != nil { + errCh <- fmt.Errorf("read data: %v", err) + return + } + if !bytes.Equal(want, got) { + errCh <- fmt.Errorf("unexpected value read: %s", got) + return + } + errCh <- nil + }() + + go func() { + conn, err := tls.Dial("tcp", srv.Addr().String(), cliConf) + if err != nil { + errCh <- fmt.Errorf("dial: %v", err) + return + } + defer conn.Close() + + if v := conn.ConnectionState().Version; v != tls.VersionTLS13 { + errCh <- fmt.Errorf("client got verison 0x%x, want=0x%x", v, tls.VersionTLS13) + return + } + + if _, err := conn.Write(want); err != nil { + errCh <- fmt.Errorf("write: %v", err) + return + } + errCh <- nil + }() + + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + t.Fatalf("%v", err) + } + } +} + func TestYubiKeyDecryptRSA(t *testing.T) { tests := []struct { name string diff --git a/third_party/rsa/LICENSE b/third_party/rsa/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/third_party/rsa/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/rsa/README b/third_party/rsa/README new file mode 100644 index 0000000..fef7ed2 --- /dev/null +++ b/third_party/rsa/README @@ -0,0 +1,2 @@ +This directory contains a fork of internal crypto/rsa logic to allow computation +of PSS padding. diff --git a/third_party/rsa/pss.go b/third_party/rsa/pss.go new file mode 100644 index 0000000..142b55d --- /dev/null +++ b/third_party/rsa/pss.go @@ -0,0 +1,168 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rsa + +import ( + "crypto" + "crypto/rsa" + "errors" + "hash" + "io" +) + +var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative") + +// Per RFC 8017, Section 9.1 +// +// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc +// +// where +// +// DB = PS || 0x01 || salt +// +// and PS can be empty so +// +// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2 +// + +// EMSAPSSEncode is extracted from SignPSS, and is used to generate a EM value +// for a PSS signature operation. +func EMSAPSSEncode(mHash []byte, pub *rsa.PublicKey, salt []byte, hash hash.Hash) ([]byte, error) { + emBits := pub.N.BitLen() - 1 + + // See RFC 8017, Section 9.1.1. + + hLen := hash.Size() + sLen := len(salt) + emLen := (emBits + 7) / 8 + + // 1. If the length of M is greater than the input limitation for the + // hash function (2^61 - 1 octets for SHA-1), output "message too + // long" and stop. + // + // 2. Let mHash = Hash(M), an octet string of length hLen. + + if len(mHash) != hLen { + return nil, errors.New("crypto/rsa: input must be hashed with given hash") + } + + // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. + + if emLen < hLen+sLen+2 { + return nil, rsa.ErrMessageTooLong + } + + em := make([]byte, emLen) + psLen := emLen - sLen - hLen - 2 + db := em[:psLen+1+sLen] + h := em[psLen+1+sLen : emLen-1] + + // 4. Generate a random octet string salt of length sLen; if sLen = 0, + // then salt is the empty string. + // + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + // + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + // + // 6. Let H = Hash(M'), an octet string of length hLen. + + var prefix [8]byte + + hash.Write(prefix[:]) + hash.Write(mHash) + hash.Write(salt) + + h = hash.Sum(h[:0]) + hash.Reset() + + // 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2 + // zero octets. The length of PS may be 0. + // + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + + db[psLen] = 0x01 + copy(db[psLen+1:], salt) + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + // + // 10. Let maskedDB = DB \xor dbMask. + + mgf1XOR(db, hash, h) + + // 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in + // maskedDB to zero. + + db[0] &= 0xff >> (8*emLen - emBits) + + // 12. Let EM = maskedDB || H || 0xbc. + em[emLen-1] = 0xbc + + // 13. Output EM. + return em, nil +} + +// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function +// specified in PKCS #1 v2.1. +func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { + var counter [4]byte + var digest []byte + + done := 0 + for done < len(out) { + hash.Write(seed) + hash.Write(counter[0:4]) + digest = hash.Sum(digest[:0]) + hash.Reset() + + for i := 0; i < len(digest) && done < len(out); i++ { + out[done] ^= digest[i] + done++ + } + incCounter(&counter) + } +} + +// incCounter increments a four byte, big-endian counter. +func incCounter(c *[4]byte) { + if c[3]++; c[3] != 0 { + return + } + if c[2]++; c[2] != 0 { + return + } + if c[1]++; c[1] != 0 { + return + } + c[0]++ +} + +// NewSalt is extracted from SignPSS and is used to generate a salt value for a +// PSS signature. +func NewSalt(rand io.Reader, pub *rsa.PublicKey, hash crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) { + saltLength := opts.SaltLength + switch saltLength { + case rsa.PSSSaltLengthAuto: + saltLength = (pub.N.BitLen()-1+7)/8 - 2 - hash.Size() + if saltLength < 0 { + return nil, rsa.ErrMessageTooLong + } + case rsa.PSSSaltLengthEqualsHash: + saltLength = hash.Size() + default: + // If we get here saltLength is either > 0 or < -1, in the + // latter case we fail out. + if saltLength <= 0 { + return nil, invalidSaltLenErr + } + } + salt := make([]byte, saltLength) + if _, err := io.ReadFull(rand, salt); err != nil { + return nil, err + } + return salt, nil +}