Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exported NewValidator #349

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ type Claims interface {
}
```

Users that previously directly called the `Valid` function on their claims,
e.g., to perform validation independently of parsing/verifying a token, can now
use the `jwt.NewValidator` function to create a `validator` independently of the
oxisto marked this conversation as resolved.
Show resolved Hide resolved
`Parser`.

```go
var v = jwt.NewValidator(jwt.WithLeeway(5*time.Second))
v.Validate(myClaims)
```

### Supported Claim Types and Removal of `StandardClaims`

The two standard claim types supported by this library, `MapClaims` and
Expand Down
14 changes: 7 additions & 7 deletions map_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestVerifyAud(t *testing.T) {
opts = append(opts, WithAudience(test.Comparison))
}

validator := newValidator(opts...)
validator := NewValidator(opts...)
got := validator.Validate(test.MapClaims)

if (got == nil) != test.Expected {
Expand All @@ -77,7 +77,7 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
"iat": "foo",
}
want := false
got := newValidator(WithIssuedAt()).Validate(mapClaims)
got := NewValidator(WithIssuedAt()).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -88,7 +88,7 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
"nbf": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := NewValidator().Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -99,7 +99,7 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
"exp": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := NewValidator().Validate(mapClaims)

if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
Expand All @@ -112,22 +112,22 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
"exp": float64(exp.Unix()),
}
want := false
got := newValidator(WithTimeFunc(func() time.Time {
got := NewValidator(WithTimeFunc(func() time.Time {
return exp
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

got = newValidator(WithTimeFunc(func() time.Time {
got = NewValidator(WithTimeFunc(func() time.Time {
return exp.Add(1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

want = true
got = newValidator(WithTimeFunc(func() time.Time {
got = NewValidator(WithTimeFunc(func() time.Time {
return exp.Add(-1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
Expand Down
6 changes: 3 additions & 3 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Parser struct {
// Skip claims validation during token parsing.
skipClaimsValidation bool

validator *validator
validator *Validator

decodeStrict bool

Expand All @@ -28,7 +28,7 @@ type Parser struct {
// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
validator: &validator{},
validator: &Validator{},
}

// Loop through our parsing options and apply them
Expand Down Expand Up @@ -115,7 +115,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
if !p.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = newValidator()
p.validator = NewValidator()
}

if err := p.validator.Validate(claims); err != nil {
Expand Down
39 changes: 25 additions & 14 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ type ClaimsValidator interface {
Validate() error
}

// validator is the core of the new Validation API. It is automatically used by
// Validator is the core of the new Validation API. It is automatically used by
// a [Parser] during parsing and can be modified with various parser options.
//
// Note: This struct is intentionally not exported (yet) as we want to
// internally finalize its API. In the future, we might make it publicly
// available.
type validator struct {
// The [NewValidator] function should be used to create an instance of this
// struct.
type Validator struct {
oxisto marked this conversation as resolved.
Show resolved Hide resolved
// leeway is an optional leeway that can be provided to account for clock skew.
leeway time.Duration

Expand Down Expand Up @@ -62,16 +61,28 @@ type validator struct {
expectedSub string
}

// newValidator can be used to create a stand-alone validator with the supplied
// NewValidator can be used to create a stand-alone validator with the supplied
// options. This validator can then be used to validate already parsed claims.
func newValidator(opts ...ParserOption) *validator {
//
// Note: Under normal circumstances, explicitly creating a validator is not
// needed and can potentially be dangerous; instead functions of the [Parser]
// class should be used.
//
// The [Validator] is only checking the *validity* of the claims, such as its
// expiration time, but it does NOT perform *signature verification* of the
// token.
func NewValidator(opts ...ParserOption) *Validator {
p := NewParser(opts...)
return p.validator
}

// Validate validates the given claims. It will also perform any custom
// validation if claims implements the [ClaimsValidator] interface.
func (v *validator) Validate(claims Claims) error {
//
// Note: It will NOT perform any *signature verification* on the token that
// contains the claims and expects that the [Claim] was already successfully
// verified.
func (v *Validator) Validate(claims Claims) error {
var (
now time.Time
errs []error = make([]error, 0, 6)
Expand Down Expand Up @@ -149,7 +160,7 @@ func (v *validator) Validate(claims Claims) error {
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
func (v *Validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
exp, err := claims.GetExpirationTime()
if err != nil {
return err
Expand All @@ -170,7 +181,7 @@ func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool)
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
func (v *Validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
iat, err := claims.GetIssuedAt()
if err != nil {
return err
Expand All @@ -191,7 +202,7 @@ func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool)
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
nbf, err := claims.GetNotBefore()
if err != nil {
return err
Expand All @@ -211,7 +222,7 @@ func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool)
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error {
aud, err := claims.GetAudience()
if err != nil {
return err
Expand Down Expand Up @@ -247,7 +258,7 @@ func (v *validator) verifyAudience(claims Claims, cmp string, required bool) err
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
func (v *Validator) verifyIssuer(claims Claims, cmp string, required bool) error {
iss, err := claims.GetIssuer()
if err != nil {
return err
Expand All @@ -267,7 +278,7 @@ func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
func (v *Validator) verifySubject(claims Claims, cmp string, required bool) error {
sub, err := claims.GetSubject()
if err != nil {
return err
Expand Down
20 changes: 10 additions & 10 deletions validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (m MyCustomClaims) Validate() error {
return nil
}

func Test_validator_Validate(t *testing.T) {
func Test_Validator_Validate(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
Expand Down Expand Up @@ -71,7 +71,7 @@ func Test_validator_Validate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
Expand All @@ -86,7 +86,7 @@ func Test_validator_Validate(t *testing.T) {
}
}

func Test_validator_verifyExpiresAt(t *testing.T) {
func Test_Validator_verifyExpiresAt(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
Expand Down Expand Up @@ -117,7 +117,7 @@ func Test_validator_verifyExpiresAt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
}
Expand All @@ -130,7 +130,7 @@ func Test_validator_verifyExpiresAt(t *testing.T) {
}
}

func Test_validator_verifyIssuer(t *testing.T) {
func Test_Validator_verifyIssuer(t *testing.T) {
type fields struct {
expectedIss string
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func Test_validator_verifyIssuer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
v := &Validator{
expectedIss: tt.fields.expectedIss,
}
err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required)
Expand All @@ -171,7 +171,7 @@ func Test_validator_verifyIssuer(t *testing.T) {
}
}

func Test_validator_verifySubject(t *testing.T) {
func Test_Validator_verifySubject(t *testing.T) {
type fields struct {
expectedSub string
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func Test_validator_verifySubject(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
v := &Validator{
expectedSub: tt.fields.expectedSub,
}
err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required)
Expand All @@ -212,7 +212,7 @@ func Test_validator_verifySubject(t *testing.T) {
}
}

func Test_validator_verifyIssuedAt(t *testing.T) {
func Test_Validator_verifyIssuedAt(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
Expand Down Expand Up @@ -248,7 +248,7 @@ func Test_validator_verifyIssuedAt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
Expand Down