From 644dc323eb21b8826f71b93fb6b146e5ab73cebf Mon Sep 17 00:00:00 2001 From: AlexandreEXFO <154447827+AlexandreEXFO@users.noreply.github.com> Date: Mon, 16 Dec 2024 13:05:32 -0500 Subject: [PATCH] Improve crypto & constants (#381) * Improve crypto & constants * Fix lint * Update constants.go (remove underscore)) --- tpm2/constants.go | 43 +++-- tpm2/crypto.go | 30 +++- tpm2/crypto_test.go | 370 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 431 insertions(+), 12 deletions(-) create mode 100644 tpm2/crypto_test.go diff --git a/tpm2/constants.go b/tpm2/constants.go index 733c3eb7..70edae27 100644 --- a/tpm2/constants.go +++ b/tpm2/constants.go @@ -27,6 +27,7 @@ const ( TPMAlgSHA256 TPMAlgID = 0x000B TPMAlgSHA384 TPMAlgID = 0x000C TPMAlgSHA512 TPMAlgID = 0x000D + TPMAlgSHA256192 TPMAlgID = 0x000E TPMAlgNull TPMAlgID = 0x0010 TPMAlgSM3256 TPMAlgID = 0x0012 TPMAlgSM4 TPMAlgID = 0x0013 @@ -49,12 +50,31 @@ const ( TPMAlgSHA3256 TPMAlgID = 0x0027 TPMAlgSHA3384 TPMAlgID = 0x0028 TPMAlgSHA3512 TPMAlgID = 0x0029 + TPMAlgSHAKE128 TPMAlgID = 0x002A + TPMAlgSHAKE256 TPMAlgID = 0x002B + TPMAlgSHAKE256192 TPMAlgID = 0x002C + TPMAlgSHAKE256256 TPMAlgID = 0x002D + TPMAlgSHAKE256512 TPMAlgID = 0x002E TPMAlgCMAC TPMAlgID = 0x003F TPMAlgCTR TPMAlgID = 0x0040 TPMAlgOFB TPMAlgID = 0x0041 TPMAlgCBC TPMAlgID = 0x0042 TPMAlgCFB TPMAlgID = 0x0043 TPMAlgECB TPMAlgID = 0x0044 + TPMAlgCCM TPMAlgID = 0x0050 + TPMAlgGCM TPMAlgID = 0x0051 + TPMAlgKW TPMAlgID = 0x0052 + TPMAlgKWP TPMAlgID = 0x0053 + TPMAlgEAX TPMAlgID = 0x0054 + TPMAlgEDDSA TPMAlgID = 0x0060 + TPMAlgEDDSAPH TPMAlgID = 0x0061 + TPMAlgLMS TPMAlgID = 0x0070 + TPMAlgXMSS TPMAlgID = 0x0071 + TPMAlgKEYEDXOF TPMAlgID = 0x0080 + TPMAlgKMACXOF128 TPMAlgID = 0x0081 + TPMAlgKMACXOF256 TPMAlgID = 0x0082 + TPMAlgKMAC128 TPMAlgID = 0x0090 + TPMAlgKMAC256 TPMAlgID = 0x0091 ) // TPMECCCurve represents a TPM_ECC_Curve. @@ -63,15 +83,20 @@ type TPMECCCurve uint16 // TPMECCCurve values come from Part 2: Structures, section 6.4. const ( - TPMECCNone TPMECCCurve = 0x0000 - TPMECCNistP192 TPMECCCurve = 0x0001 - TPMECCNistP224 TPMECCCurve = 0x0002 - TPMECCNistP256 TPMECCCurve = 0x0003 - TPMECCNistP384 TPMECCCurve = 0x0004 - TPMECCNistP521 TPMECCCurve = 0x0005 - TPMECCBNP256 TPMECCCurve = 0x0010 - TPMECCBNP638 TPMECCCurve = 0x0011 - TPMECCSM2P256 TPMECCCurve = 0x0020 + TPMECCNone TPMECCCurve = 0x0000 + TPMECCNistP192 TPMECCCurve = 0x0001 + TPMECCNistP224 TPMECCCurve = 0x0002 + TPMECCNistP256 TPMECCCurve = 0x0003 + TPMECCNistP384 TPMECCCurve = 0x0004 + TPMECCNistP521 TPMECCCurve = 0x0005 + TPMECCBNP256 TPMECCCurve = 0x0010 + TPMECCBNP638 TPMECCCurve = 0x0011 + TPMECCSM2P256 TPMECCCurve = 0x0020 + TPMECCBrainpoolP256R1 TPMECCCurve = 0x0030 + TPMECCBrainpoolP384R1 TPMECCCurve = 0x0031 + TPMECCBrainpoolP512R1 TPMECCCurve = 0x0032 + TPMECCCurve25519 TPMECCCurve = 0x0040 + TPMECCCurve448 TPMECCCurve = 0x0041 ) // TPMCC represents a TPM_CC. diff --git a/tpm2/crypto.go b/tpm2/crypto.go index c2c72072..2f12e0a3 100644 --- a/tpm2/crypto.go +++ b/tpm2/crypto.go @@ -24,6 +24,10 @@ func Priv(public TPMTPublic, sensitive TPMTSensitive) (crypto.PrivateKey, error) case TPMAlgRSA: publicKey := publicKey.(*rsa.PublicKey) + if sensitive.SensitiveType != TPMAlgRSA { + return nil, fmt.Errorf("sensitive type is not equal to public type") + } + prime, err := sensitive.Sensitive.RSA() if err != nil { return nil, fmt.Errorf("failed to retrieve the RSA prime number") @@ -34,14 +38,34 @@ func Priv(public TPMTPublic, sensitive TPMTSensitive) (crypto.PrivateKey, error) phiN := new(big.Int).Mul(new(big.Int).Sub(P, big.NewInt(1)), new(big.Int).Sub(Q, big.NewInt(1))) D := new(big.Int).ModInverse(big.NewInt(int64(publicKey.E)), phiN) - privateKey = rsa.PrivateKey{ + rsaKey := &rsa.PrivateKey{ PublicKey: *publicKey, D: D, Primes: []*big.Int{P, Q}, } - privateKey := privateKey.(rsa.PrivateKey) + rsaKey.Precompute() + + privateKey = rsaKey + case TPMAlgECC: + publicKey := publicKey.(*ecdsa.PublicKey) + + if sensitive.SensitiveType != TPMAlgECC { + return nil, fmt.Errorf("sensitive type is not equal to public type") + } + + d, err := sensitive.Sensitive.ECC() + if err != nil { + return nil, fmt.Errorf("failed to retrieve the ECC") + } + + D := new(big.Int).SetBytes(d.Buffer) + + ecdsaKey := &ecdsa.PrivateKey{ + PublicKey: *publicKey, + D: D, + } - privateKey.Precompute() + privateKey = ecdsaKey default: return nil, fmt.Errorf("unsupported public key type: %v", public.Type) } diff --git a/tpm2/crypto_test.go b/tpm2/crypto_test.go new file mode 100644 index 00000000..c9555f8b --- /dev/null +++ b/tpm2/crypto_test.go @@ -0,0 +1,370 @@ +package tpm2 + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "reflect" + "testing" +) + +func TestPriv(t *testing.T) { + + t.Parallel() + + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + ecdsaKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + seed := make([]byte, crypto.SHA256.New().Size()) + rand.Read(seed) + + tests := map[string]struct { + sensitive TPMTSensitive + public TPMTPublic + result bool + }{ + "valid rsa": { + sensitive: TPMTSensitive{ + SensitiveType: TPMAlgRSA, + AuthValue: TPM2BAuth{ + Buffer: nil, + }, + SeedValue: TPM2BDigest{ + Buffer: seed, + }, + Sensitive: NewTPMUSensitiveComposite( + TPMAlgRSA, + &TPM2BPrivateKeyRSA{ + Buffer: rsaKey.Primes[0].Bytes(), + }, + ), + }, + public: TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: false, + EncryptedDuplication: false, + Restricted: true, + Decrypt: true, + SignEncrypt: false, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + KeyBits: TPMKeyBits(rsaKey.PublicKey.N.BitLen()), + Exponent: 0, + Symmetric: TPMTSymDefObject{ + Algorithm: TPMAlgAES, + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + }, + }, + ), + Unique: NewTPMUPublicID( + TPMAlgRSA, + &TPM2BPublicKeyRSA{ + Buffer: rsaKey.PublicKey.N.Bytes(), + }, + ), + }, + result: true, + }, + "valid ecdsa": { + sensitive: TPMTSensitive{ + SensitiveType: TPMAlgECC, + AuthValue: TPM2BAuth{ + Buffer: nil, + }, + SeedValue: TPM2BDigest{ + Buffer: seed, + }, + Sensitive: NewTPMUSensitiveComposite( + TPMAlgECC, + &TPM2BECCParameter{Buffer: ecdsaKey.D.FillBytes(make([]byte, len(ecdsaKey.D.Bytes())))}, + ), + }, + public: TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: false, + EncryptedDuplication: false, + Restricted: true, + Decrypt: true, + SignEncrypt: false, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + CurveID: TPMECCNistP256, + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDSA, + Details: NewTPMUAsymScheme( + TPMAlgECDSA, + &TPMSSigSchemeECDSA{ + HashAlg: TPMAlgSHA256, + }, + ), + }, + }, + ), + Unique: NewTPMUPublicID( + TPMAlgECC, + &TPMSECCPoint{ + X: TPM2BECCParameter{ + Buffer: ecdsaKey.X.Bytes(), + }, + Y: TPM2BECCParameter{ + Buffer: ecdsaKey.Y.Bytes(), + }, + }, + ), + }, + result: true, + }, + "public error": { + sensitive: TPMTSensitive{}, + public: TPMTPublic{ + Type: TPMAlgAES, + }, + result: false, + }, + } + + for name, test := range tests { + test := test + + t.Run(name, func(t *testing.T) { + t.Parallel() + + key, err := Priv(test.public, test.sensitive) + + if (key != nil) != test.result { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + test.result, + key != nil, + ) + } + + if (err == nil) != test.result { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + test.result, + err == nil, + ) + } + + if key != nil { + switch key := key.(type) { + case *rsa.PrivateKey: + if !reflect.DeepEqual(rsaKey, key) { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + rsaKey, + key, + ) + } + case *ecdsa.PrivateKey: + if !reflect.DeepEqual(ecdsaKey, key) { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + ecdsaKey, + key, + ) + } + default: + t.Fatalf("unexpected case") + } + } + }) + } +} + +func TestPub(t *testing.T) { + + t.Parallel() + + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + ecdsaKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + tests := map[string]struct { + public TPMTPublic + result bool + }{ + "valid rsa": { + public: TPMTPublic{ + Type: TPMAlgRSA, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: false, + EncryptedDuplication: false, + Restricted: true, + Decrypt: true, + SignEncrypt: false, + }, + Parameters: NewTPMUPublicParms( + TPMAlgRSA, + &TPMSRSAParms{ + KeyBits: TPMKeyBits(rsaKey.PublicKey.N.BitLen()), + Exponent: 0, + Symmetric: TPMTSymDefObject{ + Algorithm: TPMAlgAES, + Mode: NewTPMUSymMode( + TPMAlgAES, + TPMAlgCFB, + ), + KeyBits: NewTPMUSymKeyBits( + TPMAlgAES, + TPMKeyBits(128), + ), + }, + }, + ), + Unique: NewTPMUPublicID( + TPMAlgRSA, + &TPM2BPublicKeyRSA{ + Buffer: rsaKey.PublicKey.N.Bytes(), + }, + ), + }, + result: true, + }, + "valid ecdsa": { + public: TPMTPublic{ + Type: TPMAlgECC, + NameAlg: TPMAlgSHA256, + ObjectAttributes: TPMAObject{ + FixedTPM: true, + STClear: false, + FixedParent: true, + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: false, + NoDA: false, + EncryptedDuplication: false, + Restricted: true, + Decrypt: true, + SignEncrypt: false, + }, + Parameters: NewTPMUPublicParms( + TPMAlgECC, + &TPMSECCParms{ + CurveID: TPMECCNistP256, + Scheme: TPMTECCScheme{ + Scheme: TPMAlgECDSA, + Details: NewTPMUAsymScheme( + TPMAlgECDSA, + &TPMSSigSchemeECDSA{ + HashAlg: TPMAlgSHA256, + }, + ), + }, + }, + ), + Unique: NewTPMUPublicID( + TPMAlgECC, + &TPMSECCPoint{ + X: TPM2BECCParameter{ + Buffer: ecdsaKey.X.Bytes(), + }, + Y: TPM2BECCParameter{ + Buffer: ecdsaKey.Y.Bytes(), + }, + }, + ), + }, + result: true, + }, + "unsupported algorithm": { + public: TPMTPublic{ + Type: TPMAlgAES, + }, + result: false, + }, + } + + for name, test := range tests { + test := test + + t.Run(name, func(t *testing.T) { + t.Parallel() + + key, err := Pub(test.public) + + if (key != nil) != test.result { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + test.result, + key != nil, + ) + } + + if (err == nil) != test.result { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + test.result, + err == nil, + ) + } + + if key != nil { + switch key := key.(type) { + case *rsa.PublicKey: + if !reflect.DeepEqual(rsaKey.PublicKey, *key) { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + rsaKey.PublicKey, + key, + ) + } + case *ecdsa.PublicKey: + if !reflect.DeepEqual(ecdsaKey.PublicKey, *key) { + t.Errorf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", + ecdsaKey.PublicKey, + key, + ) + } + default: + t.Fatalf("unexpected case") + } + } + }) + } +}