Skip to content

Commit

Permalink
Merge pull request #40 from andrewheberle/multi-sp
Browse files Browse the repository at this point in the history
Multi sp
  • Loading branch information
andrewheberle authored Sep 20, 2024
2 parents fbb7aee + 5a356fb commit da25c8a
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 141 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ AUTH_IDP_METADATA=https://idp.example.net/metadata \
```
--cert string HTTPS Certificate
--db-connection string Database connection string
--db-prefix string Database table prefix
--debug Enable debug logging
-h, --help help for http-auth-server
--idp-certificate string IdP Certificate/Public Key
Expand Down
10 changes: 10 additions & 0 deletions config_multiple.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
service-providers:
- sp-url: http://localhost:9091/a
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
- name: b
sp-url: http://localhost:9091/b
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
6 changes: 6 additions & 0 deletions config_one.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
service-providers:
- name: one
sp-url: http://localhost:9091
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
3 changes: 3 additions & 0 deletions config_single.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sp-cert: ./samlsp.crt
sp-key: ./samlsp.key
idp-metadata: https://mocksaml.com/api/saml/metadata
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ require (
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
gitlab.com/andrewheberle/routerswapper v1.2.0
)

require (
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
gitlab.com/andrewheberle/routerswapper v1.2.0 h1:43e23lnlcTI31DoI/4HP2aw27WCgsghLCcezgCCraz0=
gitlab.com/andrewheberle/routerswapper v1.2.0/go.mod h1:olw/7+vGWD6II0k84qQuevoj46o5DIcG1OvM9MmyW5Q=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
Expand Down
234 changes: 136 additions & 98 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"gitlab.com/andrewheberle/routerswapper"
)

var rootCmd = &cobra.Command{
Expand Down Expand Up @@ -51,13 +50,14 @@ func init() {
rootCmd.Flags().String("idp-certificate", "", "IdP Certificate/Public Key")
rootCmd.Flags().String("db-connection", "", "Database connection string")
rootCmd.Flags().String("db-prefix", "", "Database table prefix")
rootCmd.Flags().StringP("config", "c", "", "Configuration file")
rootCmd.Flags().Bool("debug", false, "Enable debug logging")

// flag requirements
rootCmd.MarkFlagsRequiredTogether("cert", "key")
rootCmd.MarkFlagsRequiredTogether("sp-cert", "sp-key")
rootCmd.MarkFlagRequired("sp-cert")
rootCmd.MarkFlagRequired("sp-key")
rootCmd.MarkFlagsRequiredTogether("cert", "key")
rootCmd.MarkFlagsRequiredTogether("idp-issuer", "idp-sso-endpoint", "idp-certificate")
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-issuer")
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-sso-endpoint")
Expand All @@ -74,14 +74,45 @@ func initConfig() {
// bind flags to viper
viper.BindPFlags(rootCmd.Flags())

// set any flags found in environment via viper
// load config file if flag is set
if config := viper.GetString("config"); config != "" {
viper.SetConfigFile(config)
if err := viper.ReadInConfig(); err != nil {
slog.Error("problem loading configuration", "error", err)
os.Exit(1)
}

// set sp-cert and sp-key to something just to allow things to work when using multiple SP's
for _, name := range []string{"sp-cert", "sp-key"} {
if !viper.IsSet(name) {
rootCmd.Flags().Set(name, "unused")
}
}
}

// set any flags found in environment/config via viper
rootCmd.Flags().VisitAll(func(f *pflag.Flag) {
if viper.IsSet(f.Name) && viper.GetString(f.Name) != "" {
slog.Info("setting flag", "name", f.Name, "value", viper.GetString(f.Name))
rootCmd.Flags().Set(f.Name, viper.GetString(f.Name))
}
})
}

type serviceProvider struct {
Name string `mapstructure:"name"`
ServiceProviderURL string `mapstructure:"sp-url"`
ServiceProviderClaimMapping map[string]string `mapstructure:"sp-claim-mapping"`
ServiceProviderCertificate string `mapstructure:"sp-cert"`
ServiceProviderKey string `mapstructure:"sp-key"`
IdPMetadata string `mapstructure:"idp-metadata"`
IdPIssuer string `mapstructure:"idp-issuer"`
IdPSSOEndpoint string `mapstructure:"idp-sso-endpoint"`
IdPCertificate string `mapstructure:"idp-certificate"`
DatabaseConnection string `mapstructure:"db-connection"`
DatabaseTablePrefix string `mapstructure:"db-prefix"`
}

func runRootCmd() error {
// logging setup
var logLevel = new(slog.LevelVar)
Expand All @@ -91,76 +122,131 @@ func runRootCmd() error {
logLevel.Set(slog.LevelDebug)
}

// validate service provider root url
root, err := url.Parse(viper.GetString("sp-url"))
if err != nil {
return fmt.Errorf("problem with SP URL: %w", err)
}
// did we load in via a config file
var serviceProviders []serviceProvider
if viper.ConfigFileUsed() != "" {
// has a list of service providers been provided?
if viper.Get("service-providers") != nil {
if err := viper.UnmarshalKey("service-providers", &serviceProviders); err != nil {
return fmt.Errorf("error with service providers list: %w", err)
}
} else {
var sp serviceProvider
if err := viper.Unmarshal(&sp); err != nil {
return fmt.Errorf("error with service provider: %w", err)
}

// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
serviceProviders = []serviceProvider{sp}
}
}

// handle metadata
if m := viper.GetString("idp-metadata"); m != "" {
metadata, err := url.Parse(m)
// create run group
g := run.Group{}

// new mux
mux := http.NewServeMux()

// set up service provider(s)
for _, spConfig := range serviceProviders {
// validate service provider root url
root, err := url.Parse(spConfig.ServiceProviderURL)
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
return fmt.Errorf("problem with SP URL: %w", err)
}

opts = append(opts, sp.WithMetadataURL(metadata))
} else {
metadata := sp.ServiceProviderMetadata{
Issuer: viper.GetString("idp-issuer"),
Endpoint: viper.GetString("idp-sso-endpoint"),
NameId: "persistent",
Certificate: viper.GetString("idp-certificate"),
// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(spConfig.ServiceProviderClaimMapping),
}

opts = append(opts, sp.WithCustomMetadata(metadata))
}
// handle metadata
if spConfig.IdPMetadata != "" {
metadata, err := url.Parse(spConfig.IdPMetadata)
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
}

// are we using a database for storing session attributes
if dsn := viper.GetString("db-connection"); dsn != "" {
store, err := sp.NewDbAttributeStore(viper.GetString("db-prefix"), dsn)
opts = append(opts, sp.WithMetadataURL(metadata))
} else {
metadata := sp.ServiceProviderMetadata{
Issuer: spConfig.IdPIssuer,
Endpoint: spConfig.IdPSSOEndpoint,
Certificate: spConfig.IdPCertificate,
}

opts = append(opts, sp.WithCustomMetadata(metadata))
}

// are we using a database for storing session attributes
if dsn := spConfig.DatabaseConnection; dsn != "" {
store, err := sp.NewDbAttributeStore(spConfig.DatabaseTablePrefix, dsn)
if err != nil {
return fmt.Errorf("problem setting up db attribute store: %w", err)
}
defer store.Close()

opts = append(opts, sp.WithAttributeStore(store))
}

// set Service Provider name if provided
if spConfig.Name != "" {
opts = append(opts, sp.WithName(spConfig.Name))
}

// set up auth provider
provider, err := sp.NewServiceProvider(spConfig.ServiceProviderCertificate, spConfig.ServiceProviderKey, root, opts...)
if err != nil {
return fmt.Errorf("problem setting up db attribute store: %w", err)
return fmt.Errorf("problem setting up SP: %w", err)
}
defer store.Close()

opts = append(opts, sp.WithAttributeStore(store))
}
// set up refresh/reload of service provider metdata
if spConfig.IdPMetadata != "" {
quit := make(chan struct{})
g.Add(func() error {
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
for {
select {
case <-quit:
return nil
default:
if err := provider.RefreshMetadata(); err != nil {
// not a fatal error
slog.Error("saml service provider reload", "error", err)
continue
}
}

// set up auth provider
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
if err != nil {
return fmt.Errorf("problem setting up SP: %w", err)
}
// some logging
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
}
}, func(err error) {
slog.Info("service provider refresh", "action", "shutting down")
close(quit)
})
}

// new server mux
mux := sp.NewMux(provider)
// new server mux
if err := provider.NewMux(mux); err != nil {
return fmt.Errorf("error setting up mux: %w", err)
}

// allow swapping of mux
rs := routerswapper.New(mux)
slog.Info("set up service provider",
"acs-url", provider.AcsURL().String(),
"metdata-url", provider.MetadataURL().String(),
"logout-url", provider.LogoutUrl().String(),
"name", spConfig.Name,
)
}

// set up server
srv := &http.Server{
Addr: viper.GetString("listen"),
Handler: rs,
Handler: mux,
ReadTimeout: time.Second * 3,
WriteTimeout: time.Second * 3,
}

slog.Info("starting service",
"listen", srv.Addr,
"sp-acs-url", provider.AcsURL().String(),
"sp-metdata-url", provider.MetadataURL().String(),
"sp-logout-url", provider.LogoutUrl().String(),
)

// create run group
g := run.Group{}
slog.Info("starting service", "listen", srv.Addr)

// add http server
if viper.GetString("cert") == "" && viper.GetString("key") == "" {
Expand Down Expand Up @@ -213,54 +299,6 @@ func runRootCmd() error {
})
}

// set up refresh/reload of service provider metdata
if viper.GetString("idp-metadata") != "" {
quit := make(chan struct{})
g.Add(func() error {
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
for {
select {
case <-quit:
return nil
default:
time.Sleep(time.Hour * 24)

// parse url
metadata, _ := url.Parse(viper.GetString("idp-metadata"))
if err != nil {
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
}

// set up service provider options
opts := []sp.ServiceProviderOption{
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
sp.WithMetadataURL(metadata),
}

// set up provider
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
if err != nil {
// not a fatal error
slog.Error("saml service provider reload", "error", err)
continue
}

// new server mux
mux := sp.NewMux(provider)

// swap to new mux
rs.Swap(mux)
}

// some logging
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
}
}, func(err error) {
slog.Info("service provider refresh", "action", "shutting down")
close(quit)
})
}

if err := g.Run(); err != nil {
return fmt.Errorf("problem while running: %w", err)
}
Expand Down
17 changes: 15 additions & 2 deletions pkg/sp/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type ServiceProviderOption func(*ServiceProvider)

func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
return func(s *ServiceProvider) {
// populate metadata either from a metadata URL or from custom values
// populate metadata from a metadata URL
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

Expand All @@ -26,13 +26,14 @@ func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
}

s.idpMetadata = idpMetadata
s.idpMetadataURL = metadata
}
}

func WithCustomMetadata(metadata ServiceProviderMetadata) ServiceProviderOption {
return func(s *ServiceProvider) {
// build metadata from provided values
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.NameId, metadata.Certificate)
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.Certificate)
if err != nil {
slog.Error("metadata build error", "error", err)
return
Expand All @@ -59,3 +60,15 @@ func WithAttributeStore(store AttributeStore) ServiceProviderOption {
s.store = store
}
}

func WithMetadataRefreshInterval(d time.Duration) ServiceProviderOption {
return func(s *ServiceProvider) {
s.idpMetadataRefreshInterval = d
}
}

func WithName(name string) ServiceProviderOption {
return func(s *ServiceProvider) {
s.name = name
}
}
Loading

0 comments on commit da25c8a

Please sign in to comment.