Skip to content

Commit

Permalink
Add custom requeste tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheberle committed Dec 18, 2023
1 parent ac4632d commit c514b17
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 6 deletions.
10 changes: 4 additions & 6 deletions pkg/sp/sp.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func NewServiceProvider(cert, key string, metadata interface{}, root *url.URL, m
// set SHA256 as the signature method
mw.ServiceProvider.SignatureMethod = dsig.RSASHA256SignatureMethod

// use custom request tracker
tracker := DefaultRequestTracker(opts, &mw.ServiceProvider)
mw.RequestTracker = tracker

// set up custom session provider
if err := setSessionProvider(root, mw); err != nil {
return nil, fmt.Errorf("session provider error: %w", err)
Expand Down Expand Up @@ -325,12 +329,6 @@ func (s *ServiceProvider) doAuthFlow(w http.ResponseWriter, r *http.Request) {
// transfer headers to response
for header, v := range rr.Result().Header {
for _, item := range v {
if header == "Set-Cookie" {
// add Domain to cookie if not set
if !strings.Contains(item, "Domain=") {
item = item + "; Domain=" + s.mw.Session.(samlsp.CookieSessionProvider).Domain
}
}
w.Header().Add(header, item)
}
}
Expand Down
137 changes: 137 additions & 0 deletions pkg/sp/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package sp

import (
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
)

func DefaultRequestTracker(opts samlsp.Options, serviceProvider *saml.ServiceProvider) CookieRequestTracker {
return CookieRequestTracker{
ServiceProvider: serviceProvider,
NamePrefix: "saml_",
Codec: samlsp.DefaultTrackedRequestCodec(opts),
MaxAge: saml.MaxIssueDelay,
RelayStateFunc: opts.RelayStateFunc,
SameSite: opts.CookieSameSite,
CookieDomain: getDomain(&serviceProvider.AcsURL),
}
}

// CookieRequestTracker tracks requests by setting a uniquely named
// cookie for each request.
//
// This implementation is idenitical to samlsp.CookieRequestTracker apart
// from the addition of setting the CookieDomain for the tracker cookie.
type CookieRequestTracker struct {
ServiceProvider *saml.ServiceProvider
NamePrefix string
Codec samlsp.TrackedRequestCodec
MaxAge time.Duration
RelayStateFunc func(w http.ResponseWriter, r *http.Request) string
SameSite http.SameSite
CookieDomain string
}

// TrackRequest starts tracking the SAML request with the given ID. It returns an
// `index` that should be used as the RelayState in the SAMl request flow.
func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) {
trackedRequest := samlsp.TrackedRequest{
Index: base64.RawURLEncoding.EncodeToString(randomBytes(42)),
SAMLRequestID: samlRequestID,
URI: r.URL.String(),
}

if t.RelayStateFunc != nil {
relayState := t.RelayStateFunc(w, r)
if relayState != "" {
trackedRequest.Index = relayState
}
}

signedTrackedRequest, err := t.Codec.Encode(trackedRequest)
if err != nil {
return "", err
}

http.SetCookie(w, &http.Cookie{
Name: t.NamePrefix + trackedRequest.Index,
Value: signedTrackedRequest,
MaxAge: int(t.MaxAge.Seconds()),
HttpOnly: true,
SameSite: t.SameSite,
Secure: t.ServiceProvider.AcsURL.Scheme == "https",
Path: t.ServiceProvider.AcsURL.Path,
Domain: t.CookieDomain,
})

return trackedRequest.Index, nil
}

// StopTrackingRequest stops tracking the SAML request given by index, which is a string
// previously returned from TrackRequest
func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
cookie, err := r.Cookie(t.NamePrefix + index)
if err != nil {
return err
}
cookie.Value = ""
cookie.Domain = t.CookieDomain
cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
http.SetCookie(w, cookie)
return nil
}

// GetTrackedRequests returns all the pending tracked requests
func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
rv := []samlsp.TrackedRequest{}
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, t.NamePrefix) {
continue
}

trackedRequest, err := t.Codec.Decode(cookie.Value)
if err != nil {
continue
}
index := strings.TrimPrefix(cookie.Name, t.NamePrefix)
if index != trackedRequest.Index {
continue
}

rv = append(rv, *trackedRequest)
}
return rv
}

// GetTrackedRequest returns a pending tracked request.
func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
cookie, err := r.Cookie(t.NamePrefix + index)
if err != nil {
return nil, err
}

trackedRequest, err := t.Codec.Decode(cookie.Value)
if err != nil {
return nil, err
}
if trackedRequest.Index != index {
return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index)
}
return trackedRequest, nil
}

func randomBytes(n int) []byte {
rv := make([]byte, n)

if _, err := io.ReadFull(saml.RandReader, rv); err != nil {
panic(err)
}
return rv
}

0 comments on commit c514b17

Please sign in to comment.