Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allocation optimization #33

Merged
merged 4 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,12 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
keyBytes += 1
}

// We serialize the outpus (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)

sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)

out := append(rBytesPadded, sBytesPadded...)
// We serialize the outputs (r and s) into big-endian byte arrays
// padded with zeros on the left to make sure the sizes work out.
// Output must be 2*keyBytes long.
out := make([]byte, 2*keyBytes)
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.

return EncodeSegment(out), nil
} else {
Expand Down
48 changes: 47 additions & 1 deletion ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,60 @@ func TestECDSASign(t *testing.T) {

if data.valid {
parts := strings.Split(data.tokenString, ".")
toSign := strings.Join(parts[0:2], ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), ecdsaKey)
sig, err := method.Sign(toSign, ecdsaKey)

if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}

err = method.Verify(toSign, sig, ecdsaKey.Public())
if err != nil {
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
}
}
}
}

func BenchmarkECDSASigning(b *testing.B) {
for _, data := range ecdsaTestData {
key, _ := ioutil.ReadFile(data.keys["private"])

ecdsaKey, err := jwt.ParseECPrivateKeyFromPEM(key)
if err != nil {
b.Fatalf("Unable to parse ECDSA private key: %v", err)
}

method := jwt.GetSigningMethod(data.alg)

b.Run(data.name, func(b *testing.B) {
benchmarkSigning(b, method, ecdsaKey)
})

// Directly call method.Sign without the decoration of *Token.
b.Run(data.name+"/sign-only", func(b *testing.B) {
if !data.valid {
b.Skipf("Skipping because data is not valid")
}

parts := strings.Split(data.tokenString, ".")
toSign := strings.Join(parts[0:2], ".")

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
sig, err := method.Sign(toSign, ecdsaKey)
if err != nil {
b.Fatalf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}
}
})
}
}
2 changes: 2 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ func TestParser_ParseUnverified(t *testing.T) {
// Helper method for benchmarking various methods
func benchmarkSigning(b *testing.B, method jwt.SigningMethod, key interface{}) {
t := jwt.New(method)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := t.SignedString(key); err != nil {
Expand Down
8 changes: 2 additions & 6 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,10 @@ func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token

// Encode JWT specific base64url encoding with padding stripped
func EncodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
return base64.RawURLEncoding.EncodeToString(seg)
}

// Decode JWT specific base64url encoding with padding stripped
func DecodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}

return base64.URLEncoding.DecodeString(seg)
return base64.RawURLEncoding.DecodeString(seg)
}