-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HdonCr/Allow multiple audiences #427
base: main
Are you sure you want to change the base?
Changes from all commits
8f964e5
ca00a1f
ab835f5
f1c5576
ca1d34c
e107f51
8e711bc
de939e0
622a310
9144621
773bd50
2728fcc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,9 +51,12 @@ type Validator struct { | |
// unrealistic, i.e., in the future. | ||
verifyIat bool | ||
|
||
// expectedAud contains the audience this token expects. Supplying an empty | ||
// string will disable aud checking. | ||
expectedAud string | ||
//expectedAuds contains the audiences this token expects. Supplying an empty | ||
// []string will disable auds checking. | ||
expectedAuds []string | ||
|
||
// matchAllAud specifies whether all expected audiences must match all auds from claim | ||
matchAllAud bool | ||
|
||
// expectedIss contains the issuer this token expects. Supplying an empty | ||
// string will disable iss checking. | ||
|
@@ -119,9 +122,9 @@ func (v *Validator) Validate(claims Claims) error { | |
} | ||
} | ||
|
||
// If we have an expected audience, we also require the audience claim | ||
if v.expectedAud != "" { | ||
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil { | ||
// If we have expected audiences, we also require the audiences claim | ||
if len(v.expectedAuds) > 0 { | ||
if err := v.verifyAudiences(claims, v.expectedAuds, true, v.matchAllAud); err != nil { | ||
errs = append(errs, err) | ||
} | ||
} | ||
|
@@ -219,40 +222,124 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) | |
return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet) | ||
} | ||
|
||
// verifyAudience compares the aud claim against cmp. | ||
// / verifyAudiences compares the aud claim against cmps. | ||
// If matchAllAuds is true, all cmps must match a aud. | ||
// If matchAllAuds is false, at least one cmp must match a aud. | ||
// | ||
// If matchAllAuds is true and aud length does not match cmps length, an ErrTokenInvalidAudience error will be returned. | ||
// Note that this does not account for any duplicate aud or cmps | ||
// | ||
// If aud is not set or an empty list, it will succeed if the claim is not required, | ||
// otherwise ErrTokenRequiredClaimMissing will be returned. | ||
// | ||
// Additionally, if any error occurs while retrieving the claim, e.g., when its | ||
// the wrong type, an ErrTokenUnverifiable error will be returned. | ||
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error { | ||
func (v *Validator) verifyAudiences(claims Claims, cmps []string, required bool, matchAllAuds bool) error { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems to be very complex, I wonder if we could streamline this implementation a little bit. Furthermore, it should be possible to include the previous implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. I would propose to remove the validation for single audiences, since the I have removed the Note that this is a breaking change. |
||
|
||
// Get the audience claim(s) from the token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oxisto, regarding your first comment here: https://github.com/golang-jwt/jwt/pull/427/files#r1917080944 I agree that this function is quite complex. I think this is the result of the option to match all auds. Deduplicating both situations is difficult, making the function very long due to two separate flows. Additionally, the situation in which at least one cmp must match an aud is relatively simple. However, all cmps matching all auds is more difficult to implement. I have added documentation in an attempt to make the function more legible. If required, I can look into simplifying it further. |
||
aud, err := claims.GetAudience() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// If no audience is provided, return an error if required | ||
if len(aud) == 0 { | ||
return errorIfRequired(required, "aud") | ||
} | ||
|
||
// use a var here to keep constant time compare when looping over a number of claims | ||
result := false | ||
// Deduplicate the aud and cmps slices | ||
aud = deduplicateStrings(aud) | ||
cmps = deduplicateStrings(cmps) | ||
|
||
var stringClaims string | ||
for _, a := range aud { | ||
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { | ||
result = true | ||
|
||
// If matchAllAuds is true, check if all the cmps matches any of the aud | ||
if matchAllAuds { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oxisto, I am not sure whether this |
||
|
||
// cmps and aud length should match if matchAllAuds is true | ||
if len(cmps) != len(aud) { | ||
return errorIfFalse(false, ErrTokenInvalidAudience) | ||
} | ||
|
||
// Check all cmps values | ||
for _, cmp := range cmps { | ||
matchFound := false | ||
|
||
// Check all aud values | ||
for _, a := range aud { | ||
|
||
// Perform constant time comparison | ||
result := subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. quite a large duplicate code fragment, this is essentially what the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment here: https://github.com/golang-jwt/jwt/pull/427/files#r1948157329 |
||
|
||
// Concatenate all aud values to stringClaims | ||
stringClaims = stringClaims + a | ||
|
||
// If a match is found, set matchFound to true and break out of inner aud loop and continue to next cmp | ||
if result { | ||
matchFound = true | ||
break | ||
} | ||
} | ||
|
||
// If no match was found for the current cmp, return a ErrTokenInvalidAudience error | ||
if !matchFound { | ||
return ErrTokenInvalidAudience | ||
} | ||
} | ||
|
||
} else { | ||
// if matchAllAuds is false, check if any of the cmps matches any of the aud | ||
|
||
matchFound := false | ||
|
||
// Label to break out of both loops if a match is found | ||
outer: | ||
|
||
// Check all aud values | ||
for _, a := range aud { | ||
|
||
// Check all cmp values | ||
for _, cmp := range cmps { | ||
|
||
// Perform constant time comparison | ||
result := subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment here: https://github.com/golang-jwt/jwt/pull/427/files#r1948157329 |
||
|
||
// Concatenate all aud values to stringClaims | ||
stringClaims = stringClaims + a | ||
|
||
// If a match is found, break out of both loops and finish comparison | ||
if result { | ||
matchFound = true | ||
break outer | ||
} | ||
} | ||
} | ||
|
||
// If no match was found for any cmp, return an error | ||
if !matchFound { | ||
return errorIfFalse(false, ErrTokenInvalidAudience) | ||
} | ||
stringClaims = stringClaims + a | ||
} | ||
|
||
// case where "" is sent in one or many aud claims | ||
if stringClaims == "" { | ||
return errorIfRequired(required, "aud") | ||
} | ||
|
||
return errorIfFalse(result, ErrTokenInvalidAudience) | ||
return nil | ||
} | ||
|
||
// deduplicateStrings removes duplicate elements from a string slice | ||
func deduplicateStrings(slice []string) []string { | ||
unique := make(map[string]bool) | ||
var result []string | ||
for _, item := range slice { | ||
if _, found := unique[item]; !found { | ||
unique[item] = true | ||
result = append(result, item) | ||
} | ||
} | ||
return result | ||
} | ||
|
||
// verifyIssuer compares the iss claim in claims against cmp. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we could merge this with the
expectedAud
aboveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment here: https://github.com/golang-jwt/jwt/pull/427/files#r1948157329