Skip to content

Commit

Permalink
Merge pull request #7 from ripienaar/6
Browse files Browse the repository at this point in the history
(#6) support adding chain issuer data when issuer is in vault
  • Loading branch information
ripienaar authored Jan 18, 2023
2 parents de9d178 + d478f52 commit 75e1631
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codeql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v3

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
Expand Down
40 changes: 33 additions & 7 deletions standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
package tokens

import (
"context"
"crypto/ed25519"
"crypto/tls"
"encoding/hex"
"fmt"
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/segmentio/ksuid"
"github.com/sirupsen/logrus"
)

type StandardClaims struct {
Expand All @@ -32,13 +35,13 @@ type StandardClaims struct {
}

// ExpireTime determines the expiry time based on issuer expiry and token expiry
func (s *StandardClaims) ExpireTime() time.Time {
func (c *StandardClaims) ExpireTime() time.Time {
var iexp, exp time.Time
if s.IssuerExpiresAt != nil {
iexp = s.IssuerExpiresAt.Time
if c.IssuerExpiresAt != nil {
iexp = c.IssuerExpiresAt.Time
}
if s.ExpiresAt != nil {
exp = s.ExpiresAt.Time
if c.ExpiresAt != nil {
exp = c.ExpiresAt.Time
}

if iexp.IsZero() {
Expand All @@ -57,8 +60,8 @@ func (s *StandardClaims) ExpireTime() time.Time {
}

// IsExpired checks if the token has expired
func (s *StandardClaims) IsExpired() bool {
return time.Now().After(s.ExpireTime())
func (c *StandardClaims) IsExpired() bool {
return time.Now().After(c.ExpireTime())
}

// AddOrgIssuerData adds the data that a Chain Issuer needs to be able to issue clients in an Org managed by an Issuer
Expand All @@ -79,6 +82,29 @@ func (c *StandardClaims) AddOrgIssuerData(priK ed25519.PrivateKey) error {
return nil
}

// AddOrgIssuerDataUsingVault adds the data that a Chain Issuer needs to be able to issue clients in an Org managed by an Issuer by using Vault to sign the data using the named key `priK`.
func (c *StandardClaims) AddOrgIssuerDataUsingVault(ctx context.Context, tlsc *tls.Config, priK string, log *logrus.Entry) error {
issuer, err := getVaultIssuerPubKey(ctx, tlsc, priK, log)
if err != nil {
return err
}

dat, err := c.OrgIssuerChainData()
if err != nil {
return err
}

sig, err := signWithVault(ctx, tlsc, priK, dat, log)
if err != nil {
return err
}

c.SetOrgIssuer(issuer)
c.SetChainIssuerTrustSignature(sig)

return nil
}

// AddChainIssuerData adds the data that a Signed token needs from a Chain Issuer in an Org managed by an Issuer
func (c *StandardClaims) AddChainIssuerData(chainIssuer *ClientIDClaims, prik ed25519.PrivateKey) error {
err := c.SetChainIssuer(chainIssuer)
Expand Down
124 changes: 103 additions & 21 deletions tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ func ParseToken(token string, claims jwt.Claims, pk any) error {

var sc *StandardClaims

// if its a client and from a chain we will verify it using the chain issuer pubk
// if it's a client and from a chain we will verify it using the chain issuer pubk
client, ok := claims.(*ClientIDClaims)
if ok && strings.HasPrefix(client.Issuer, ChainIssuerPrefix) {
sc = &client.StandardClaims
}

// if its a server and from a chain we will verify it using the chain issuer pubk
// if it's a server and from a chain we will verify it using the chain issuer pubk
server, ok := claims.(*ServerClaims)
if ok && strings.HasPrefix(server.Issuer, ChainIssuerPrefix) {
sc = &server.StandardClaims
Expand Down Expand Up @@ -161,7 +161,7 @@ func TokenPurposeBytes(token []byte) Purpose {
return TokenPurpose(string(token))
}

// SignTokenWithKeyFile signs a JWT using a RSA Private Key in PEM format
// SignTokenWithKeyFile signs a JWT using an RSA Private Key in PEM format
func SignTokenWithKeyFile(claims jwt.Claims, pkFile string) (string, error) {
keydat, err := os.ReadFile(pkFile)
if err != nil {
Expand Down Expand Up @@ -189,7 +189,7 @@ func SignTokenWithKeyFile(claims jwt.Claims, pkFile string) (string, error) {
return "", fmt.Errorf("unsupported key in %v", pkFile)
}

// SignToken signs a JWT using a RSA Private Key
// SignToken signs a JWT using an RSA Private Key
func SignToken(claims jwt.Claims, pk any) (string, error) {
var stoken string
var err error
Expand Down Expand Up @@ -224,34 +224,98 @@ func SaveAndSignTokenWithKeyFile(claims jwt.Claims, pkFile string, outFile strin
return os.WriteFile(outFile, []byte(token), perm)
}

// SaveAndSignTokenWithVault signs a token using the named key in a Vault Transit engine. Requires VAULT_TOKEN and VAULT_ADDR to be set.
func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key string, outFile string, perm os.FileMode, tlsc *tls.Config, log *logrus.Entry) error {
func getVaultIssuerPubKey(ctx context.Context, tlsc *tls.Config, key string, log *logrus.Entry) (ed25519.PublicKey, error) {
vt := os.Getenv("VAULT_TOKEN")
va := os.Getenv("VAULT_ADDR")

if vt == "" || va == "" {
return fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
return nil, fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
}

token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
ss, err := token.SigningString()
uri, err := url.Parse(va)
if err != nil {
return err
return nil, err
}

uri.Path = fmt.Sprintf("/v1/transit/keys/%s", key)
client := &http.Client{}
if tlsc != nil {
client.Transport = &http.Transport{TLSClientConfig: tlsc}
}

req, err := http.NewRequestWithContext(ctx, "GET", uri.String(), nil)
if err != nil {
return nil, err
}
req.Header.Add("X-Vault-Token", vt)

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if resp.StatusCode != 200 {
return nil, fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
}

log.Debugf("JSON Response: %s", string(body))

var vr struct {
Data struct {
Keys map[string]struct {
PublicKey []byte `json:"public_key"`
} `json:"keys"`
} `json:"data"`
}
err = json.Unmarshal(body, &vr)
if err != nil {
return nil, err
}

if len(vr.Data.Keys) == 0 {
return nil, fmt.Errorf("did not receive keys in response")
}

pk, ok := vr.Data.Keys["1"]
if !ok {
return nil, fmt.Errorf("did not receive keys in response")
}

if len(pk.PublicKey) != ed25519.PublicKeySize {
return nil, fmt.Errorf("did not receive a valid public key in response")
}

return ed25519.PublicKey(pk.PublicKey), nil
}

func signWithVault(ctx context.Context, tlsc *tls.Config, key string, ss []byte, log *logrus.Entry) ([]byte, error) {
vt := os.Getenv("VAULT_TOKEN")
va := os.Getenv("VAULT_ADDR")

if vt == "" || va == "" {
return nil, fmt.Errorf("requires VAULT_TOKEN and VAULT_ADDR environment variables")
}

uri, err := url.Parse(va)
if err != nil {
return err
return nil, err
}

uri.Path = fmt.Sprintf("/v1/transit/sign/%s", key)

dat := map[string]any{
"signature_algorithm": "ed25519",
"input": base64.StdEncoding.EncodeToString([]byte(ss)),
"input": base64.StdEncoding.EncodeToString(ss),
}
jdat, err := json.Marshal(dat)
if err != nil {
return err
return nil, err
}
log.Debugf("JSON Request: %s", string(jdat))

Expand All @@ -262,48 +326,66 @@ func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key strin

req, err := http.NewRequestWithContext(ctx, "POST", uri.String(), bytes.NewBuffer(jdat))
if err != nil {
return err
return nil, err
}
req.Header.Add("X-Vault-Token", vt)

resp, err := client.Do(req)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err
return nil, err
}

if resp.StatusCode != 200 {
return fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("request failed: code: %d: %s", resp.StatusCode, string(body))
}

log.Debugf("JSON Response: %s", string(body))

var vr struct {
Data struct {
Sig string `json:"signature"`
} `json:"data"`
}
err = json.Unmarshal(body, &vr)
if err != nil {
return err
return nil, err
}

if vr.Data.Sig == "" {
return fmt.Errorf("no signature in response: %s", string(body))
return nil, fmt.Errorf("no signature in response: %s", string(body))
}

const vaultSigPrefix = "vault:v1:"

if !strings.HasPrefix(vr.Data.Sig, vaultSigPrefix) {
return fmt.Errorf("invalid signature, no vault:v1 prefix")
return nil, fmt.Errorf("invalid signature, no vault:v1 prefix")
}

signature, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(vr.Data.Sig, vaultSigPrefix))
if err != nil {
return fmt.Errorf("could not decode vault response: %w", err)
return nil, fmt.Errorf("could not decode vault response: %w", err)
}

return signature, nil
}

// SaveAndSignTokenWithVault signs a token using the named key in a Vault Transit engine. Requires VAULT_TOKEN and VAULT_ADDR to be set.
func SaveAndSignTokenWithVault(ctx context.Context, claims jwt.Claims, key string, outFile string, perm os.FileMode, tlsc *tls.Config, log *logrus.Entry) error {
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
ss, err := token.SigningString()
if err != nil {
return err
}

signature, err := signWithVault(ctx, tlsc, key, []byte(ss), log)
if err != nil {
return err
}

signed := fmt.Sprintf("%s.%s", ss, strings.TrimRight(base64.RawURLEncoding.EncodeToString(signature), "="))
Expand Down

0 comments on commit 75e1631

Please sign in to comment.