diff --git a/attribute.go b/attribute.go index bcf5a05..473d230 100644 --- a/attribute.go +++ b/attribute.go @@ -139,7 +139,7 @@ func NewIFID(addr net.HardwareAddr) (Attribute, error) { // attribute length is invalid, the secret is empty, or the requestAuthenticator // length is invalid. func UserPassword(a Attribute, secret, requestAuthenticator []byte) ([]byte, error) { - if len(a) < 16 || len(a) > 128 { + if len(a) < 16 || len(a) > 128 || len(a)%16 != 0 { return nil, errors.New("invalid attribute length (" + strconv.Itoa(len(a)) + ")") } if len(secret) == 0 { @@ -204,8 +204,8 @@ func NewUserPassword(plaintext, secret, requestAuthenticator []byte) (Attribute, hash.Write(requestAuthenticator) enc = hash.Sum(enc) - for i, b := range plaintext[:16] { - enc[i] ^= b + for i := 0; i < 16 && i < len(plaintext); i++ { + enc[i] ^= plaintext[i] } for i := 16; i < len(plaintext); i += 16 { @@ -214,8 +214,8 @@ func NewUserPassword(plaintext, secret, requestAuthenticator []byte) (Attribute, hash.Write(enc[i-16 : i]) enc = hash.Sum(enc) - for j, b := range plaintext[i : i+16] { - enc[i+j] ^= b + for j := 0; j < 16 && i+j < len(plaintext); j++ { + enc[i+j] ^= plaintext[i+j] } } diff --git a/attribute_test.go b/attribute_test.go index 2df7dc5..9935b4e 100644 --- a/attribute_test.go +++ b/attribute_test.go @@ -10,28 +10,36 @@ import ( func TestNewUserPassword_length(t *testing.T) { tbl := []struct { - Password string + Password []byte EncodedLength int }{ - {"", 16}, - {"abc", 16}, - {"0123456789abcde", 16}, - {"0123456789abcdef", 16}, - {"0123456789abcdef0", 16 * 2}, - {"0123456789abcdef0123456789abcdef0123456789abcdef", 16 * 3}, + {append(make([]byte, 0, 0), ""...), 16}, + {append(make([]byte, 0, 15), "abc"...), 16}, + {append(make([]byte, 0, 15), "0123456789abcde"...), 16}, + {append(make([]byte, 0, 16), "0123456789abcdef"...), 16}, + {append(make([]byte, 0, 30), "0123456789abcdef0"...), 16 * 2}, + {append(make([]byte, 0, 48), "0123456789abcdefzzzzzzzzzzzzzzzzQQQQQQQQQQQQQQQQ"...), 16 * 3}, } secret := []byte(`12345`) ra := []byte(`0123456789abcdef`) for _, x := range tbl { - attr, err := NewUserPassword([]byte(x.Password), secret, ra) + attr, err := NewUserPassword(x.Password, secret, ra) if err != nil { t.Fatal(err) } if len(attr) != x.EncodedLength { t.Fatalf("expected encoded length of %#v = %d, got %d", x.Password, x.EncodedLength, len(attr)) } + + decoded, err := UserPassword(attr, secret, ra) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded, x.Password) { + t.Fatalf("expected roundtrip to succeed") + } } }