diff --git a/go.mod b/go.mod index cbd144d..7d81d0c 100644 --- a/go.mod +++ b/go.mod @@ -4,3 +4,5 @@ require ( golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5 golang.org/x/sys v0.0.0-20190412213103-97732733099d ) + +go 1.13 diff --git a/prime.go b/prime.go index 9d658da..298fa86 100644 --- a/prime.go +++ b/prime.go @@ -81,10 +81,8 @@ func isGenerator(g, p *big.Int) bool { func ok(g, x *big.Int, p *big.Int) bool { z := big.NewInt(0).Exp(g, x, p) - if z.Cmp(one) != 0 { // the expmod should NOT be 1 - return true - } - return false + // the expmod should NOT be 1 + return z.Cmp(one) != 0 } // vim: noexpandtab:sw=8:ts=8:tw=92: diff --git a/srp.go b/srp.go index a44c178..2bf2bb9 100644 --- a/srp.go +++ b/srp.go @@ -181,14 +181,14 @@ func NewWithHash(h crypto.Hash, bits int) (*SRP, error) { func ServerBegin(creds string) (string, *big.Int, error) { v := strings.Split(creds, ":") if len(v) != 2 { - return "", nil, fmt.Errorf("invalid client public key") + return "", nil, fmt.Errorf("srp: invalid client public key") } //fmt.Printf("v0: %s\nv1: %s\n", v[0], v[1]) A, ok := big.NewInt(0).SetString(v[1], 16) if !ok { - return "", nil, fmt.Errorf("Invalid client public key A") + return "", nil, fmt.Errorf("srp: invalid client public key A") } return v[0], A, nil @@ -374,29 +374,29 @@ func (c *Client) Credentials() string { func (c *Client) Generate(srv string) (string, error) { v := strings.Split(srv, ":") if len(v) != 2 { - return "", fmt.Errorf("invalid server public key") + return "", fmt.Errorf("srp: invalid server public key") } salt, err := hex.DecodeString(v[0]) if err != nil { - return "", fmt.Errorf("invalid server public key") + return "", fmt.Errorf("srp: invalid server public key") } B, ok1 := big.NewInt(0).SetString(v[1], 16) if !ok1 { - return "", fmt.Errorf("invalid server public key") + return "", fmt.Errorf("srp: invalid server public key") } pf := c.s.pf zero := big.NewInt(0) z := big.NewInt(0).Mod(B, pf.N) if zero.Cmp(z) == 0 { - return "", fmt.Errorf("invalid server public key") + return "", fmt.Errorf("srp: invalid server public key") } u := c.s.hashint(pad(c.xA, pf.n), pad(B, pf.n)) if u.Cmp(zero) == 0 { - return "", fmt.Errorf("invalid server public key") + return "", fmt.Errorf("srp: invalid server public key") } // S := ((B - kg^x) ^ (a + ux)) % N @@ -457,7 +457,7 @@ func (s *SRP) NewServer(v *Verifier, A *big.Int) (*Server, error) { zero := big.NewInt(0) z := big.NewInt(0).Mod(A, pf.N) if zero.Cmp(z) == 0 { - return nil, fmt.Errorf("invalid client public key") + return nil, fmt.Errorf("srp: invalid client public key") } sx := &Server{ @@ -483,7 +483,7 @@ func (s *SRP) NewServer(v *Verifier, A *big.Int) (*Server, error) { u := s.hashint(pad(A, pf.n), pad(B, pf.n)) if u.Cmp(zero) == 0 { - return nil, fmt.Errorf("Invalid client public key u") + return nil, fmt.Errorf("srp: invalid client public key u") } t0 = big.NewInt(0).Mul(A, big.NewInt(0).Exp(sx.v, u, pf.N)) @@ -570,7 +570,7 @@ func pad(x *big.Int, n int) []byte { b := x.Bytes() if len(b) < n { z := n - len(b) - p := make([]byte, n, n) + p := make([]byte, n) for i := 0; i < z; i++ { p[i] = 0 } @@ -595,7 +595,7 @@ func randbytes(n int) []byte { // Generate and return a bigInt 'bits' bits in length func randBigInt(bits int) *big.Int { n := bits / 8 - if 0 != bits%8 { + if (bits%8) != 0 { n += 1 } b := randbytes(n) @@ -646,7 +646,7 @@ func newPrimeField(nbits int) (*primeField, error) { } } } - return nil, fmt.Errorf("can't find generator after 100 tries") + return nil, fmt.Errorf("srp: can't find generator after 100 tries") } // Find a pre-generated safe-prime and its generator from our list below. @@ -685,7 +685,7 @@ func init() { N: atobi(v[2], 0), n: b / 8, } - if 0 == big.NewInt(0).Cmp(pf.N) { + if big.NewInt(0).Cmp(pf.N) == 0 { panic(fmt.Sprintf("srp init: N (%s) is zero", v[2])) } pflist[b] = pf diff --git a/srp_test.go b/srp_test.go index f1ede4a..faf499b 100644 --- a/srp_test.go +++ b/srp_test.go @@ -31,9 +31,6 @@ func newAsserter(t *testing.T) func(cond bool, msg string, args ...interface{}) } } -type user struct { - v *Verifier -} type userdb struct { s *SRP @@ -133,7 +130,7 @@ func (db *userdb) verify(t *testing.T, user, pass []byte, goodPw bool) { kc := c.RawKey() ks := srv.RawKey() - assert(1 == subtle.ConstantTimeCompare(kc, ks), "key mismatch;\nclient %x, server %x", kc, ks) + assert(subtle.ConstantTimeCompare(kc, ks) == 1, "key mismatch;\nclient %x, server %x", kc, ks) } func TestSRP(t *testing.T) { @@ -154,38 +151,3 @@ func TestSRP(t *testing.T) { } } -func mustDecode(s string) []byte { - n := len(s) - b := make([]byte, 0, n) - var z, x byte - var shift uint = 4 - for i := 0; i < n; i++ { - c := s[i] - switch { - case '0' <= c && c <= '9': - x = c - '0' - case 'a' <= c && c <= 'f': - x = c - 'a' + 10 - case 'A' <= c && c <= 'F': - x = c - 'A' + 10 - case c == ' ' || c == '\n' || c == '\t': - continue - default: - panic(fmt.Sprintf("invalid hex char %c in %s", c, s)) - } - - if shift == 0 { - z |= x - b = append(b, z) - z = 0 - shift = 4 - } else { - z |= (x << shift) - shift -= 4 - } - } - if shift != 4 { - b = append(b, z) - } - return b -}