This repository has been archived by the owner on Dec 19, 2023. It is now read-only.
forked from marcinwyszynski/kmsjwt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkmsjwt.go
150 lines (120 loc) · 3.49 KB
/
kmsjwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package kmsjwt
import (
"context"
"crypto/sha512"
"crypto/subtle"
"encoding/base64"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
cache "github.com/patrickmn/go-cache"
"github.com/pkg/errors"
)
const kmsAlgorighm = "KMS"
// ErrKmsVerification is an error shown when KMS token verification fails.
var ErrKmsVerification = errors.New("kms: verification error")
// ErrInvalidKey indicates taht the key is invalid.
var ErrInvalidKey = errors.New("key is invalid")
type KMSJWT struct {
kmsiface.KMSAPI
algorithm string
cache *cache.Cache
kmsKeyID string
withCache bool
defaultExpiration time.Duration
cleanupInterval time.Duration
signingAlgorithm string
}
// New provides a KMS-based implementation of JWT signing method.
func New(client kmsiface.KMSAPI, kmsKeyID string, opts ...Option) *KMSJWT {
ret := &KMSJWT{
KMSAPI: client,
algorithm: kmsAlgorighm,
kmsKeyID: kmsKeyID,
withCache: true,
defaultExpiration: time.Hour,
cleanupInterval: time.Minute,
signingAlgorithm: kms.SigningAlgorithmSpecRsassaPssSha512,
}
for _, opt := range opts {
opt(ret)
}
if ret.withCache {
ret.cache = cache.New(ret.defaultExpiration, ret.cleanupInterval)
}
return ret
}
func (k *KMSJWT) Alg() string {
return k.algorithm
}
func (k *KMSJWT) Sign(signingString string, key interface{}) (string, error) {
ctx, ok := key.(context.Context)
if !ok {
return "", errors.New("key is not a context")
}
out, err := k.SignWithContext(ctx, &kms.SignInput{
KeyId: aws.String(k.kmsKeyID),
Message: checksum(signingString),
MessageType: aws.String("DIGEST"),
SigningAlgorithm: aws.String(k.signingAlgorithm),
})
if err != nil && errors.Is(err, context.Canceled) {
return "", err
} else if err != nil {
return "", errors.Wrap(err, "key is invalid")
}
if k.cache != nil {
k.cache.SetDefault(signingString, out.Signature)
}
return base64.StdEncoding.EncodeToString(out.Signature), nil
}
func (k *KMSJWT) Verify(signingString, stringSignature string, key interface{}) error {
ctx, ok := key.(context.Context)
if !ok {
return errors.New("key is not a context")
}
signature, err := base64.StdEncoding.DecodeString(stringSignature)
if err != nil {
return errors.New("invalid signature encoding")
}
if k.verifyCache(signingString, signature) {
return nil
}
out, err := k.VerifyWithContext(ctx, &kms.VerifyInput{
KeyId: aws.String(k.kmsKeyID),
Message: checksum(signingString),
MessageType: aws.String("DIGEST"),
Signature: signature,
SigningAlgorithm: aws.String(k.signingAlgorithm),
})
if err != nil && errors.Is(err, context.Canceled) {
return err
} else if err == nil && (out.SignatureValid == nil || !(*out.SignatureValid)) {
return ErrKmsVerification
} else if err != nil {
return errors.Wrap(err, ErrKmsVerification.Error())
}
if k.cache != nil {
k.cache.SetDefault(signingString, signature)
}
return nil
}
func (k *KMSJWT) verifyCache(signingString string, providedSignature []byte) bool {
if k.cache == nil {
return false
}
untypedCached, isCached := k.cache.Get(signingString)
if !isCached {
return false
}
typedCached, typeOK := untypedCached.([]byte)
if !typeOK {
return false
}
return subtle.ConstantTimeCompare(typedCached, providedSignature) == 1
}
func checksum(in string) []byte {
out := sha512.Sum512([]byte(in))
return out[:]
}