Skip to content

Commit

Permalink
(#6) support adding chain issuer data when issuer is in vault
Browse files Browse the repository at this point in the history
Previously we could only sign direct authorized JWTs using vault
but when the issuer is created in vault and we wish to delegate
issuing clients we need to be able to get the issuer data signed
by vault, which is now possible.

Signed-off-by: R.I.Pienaar <[email protected]>
  • Loading branch information
ripienaar committed Jan 18, 2023
1 parent de9d178 commit d478f52
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codeql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v3

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
Expand Down
40 changes: 33 additions & 7 deletions standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
package tokens

import (
"context"
"crypto/ed25519"
"crypto/tls"
"encoding/hex"
"fmt"
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/segmentio/ksuid"
"github.com/sirupsen/logrus"
)

type StandardClaims struct {
Expand All @@ -32,13 +35,13 @@ type StandardClaims struct {
}

// ExpireTime determines the expiry time based on issuer expiry and token expiry
func (s *StandardClaims) ExpireTime() time.Time {
func (c *StandardClaims) ExpireTime() time.Time {
var iexp, exp time.Time
if s.IssuerExpiresAt != nil {
iexp = s.IssuerExpiresAt.Time
if c.IssuerExpiresAt != nil {
iexp = c.IssuerExpiresAt.Time
}
if s.ExpiresAt != nil {
exp = s.ExpiresAt.Time
if c.ExpiresAt != nil {
exp = c.ExpiresAt.Time
}

if iexp.IsZero() {
Expand All @@ -57,8 +60,8 @@ func (s *StandardClaims) ExpireTime() time.Time {
}

// IsExpired checks if the token has expired
func (s *StandardClaims) IsExpired() bool {
return time.Now().After(s.ExpireTime())
func (c *StandardClaims) IsExpired() bool {
return time.Now().After(c.ExpireTime())
}

// AddOrgIssuerData adds the data that a Chain Issuer needs to be able to issue clients in an Org managed by an Issuer
Expand All @@ -79,6 +82,29 @@ func (c *StandardClaims) AddOrgIssuerData(priK ed25519.PrivateKey) error {
return nil
}

// AddOrgIssuerDataUsingVault adds the data that a Chain Issuer needs to be able to issue clients in an Org managed by an Issuer by using Vault to sign the data using the named key `priK`.
func (c *StandardClaims) AddOrgIssuerDataUsingVault(ctx context.Context, tlsc *tls.Config, priK string, log *logrus.Entry) error {
issuer, err := getVaultIssuerPubKey(ctx, tlsc, priK, log)
if err != nil {
return err
}

dat, err := c.OrgIssuerChainData()
if err != nil {
return err
}

sig, err := signWithVault(ctx, tlsc, priK, dat, log)
if err != nil {
return err
}

c.SetOrgIssuer(issuer)
c.SetChainIssuerTrustSignature(sig)

return nil
}

// AddChainIssuerData adds the data that a Signed token needs from a Chain Issuer in an Org managed by an Issuer
func (c *StandardClaims) AddChainIssuerData(chainIssuer *ClientIDClaims, prik ed25519.PrivateKey) error {
err := c.SetChainIssuer(chainIssuer)
Expand Down
124 changes: 103 additions & 21 deletions tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ func ParseToken(token string, claims jwt.Claims, pk any) error {

var sc *StandardClaims

// if its a client and from a chain we will verify it using the chain issuer pubk
// if it's a client and from a chain we will verify it using the chain issuer pubk
client, ok := claims.(*ClientIDClaims)
if ok && strings.HasPrefix(client.Issuer, ChainIssuerPrefix) {
sc = &client.StandardClaims
}

// if its a server and from a chain we will verify it using the chain issuer pubk
// if it's a server and from a chain we will verify it using the chain issuer pubk
server, ok := claims.(*ServerClaims)
if ok && strings.HasPrefix(server.Issuer, ChainIssuerPrefix) {
sc = &server.StandardClaims
Expand Down Expand Up @@ -161,7 +161,7 @@ func TokenPurposeBytes(token []byte) Purpose {
return TokenPurpose(string(token))
}

// SignTokenWithKeyFile signs a JWT using a RSA Private Key in PEM format
// SignTokenWithKeyFile signs a JWT using an RSA Private Key in PEM format
func SignTokenWithKeyFile(claims jwt.Claims, pkFile string) (string, error) {
keydat, err := os.ReadFile(pkFile)
if err != nil {
Expand Down Expand Up @@ -189,7 +189,7 @@ func SignTokenWithKeyFile(claims jwt.Claims, pkFile string) (string, error) {
return "", fmt.Errorf("unsupported key in %v", pkFile)
}

// SignToken signs a JWT using a RSA Private Key
// SignToken signs a JWT using an RSA Private Key
func SignToken(claims jwt.Claims, pk any) (string, error) {
var stoken string
var err error
Expand Down Expand Up @@ -224,34 +224,98 @@ func SaveAndSignTokenWithKeyFile(claims jwt.Claims, pkFile string, outFile strin
return os.WriteFile(outFile, []byte(token), perm)
}

// SaveAndSignTokenWithVault signs a token using the named key in a Vault Transit engine. Requires VAULT_TOKEN and VAULT_ADDR to be set.
func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key string, outFile string, perm os.FileMode, tlsc *tls.Config, log *logrus.Entry) error {
func getVaultIssuerPubKey(ctx context.Context, tlsc *tls.Config, key string, log *logrus.Entry) (ed25519.PublicKey, error) {
vt := os.Getenv("VAULT_TOKEN")
va := os.Getenv("VAULT_ADDR")

if vt == "" || va == "" {
return fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
return nil, fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
}

token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
ss, err := token.SigningString()
uri, err := url.Parse(va)
if err != nil {
return err
return nil, err
}

uri.Path = fmt.Sprintf("/v1/transit/keys/%s", key)
client := &http.Client{}
if tlsc != nil {
client.Transport = &http.Transport{TLSClientConfig: tlsc}
}

req, err := http.NewRequestWithContext(ctx, "GET", uri.String(), nil)
if err != nil {
return nil, err
}
req.Header.Add("X-Vault-Token", vt)

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if resp.StatusCode != 200 {
return nil, fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
}

log.Debugf("JSON Response: %s", string(body))

var vr struct {
Data struct {
Keys map[string]struct {
PublicKey []byte `json:"public_key"`
} `json:"keys"`
} `json:"data"`
}
err = json.Unmarshal(body, &vr)
if err != nil {
return nil, err
}

if len(vr.Data.Keys) == 0 {
return nil, fmt.Errorf("did not receive keys in response")
}

pk, ok := vr.Data.Keys["1"]
if !ok {
return nil, fmt.Errorf("did not receive keys in response")
}

if len(pk.PublicKey) != ed25519.PublicKeySize {
return nil, fmt.Errorf("did not receive a valid public key in response")
}

return ed25519.PublicKey(pk.PublicKey), nil
}

func signWithVault(ctx context.Context, tlsc *tls.Config, key string, ss []byte, log *logrus.Entry) ([]byte, error) {
vt := os.Getenv("VAULT_TOKEN")
va := os.Getenv("VAULT_ADDR")

if vt == "" || va == "" {
return nil, fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
}

uri, err := url.Parse(va)
if err != nil {
return err
return nil, err
}

uri.Path = fmt.Sprintf("/v1/transit/sign/%s", key)

dat := map[string]any{
"signature_algorithm": "ed25519",
"input": base64.StdEncoding.EncodeToString([]byte(ss)),
"input": base64.StdEncoding.EncodeToString(ss),
}
jdat, err := json.Marshal(dat)
if err != nil {
return err
return nil, err
}
log.Debugf("JSON Request: %s", string(jdat))

Expand All @@ -262,48 +326,66 @@ func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key strin

req, err := http.NewRequestWithContext(ctx, "POST", uri.String(), bytes.NewBuffer(jdat))
if err != nil {
return err
return nil, err
}
req.Header.Add("X-Vault-Token", vt)

resp, err := client.Do(req)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err
return nil, err
}

if resp.StatusCode != 200 {
return fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
}

log.Debugf("JSON Response: %s", string(body))

var vr struct {
Data struct {
Sig string `json:"signature"`
} `json:"data"`
}
err = json.Unmarshal(body, &vr)
if err != nil {
return err
return nil, err
}

if vr.Data.Sig == "" {
return fmt.Errorf("no signature in response: %s", string(body))
return nil, fmt.Errorf("no signature in response: %s", string(body))
}

const vaultSigPrefix = "vault:v1:"

if !strings.HasPrefix(vr.Data.Sig, vaultSigPrefix) {
return fmt.Errorf("invalid signature, no vault:v1 prefix")
return nil, fmt.Errorf("invalid signature, no vault:v1 prefix")
}

signature, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(vr.Data.Sig, vaultSigPrefix))
if err != nil {
return fmt.Errorf("could not decode vault response: %w", err)
return nil, fmt.Errorf("could not decode vault response: %w", err)
}

return signature, nil
}

// SaveAndSignTokenWithVault signs a token using the named key in a Vault Transit engine. Requires VAULT_TOKEN and VAULT_ADDR to be set.
func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key string, outFile string, perm os.FileMode, tlsc *tls.Config, log *logrus.Entry) error {
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
ss, err := token.SigningString()
if err != nil {
return err
}

signature, err := signWithVault(ctx, tlsc, key, []byte(ss), log)
if err != nil {
return err
}

signed := fmt.Sprintf("%s.%s", ss, strings.TrimRight(base64.RawURLEncoding.EncodeToString(signature), "="))
Expand Down

0 comments on commit d478f52

Please sign in to comment.