diff --git a/errors.go b/errors.go index a30bb96..0b52f61 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,7 @@ package nkeys // Errors const ( ErrInvalidPrefixByte = nkeysError("nkeys: invalid prefix byte") + ErrDuplicatePrefixByte = nkeysError("nkeys: prefix byte already present") ErrInvalidKey = nkeysError("nkeys: invalid key") ErrInvalidPublicKey = nkeysError("nkeys: invalid public key") ErrInvalidPrivateKey = nkeysError("nkeys: invalid private key") diff --git a/nkeys_test.go b/nkeys_test.go index e0db328..1693617 100644 --- a/nkeys_test.go +++ b/nkeys_test.go @@ -659,7 +659,7 @@ func TestValidateKeyPairRole(t *testing.T) { t.Fatal(err) } - var keyroles = []struct { + keyroles := []struct { kp KeyPair roles []PrefixByte ok bool @@ -685,7 +685,6 @@ func TestValidateKeyPairRole(t *testing.T) { } if err != nil && e.ok { t.Fatalf("test %q should have not failed: %v", e.name, err) - } if err != nil && !e.ok && err != ErrIncompatibleKey { t.Fatalf("unexpected error type for %q: %v", e.name, err) @@ -729,3 +728,75 @@ func TestSealOpen(t *testing.T) { testSealOpen(t, PrefixByteAccount) testSealOpen(t, PrefixByteUser) } + +func TestCustomPublicPrefix(t *testing.T) { + var modulePrefix PrefixByte = 12 << 3 // Base32-encodes to 'M...' + + AddPublicPrefix(modulePrefix, "module") + if v := modulePrefix.String(); v != "module" { + t.Fatalf("Expected 'module', got %v", v) + } + + testSealOpen(t, modulePrefix) + + module, err := CreatePair(modulePrefix) + if err != nil { + t.Fatalf("Expected non-nill error on CreatePair with custom prefix, received %v", err) + } + + if module == nil { + t.Fatal("Expect a non-nil keypair") + } + + seed, err := module.Seed() + if err != nil { + t.Fatalf("Unexpected error retrieving seed: %v", err) + } + + _, err = Decode(PrefixByteSeed, seed) + if err != nil { + t.Fatalf("Expected a proper seed string, got %s", seed) + } + + // Check Public + public, err := module.PublicKey() + if err != nil { + t.Fatalf("Received an error retrieving public key: %v", err) + } + if public[0] != 'M' { + t.Fatalf("Expected a prefix of 'M' but got %c", public[0]) + } + + if _, err := Decode(modulePrefix, []byte(public)); err != nil { + t.Fatalf("Not a valid public key") + } + + // Check Private + private, err := module.PrivateKey() + if err != nil { + t.Fatalf("Received an error retrieving private key: %v", err) + } + if private[0] != 'P' { + t.Fatalf("Expected a prefix of 'P' but got %v", private[0]) + } + + // Check Sign and Verify + data := []byte("Hello World") + sig, err := module.Sign(data) + if err != nil { + t.Fatalf("Unexpected error signing from custom prefix: %v", err) + } + if len(sig) != ed25519.SignatureSize { + t.Fatalf("Expected signature size of %d but got %d", + ed25519.SignatureSize, len(sig)) + } + err = module.Verify(data, sig) + if err != nil { + t.Fatalf("Unexpected error verifying signature: %v", err) + } + + RemovePublicPrefix(modulePrefix) + if v := modulePrefix.String(); v != "unknown" { + t.Fatalf("Expected 'unknown', got %v", v) + } +} diff --git a/strkey.go b/strkey.go index 8ae3311..e5eca01 100644 --- a/strkey.go +++ b/strkey.go @@ -51,9 +51,41 @@ const ( PrefixByteUnknown PrefixByte = 25 << 3 // Base32-encodes to 'Z...' ) +var publicPrefixes = map[PrefixByte]string{ + PrefixByteOperator: "operator", + PrefixByteServer: "server", + PrefixByteCluster: "cluster", + PrefixByteAccount: "account", + PrefixByteUser: "user", + PrefixByteCurve: "x25519", +} + +var privatePrefixes = map[PrefixByte]string{ + PrefixByteSeed: "seed", + PrefixBytePrivate: "private", +} + // Set our encoding to not include padding '==' var b32Enc = base32.StdEncoding.WithPadding(base32.NoPadding) +// AddPublicPrefix adds a public prefix byte. Must not collide with existing prefixes. +func AddPublicPrefix(prefix PrefixByte, name string) error { + if _, ok := publicPrefixes[prefix]; ok { + return ErrDuplicatePrefixByte + } + publicPrefixes[prefix] = name + return nil +} + +// RemovePublicPrefix removes a public prefix byte. Must be a valid prefix. +func RemovePublicPrefix(prefix PrefixByte) error { + if _, ok := publicPrefixes[prefix]; !ok { + return ErrInvalidPrefixByte + } + delete(publicPrefixes, prefix) + return nil +} + // Encode will encode a raw key or seed with the prefix and crc16 and then base32 encoded. func Encode(prefix PrefixByte, src []byte) ([]byte, error) { if err := checkValidPrefixByte(prefix); err != nil { @@ -257,43 +289,31 @@ func IsValidPublicCurveKey(src string) bool { // checkValidPrefixByte returns an error if the provided value // is not one of the defined valid prefix byte constants. func checkValidPrefixByte(prefix PrefixByte) error { - switch prefix { - case PrefixByteOperator, PrefixByteServer, PrefixByteCluster, - PrefixByteAccount, PrefixByteUser, PrefixByteSeed, PrefixBytePrivate, PrefixByteCurve: + if _, ok := privatePrefixes[prefix]; ok { return nil } - return ErrInvalidPrefixByte + + return checkValidPublicPrefixByte(prefix) } // checkValidPublicPrefixByte returns an error if the provided value // is not one of the public defined valid prefix byte constants. func checkValidPublicPrefixByte(prefix PrefixByte) error { - switch prefix { - case PrefixByteOperator, PrefixByteServer, PrefixByteCluster, PrefixByteAccount, PrefixByteUser, PrefixByteCurve: + if _, ok := publicPrefixes[prefix]; ok { return nil } return ErrInvalidPrefixByte } func (p PrefixByte) String() string { - switch p { - case PrefixByteOperator: - return "operator" - case PrefixByteServer: - return "server" - case PrefixByteCluster: - return "cluster" - case PrefixByteAccount: - return "account" - case PrefixByteUser: - return "user" - case PrefixByteSeed: - return "seed" - case PrefixBytePrivate: - return "private" - case PrefixByteCurve: - return "x25519" + if v, ok := privatePrefixes[p]; ok { + return v } + + if v, ok := publicPrefixes[p]; ok { + return v + } + return "unknown" }