From 062aee4d45b8837ce710ced6ab73f999b26e6513 Mon Sep 17 00:00:00 2001 From: Lennart Fleischmann <67686424+lfleischmann@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:17:52 +0100 Subject: [PATCH] fix: SAML issues (#2041) Rename identities table columns for more clarity. Rename parameters, arguments etc. to accommodate these changes. Change that the SAML provider domain is persisted in the identities table as the provider ID. Use the SAML Entity ID/Issuer ID of the IdP instead. Introduce saml identity entity (including migrations and a persister) as a specialization of an identity to allow for determining the correct provider name to return to the client/frontend and for assisting in determining whether an identity is a SAML identity (i.e. SAML identities should have a corresponding SAML Identity instance while OAuth/OIDC entities do not). --- backend/config/config_default.go | 14 ++-- backend/config/config_third_party.go | 14 ++-- backend/dto/admin/identity.go | 4 +- backend/dto/thirdparty.go | 16 ++-- backend/ee/saml/config/saml.go | 17 +++- backend/ee/saml/handler.go | 7 +- backend/ee/saml/provider/saml.go | 2 +- .../flow/shared/action_exchange_token.go | 2 +- .../flow/shared/action_thirdparty_oauth.go | 4 +- backend/flow_api/services/user.go | 4 +- backend/handler/thirdparty.go | 6 +- backend/handler/thirdparty_test.go | 12 +-- backend/persistence/identity_persister.go | 6 +- ...20250130154010_change_identities.down.fizz | 6 ++ .../20250130154010_change_identities.up.fizz | 6 ++ ...130170131_create_saml_identities.down.fizz | 1 + ...50130170131_create_saml_identities.up.fizz | 8 ++ backend/persistence/models/email.go | 9 +++ backend/persistence/models/identity.go | 43 +++++----- backend/persistence/models/saml_identity.go | 27 +++++++ backend/persistence/persister.go | 10 +++ .../persistence/saml_identity_persister.go | 46 +++++++++++ backend/persistence/user_persister.go | 4 +- .../test/fixtures/thirdparty/identities.yaml | 24 +++--- backend/thirdparty/linking.go | 79 +++++++++++++++---- backend/thirdparty/provider.go | 2 +- backend/thirdparty/provider_apple.go | 4 +- backend/thirdparty/provider_custom.go | 6 +- backend/thirdparty/provider_discord.go | 4 +- backend/thirdparty/provider_facebook.go | 4 +- backend/thirdparty/provider_github.go | 4 +- backend/thirdparty/provider_google.go | 4 +- backend/thirdparty/provider_linkedin.go | 4 +- backend/thirdparty/provider_microsoft.go | 4 +- 34 files changed, 295 insertions(+), 112 deletions(-) create mode 100644 backend/persistence/migrations/20250130154010_change_identities.down.fizz create mode 100644 backend/persistence/migrations/20250130154010_change_identities.up.fizz create mode 100644 backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz create mode 100644 backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz create mode 100644 backend/persistence/models/saml_identity.go create mode 100644 backend/persistence/saml_identity_persister.go diff --git a/backend/config/config_default.go b/backend/config/config_default.go index 932c44d6f..9d101f3b8 100644 --- a/backend/config/config_default.go +++ b/backend/config/config_default.go @@ -122,37 +122,37 @@ func DefaultConfig() *Config { Apple: ThirdPartyProvider{ DisplayName: "Apple", AllowLinking: true, - Name: "apple", + ID: "apple", }, Discord: ThirdPartyProvider{ DisplayName: "Discord", AllowLinking: true, - Name: "discord", + ID: "discord", }, LinkedIn: ThirdPartyProvider{ DisplayName: "LinkedIn", AllowLinking: true, - Name: "linkedin", + ID: "linkedin", }, Microsoft: ThirdPartyProvider{ DisplayName: "Microsoft", AllowLinking: true, - Name: "microsoft", + ID: "microsoft", }, GitHub: ThirdPartyProvider{ DisplayName: "GitHub", AllowLinking: true, - Name: "github", + ID: "github", }, Google: ThirdPartyProvider{ DisplayName: "Google", AllowLinking: true, - Name: "google", + ID: "google", }, Facebook: ThirdPartyProvider{ DisplayName: "Facebook", AllowLinking: true, - Name: "facebook", + ID: "facebook", }, }, }, diff --git a/backend/config/config_third_party.go b/backend/config/config_third_party.go index 71ae125fd..3960d8e6b 100644 --- a/backend/config/config_third_party.go +++ b/backend/config/config_third_party.go @@ -173,7 +173,7 @@ func (t *ThirdParty) PostProcess() error { for key, provider := range t.CustomProviders { // add prefix per default to ensure built-in and custom providers can be distinguished keyLower := strings.ToLower(key) - provider.Name = "custom_" + keyLower + provider.ID = "custom_" + keyLower providers[keyLower] = provider } t.CustomProviders = providers @@ -211,7 +211,7 @@ func (p *CustomThirdPartyProviders) Validate() error { if err != nil { return fmt.Errorf( "failed to validate third party provider %s: %w", - strings.TrimPrefix(v.Name, "custom_"), + strings.TrimPrefix(v.ID, "custom_"), err, ) } @@ -250,9 +250,9 @@ type CustomThirdPartyProvider struct { // // Required if `use_discovery` is false or omitted. AuthorizationEndpoint string `yaml:"authorization_endpoint" json:"authorization_endpoint,omitempty" koanf:"authorization_endpoint"` - // `name` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by + // `ID` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by // concatenating the prefix "custom_". This allows distinguishing between built-in and custom providers at runtime. - Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` + ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` // `issuer` is the provider's issuer identifier. It should be a URL that uses the "https" // scheme and has no query or fragment components. // @@ -446,9 +446,9 @@ type ThirdPartyProvider struct { // // Required if the provider is `enabled`. Secret string `yaml:"secret" json:"secret,omitempty" koanf:"secret"` - // `name` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field - // in ThirdPartyProviders. See also: CustomThirdPartyProvider.Name. - Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` + // `ID` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field + // in ThirdPartyProviders. See also: CustomThirdPartyProvider.ID. + ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` } func (ThirdPartyProvider) JSONSchemaExtend(schema *jsonschema.Schema) { diff --git a/backend/dto/admin/identity.go b/backend/dto/admin/identity.go index 5e9315f4f..90d9f90a3 100644 --- a/backend/dto/admin/identity.go +++ b/backend/dto/admin/identity.go @@ -18,8 +18,8 @@ type Identity struct { func FromIdentityModel(model models.Identity) Identity { return Identity{ ID: model.ID, - ProviderID: model.ProviderID, - ProviderName: model.ProviderName, + ProviderID: model.ProviderUserID, + ProviderName: model.ProviderID, EmailID: model.EmailID, CreatedAt: model.CreatedAt, UpdatedAt: model.UpdatedAt, diff --git a/backend/dto/thirdparty.go b/backend/dto/thirdparty.go index e527a001a..3059350f9 100644 --- a/backend/dto/thirdparty.go +++ b/backend/dto/thirdparty.go @@ -45,23 +45,29 @@ func FromIdentityModel(identity *models.Identity, cfg *config.Config) *Identity } return &Identity{ - ID: identity.ProviderID, + ID: identity.ProviderUserID, Provider: getProviderDisplayName(identity, cfg), } } func getProviderDisplayName(identity *models.Identity, cfg *config.Config) string { - if strings.HasPrefix(identity.ProviderName, "custom_") { - providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderName, "custom_") + if identity.SamlIdentity != nil { + for _, ip := range cfg.Saml.IdentityProviders { + if ip.Enabled && ip.Domain == identity.SamlIdentity.Domain { + return ip.Name + } + } + } else if strings.HasPrefix(identity.ProviderID, "custom_") { + providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderID, "custom_") return cfg.ThirdParty.CustomProviders[providerNameWithoutPrefix].DisplayName } else { s := structs.New(config.ThirdPartyProviders{}) for _, field := range s.Fields() { - if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderName) { + if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderID) { return field.Name() } } } - return strings.TrimSpace(identity.ProviderName) + return strings.TrimSpace(identity.ProviderID) } diff --git a/backend/ee/saml/config/saml.go b/backend/ee/saml/config/saml.go index 4d0774f42..94916f7d6 100644 --- a/backend/ee/saml/config/saml.go +++ b/backend/ee/saml/config/saml.go @@ -139,10 +139,21 @@ func (s *Saml) Validate() error { return errors.New("at least one SAML provider is needed") } + configuredDomains := make(map[string]int) for _, provider := range s.IdentityProviders { - validationErrors = provider.Validate() - if validationErrors != nil { - return validationErrors + if provider.Enabled { + validationErrors = provider.Validate() + if validationErrors != nil { + return validationErrors + } + + configuredDomains[provider.Domain] += 1 + } + } + + for configuredDomain, configuredDomainCount := range configuredDomains { + if configuredDomainCount > 1 { + return fmt.Errorf("provider domains must be unique, found domain %s configured %d times", configuredDomain, configuredDomainCount) } } } diff --git a/backend/ee/saml/handler.go b/backend/ee/saml/handler.go index a8a534711..285b90a97 100644 --- a/backend/ee/saml/handler.go +++ b/backend/ee/saml/handler.go @@ -155,8 +155,9 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * var samlError error samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { userdata := provider.GetUserData(assertionInfo) - - linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, state.Provider, true, state.IsFlow) + identityProviderIssuer := assertionInfo.Assertions[0].Issuer + samlDomain := provider.GetDomain() + linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow) if samlErrorTx != nil { return samlErrorTx } @@ -164,7 +165,7 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * accountLinkingResult = linkResult emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email) - identityModel := emailModel.Identities.GetIdentity(provider.GetDomain(), userdata.Metadata.Subject) + identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject) token, tokenError := models.NewToken( linkResult.User.ID, diff --git a/backend/ee/saml/provider/saml.go b/backend/ee/saml/provider/saml.go index 82eabc97a..d0635d6fa 100644 --- a/backend/ee/saml/provider/saml.go +++ b/backend/ee/saml/provider/saml.go @@ -37,7 +37,7 @@ func NewBaseSamlProvider(cfg *config.Config, idpConfig samlConfig.IdentityProvid IDPCertificateStore: &idpMetadata.certs, AssertionConsumerServiceURL: fmt.Sprintf("%s/saml/callback", cfg.Saml.Endpoint), - ServiceProviderIssuer: fmt.Sprintf("%s/saml/metadata", cfg.Saml.Endpoint), + ServiceProviderIssuer: cfg.Saml.Endpoint, ServiceProviderSLOURL: fmt.Sprintf("%s/saml/logout", cfg.Saml.Endpoint), SPKeyStore: serviceProviderCertStore, diff --git a/backend/flow_api/flow/shared/action_exchange_token.go b/backend/flow_api/flow/shared/action_exchange_token.go index 1e007b949..1031d6556 100644 --- a/backend/flow_api/flow/shared/action_exchange_token.go +++ b/backend/flow_api/flow/shared/action_exchange_token.go @@ -90,7 +90,7 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { return fmt.Errorf("failed to set login_method to the stash: %w", err) } - if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderName); err != nil { + if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { return fmt.Errorf("failed to set third_party_provider to the stash: %w", err) } diff --git a/backend/flow_api/flow/shared/action_thirdparty_oauth.go b/backend/flow_api/flow/shared/action_thirdparty_oauth.go index 4a1443056..de6086a44 100644 --- a/backend/flow_api/flow/shared/action_thirdparty_oauth.go +++ b/backend/flow_api/flow/shared/action_thirdparty_oauth.go @@ -39,7 +39,7 @@ func (a ThirdPartyOAuth) Initialize(c flowpilot.InitializationContext) { Required(true) for _, provider := range enabledThirdPartyProviders { - providerInput.AllowedValue(provider.DisplayName, provider.Name) + providerInput.AllowedValue(provider.DisplayName, provider.ID) } slices.SortFunc(enabledCustomThirdPartyProviders, func(a, b config.CustomThirdPartyProvider) bool { @@ -47,7 +47,7 @@ func (a ThirdPartyOAuth) Initialize(c flowpilot.InitializationContext) { }) for _, provider := range enabledCustomThirdPartyProviders { - providerInput.AllowedValue(provider.DisplayName, provider.Name) + providerInput.AllowedValue(provider.DisplayName, provider.ID) } c.AddInputs(flowpilot.StringInput("redirect_to").Hidden(true).Required(true), providerInput) diff --git a/backend/flow_api/services/user.go b/backend/flow_api/services/user.go index 9b13b2cd1..46930d52a 100644 --- a/backend/flow_api/services/user.go +++ b/backend/flow_api/services/user.go @@ -8,7 +8,7 @@ import ( func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool { for _, identity := range identities { - if provider := cfg.ThirdParty.Providers.Get(identity.ProviderName); provider != nil { + if provider := cfg.ThirdParty.Providers.Get(identity.ProviderID); provider != nil { return provider.Enabled } } @@ -18,7 +18,7 @@ func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool { func UserCanDoSaml(cfg config.Config, identities models.Identities) bool { for _, identity := range identities { - if provider := cfg.Saml.GetProviderByDomain(identity.ProviderName); provider != nil { + if provider := cfg.Saml.GetProviderByDomain(identity.ProviderID); provider != nil { return cfg.Saml.Enabled && provider.Enabled } } diff --git a/backend/handler/thirdparty.go b/backend/handler/thirdparty.go index d0f1c850d..3b2e073fd 100644 --- a/backend/handler/thirdparty.go +++ b/backend/handler/thirdparty.go @@ -61,7 +61,7 @@ func (h *ThirdPartyHandler) Auth(c echo.Context) error { return h.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo) } - state, err := thirdparty.GenerateState(h.cfg, provider.Name(), request.RedirectTo) + state, err := thirdparty.GenerateState(h.cfg, provider.ID(), request.RedirectTo) if err != nil { return h.redirectError(c, thirdparty.ErrorServer("could not generate state").WithCause(err), errorRedirectTo) } @@ -143,14 +143,14 @@ func (h *ThirdPartyHandler) Callback(c echo.Context) error { return thirdparty.ErrorInvalidRequest("could not retrieve user data from provider").WithCause(terr) } - linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.Name(), false, state.IsFlow) + linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.ID(), false, nil, state.IsFlow) if terr != nil { return terr } accountLinkingResult = linkingResult emailModel := linkingResult.User.Emails.GetEmailByAddress(userData.Metadata.Email) - identityModel := emailModel.Identities.GetIdentity(provider.Name(), userData.Metadata.Subject) + identityModel := emailModel.Identities.GetIdentity(provider.ID(), userData.Metadata.Subject) token, terr := models.NewToken( linkingResult.User.ID, diff --git a/backend/handler/thirdparty_test.go b/backend/handler/thirdparty_test.go index cea8e196a..812d79ad9 100644 --- a/backend/handler/thirdparty_test.go +++ b/backend/handler/thirdparty_test.go @@ -60,42 +60,42 @@ func (s *thirdPartySuite) setUpConfig(enabledProviders []string, allowedRedirect cfg.ThirdParty = config.ThirdParty{ Providers: config.ThirdPartyProviders{ Apple: config.ThirdPartyProvider{ - Name: "apple", + ID: "apple", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Google: config.ThirdPartyProvider{ - Name: "google", + ID: "google", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, GitHub: config.ThirdPartyProvider{ - Name: "github", + ID: "github", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Discord: config.ThirdPartyProvider{ - Name: "discord", + ID: "discord", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Microsoft: config.ThirdPartyProvider{ - Name: "microsoft", + ID: "microsoft", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: false, }, Facebook: config.ThirdPartyProvider{ - Name: "facebook", + ID: "facebook", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", diff --git a/backend/persistence/identity_persister.go b/backend/persistence/identity_persister.go index 6e1a07aa8..1a7e65c4b 100644 --- a/backend/persistence/identity_persister.go +++ b/backend/persistence/identity_persister.go @@ -10,7 +10,7 @@ import ( ) type IdentityPersister interface { - Get(userProviderID string, providerID string) (*models.Identity, error) + Get(providerUserID string, providerID string) (*models.Identity, error) GetByID(identityID uuid.UUID) (*models.Identity, error) Create(identity models.Identity) error Update(identity models.Identity) error @@ -32,9 +32,9 @@ func (p identityPersister) GetByID(identityID uuid.UUID) (*models.Identity, erro return identity, nil } -func (p identityPersister) Get(userProviderID string, providerID string) (*models.Identity, error) { +func (p identityPersister) Get(providerUserID string, providerID string) (*models.Identity, error) { identity := &models.Identity{} - if err := p.db.EagerPreload().Where("provider_id = ? AND provider_name = ?", userProviderID, providerID).First(identity); err != nil { + if err := p.db.EagerPreload().Where("provider_user_id = ? AND provider_id = ?", providerUserID, providerID).First(identity); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/backend/persistence/migrations/20250130154010_change_identities.down.fizz b/backend/persistence/migrations/20250130154010_change_identities.down.fizz new file mode 100644 index 000000000..19ca26dca --- /dev/null +++ b/backend/persistence/migrations/20250130154010_change_identities.down.fizz @@ -0,0 +1,6 @@ +drop_index("identities", "identities_provider_user_id_provider_id_idx") + +rename_column("identities", "provider_id", "provider_name") +rename_column("identities", "provider_user_id", "provider_id") + +add_index("identities", ["provider_id", "provider_name"], {unique: true}) diff --git a/backend/persistence/migrations/20250130154010_change_identities.up.fizz b/backend/persistence/migrations/20250130154010_change_identities.up.fizz new file mode 100644 index 000000000..206f0ab5d --- /dev/null +++ b/backend/persistence/migrations/20250130154010_change_identities.up.fizz @@ -0,0 +1,6 @@ +drop_index("identities", "identities_provider_id_provider_name_idx") + +rename_column("identities", "provider_id", "provider_user_id") +rename_column("identities", "provider_name", "provider_id") + +add_index("identities", ["provider_user_id", "provider_id"], {unique: true}) diff --git a/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz b/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz new file mode 100644 index 000000000..b7aa5e6d4 --- /dev/null +++ b/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz @@ -0,0 +1 @@ +drop_table("saml_identities") diff --git a/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz b/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz new file mode 100644 index 000000000..50fd61504 --- /dev/null +++ b/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz @@ -0,0 +1,8 @@ +create_table("saml_identities") { + t.Column("id", "uuid", {primary: true}) + t.Column("identity_id", "uuid", { "null": false }) + t.Column("domain", "string", { "null": false }) + t.Timestamps() + t.ForeignKey("identity_id", {"identities": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"}) + t.Index(["identity_id", "domain"], {"unique": true}) +} diff --git a/backend/persistence/models/email.go b/backend/persistence/models/email.go index 7349d09e0..30e1bf4a4 100644 --- a/backend/persistence/models/email.go +++ b/backend/persistence/models/email.go @@ -45,6 +45,15 @@ func (email *Email) IsPrimary() bool { return false } +func (email *Email) GetSamlIdentityForDomain(domain string) *SamlIdentity { + for _, identity := range email.Identities { + if identity.SamlIdentity != nil && identity.SamlIdentity.Domain == domain { + return identity.SamlIdentity + } + } + return nil +} + func (emails *Emails) GetVerified() Emails { var list Emails for _, email := range *emails { diff --git a/backend/persistence/models/identity.go b/backend/persistence/models/identity.go index e12777753..d88793080 100644 --- a/backend/persistence/models/identity.go +++ b/backend/persistence/models/identity.go @@ -12,21 +12,22 @@ import ( // Identity is used by pop to map your identities database table to your go code. type Identity struct { - ID uuid.UUID `json:"id" db:"id"` - ProviderID string `json:"provider_id" db:"provider_id"` - ProviderName string `json:"provider_name" db:"provider_name"` - Data slices.Map `json:"data" db:"data"` - EmailID uuid.UUID `json:"email_id" db:"email_id"` - Email *Email `json:"email,omitempty" belongs_to:"email"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + ID uuid.UUID `json:"id" db:"id"` + ProviderUserID string `json:"provider_user_id" db:"provider_user_id"` + ProviderID string `json:"provider_id" db:"provider_id"` + Data slices.Map `json:"data" db:"data"` + EmailID uuid.UUID `json:"email_id" db:"email_id"` + Email *Email `json:"email,omitempty" belongs_to:"email"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + SamlIdentity *SamlIdentity `json:"saml_identity" has_one:"saml_identity"` } type Identities []Identity -func (identities Identities) GetIdentity(providerName string, providerId string) *Identity { +func (identities Identities) GetIdentity(providerID string, providerUserID string) *Identity { for _, identity := range identities { - if identity.ProviderName == providerName && identity.ProviderID == providerId { + if identity.ProviderID == providerID && identity.ProviderUserID == providerUserID { return &identity } } @@ -34,22 +35,22 @@ func (identities Identities) GetIdentity(providerName string, providerId string) return nil } -func NewIdentity(provider string, identityData map[string]interface{}, emailID uuid.UUID) (*Identity, error) { - providerID, ok := identityData["sub"] +func NewIdentity(providerID string, identityData map[string]interface{}, emailID uuid.UUID) (*Identity, error) { + providerUserID, ok := identityData["sub"] if !ok { - return nil, errors.New("missing provider id") + return nil, errors.New("missing provider user id") } now := time.Now().UTC() id, _ := uuid.NewV4() identity := &Identity{ - ID: id, - Data: identityData, - ProviderID: providerID.(string), - ProviderName: provider, - EmailID: emailID, - CreatedAt: now, - UpdatedAt: now, + ID: id, + Data: identityData, + ProviderUserID: providerUserID.(string), + ProviderID: providerID, + EmailID: emailID, + CreatedAt: now, + UpdatedAt: now, } return identity, nil @@ -60,8 +61,8 @@ func NewIdentity(provider string, identityData map[string]interface{}, emailID u func (i *Identity) Validate(tx *pop.Connection) (*validate.Errors, error) { return validate.Validate( &validators.UUIDIsPresent{Name: "ID", Field: i.ID}, + &validators.StringIsPresent{Name: "ProviderUserID", Field: i.ProviderUserID}, &validators.StringIsPresent{Name: "ProviderID", Field: i.ProviderID}, - &validators.StringIsPresent{Name: "ProviderName", Field: i.ProviderName}, &validators.TimeIsPresent{Name: "CreatedAt", Field: i.CreatedAt}, &validators.TimeIsPresent{Name: "UpdatedAt", Field: i.UpdatedAt}, ), nil diff --git a/backend/persistence/models/saml_identity.go b/backend/persistence/models/saml_identity.go new file mode 100644 index 000000000..c58dfa511 --- /dev/null +++ b/backend/persistence/models/saml_identity.go @@ -0,0 +1,27 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "time" +) + +type SamlIdentity struct { + ID uuid.UUID `json:"id" db:"id"` + IdentityID uuid.UUID `json:"identity_id" db:"identity_id"` + Domain string `json:"domain" db:"domain"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type SamlIdentities []SamlIdentity + +func (i *SamlIdentity) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: i.ID}, + &validators.UUIDIsPresent{Name: "IdentityID", Field: i.IdentityID}, + &validators.StringIsPresent{Name: "Domain", Field: i.Domain}, + ), nil +} diff --git a/backend/persistence/persister.go b/backend/persistence/persister.go index f9742ef55..aeaa1a884 100644 --- a/backend/persistence/persister.go +++ b/backend/persistence/persister.go @@ -35,6 +35,8 @@ type Persister interface { GetSamlCertificatePersisterWithConnection(tx *pop.Connection) SamlCertificatePersister GetSamlStatePersister() SamlStatePersister GetSamlStatePersisterWithConnection(tx *pop.Connection) SamlStatePersister + GetSamlIdentityPersister() SamlIdentityPersister + GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister GetTokenPersister() TokenPersister GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister GetUserPersister() UserPersister @@ -263,6 +265,14 @@ func (p *persister) GetSamlCertificatePersisterWithConnection(tx *pop.Connection return NewSamlCertificatePersister(tx) } +func (p *persister) GetSamlIdentityPersister() SamlIdentityPersister { + return NewSamlIdentityPersister(p.DB) +} + +func (p *persister) GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister { + return NewSamlIdentityPersister(tx) +} + func (p *persister) GetWebhookPersister(tx *pop.Connection) WebhookPersister { if tx != nil { return NewWebhookPersister(tx) diff --git a/backend/persistence/saml_identity_persister.go b/backend/persistence/saml_identity_persister.go new file mode 100644 index 000000000..4554d8353 --- /dev/null +++ b/backend/persistence/saml_identity_persister.go @@ -0,0 +1,46 @@ +package persistence + +import ( + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +type SamlIdentityPersister interface { + Create(samlIdentity models.SamlIdentity) error + Update(samlIdentity models.SamlIdentity) error +} + +type samlIdentityPersister struct { + db *pop.Connection +} + +func NewSamlIdentityPersister(db *pop.Connection) SamlIdentityPersister { + return &samlIdentityPersister{db: db} +} + +func (p samlIdentityPersister) Create(samlIdentity models.SamlIdentity) error { + vErr, err := p.db.Eager().ValidateAndCreate(&samlIdentity) + if err != nil { + return fmt.Errorf("failed to store saml identity: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("saml identity object validation failed: %w", vErr) + } + + return nil +} + +func (p samlIdentityPersister) Update(samlIdentity models.SamlIdentity) error { + vErr, err := p.db.ValidateAndUpdate(&samlIdentity) + if err != nil { + return fmt.Errorf("failed to update saml identity: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("saml identity object validation failed: %w", vErr) + } + + return nil +} diff --git a/backend/persistence/user_persister.go b/backend/persistence/user_persister.go index 0803a6f3f..389880ce4 100644 --- a/backend/persistence/user_persister.go +++ b/backend/persistence/user_persister.go @@ -36,7 +36,7 @@ func (p *userPersister) Get(id uuid.UUID) (*models.User, error) { eagerPreloadFields := []string{ "Emails", "Emails.PrimaryEmail", - "Emails.Identities", + "Emails.Identities.SamlIdentity", "WebauthnCredentials", "WebauthnCredentials.Transports", "Username", @@ -57,7 +57,7 @@ func (p *userPersister) Get(id uuid.UUID) (*models.User, error) { func (p *userPersister) GetByEmailAddress(emailAddress string) (*models.User, error) { email := models.Email{} - err := p.db.Where("address = (?)", emailAddress).First(&email) + err := p.db.Eager().Where("address = (?)", emailAddress).First(&email) if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, nil diff --git a/backend/test/fixtures/thirdparty/identities.yaml b/backend/test/fixtures/thirdparty/identities.yaml index d0feca72b..5f4590237 100644 --- a/backend/test/fixtures/thirdparty/identities.yaml +++ b/backend/test/fixtures/thirdparty/identities.yaml @@ -1,41 +1,41 @@ - id: 443a984d-bb1c-46fe-b685-151bd0f017b1 - provider_id: "google_abcde" - provider_name: "google" + provider_user_id: "google_abcde" + provider_id: "google" data: '{"email":"test-with-google-identity@example.com","email_verified":true,"iss":"https://www.googleapis.com","sub":"google_abcde"}' email_id: 15ed767e-1544-4e03-a732-4bf80caa2e78 created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: b140230b-be5b-4589-9762-ec72f32d2833 - provider_id: "1234" - provider_name: "github" + provider_user_id: "1234" + provider_id: "github" data: '{"email":"test-with-github-identity@example.com","email_verified":true,"iss":"https://api.github.com","sub":"1234"}' email_id: 80939e4c-0028-4b67-aeee-98c617a20b6b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: 09b084b7-36af-41c5-b4b6-db9b8ed8385b - provider_id: "apple_abcde" - provider_name: "apple" + provider_user_id: "apple_abcde" + provider_id: "apple" data: '{"email":"test-with-apple-identity@example.com","email_verified":true,"iss":"https://appleid.apple.com","sub":"apple_abcde"}' email_id: 05ab6e1f-8dfb-4329-ae04-22571a68d96b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: 18d61f13-2789-467a-a3a6-4292c0621580 - provider_id: "discord_abcde" - provider_name: "discord" + provider_user_id: "discord_abcde" + provider_id: "discord" data: '{"email":"test-with-discord-identity@example.com","email_verified":true,"iss":"https://discord.com/api","sub":"discord_abcde"}' email_id: 09f35529-cca6-44a7-ab1d-b07e95a04e3b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: e4cd954b-8f1f-4fa2-b7e9-60a0ffa56363 - provider_id: "microsoft_abcde" - provider_name: "microsoft" + provider_user_id: "microsoft_abcde" + provider_id: "microsoft" data: '{"email":"test-with-microsoft-identity@example.com","iss":"https://login.microsoftonline.com/common","sub":"microsoft_abcde"}' email_id: d781006b-4f55-4327-bad6-55bc34b88585 created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: b6b1309d-61de-4a82-b8b8-d54db0be679b - provider_id: "facebook_abcde" - provider_name: "facebook" + provider_user_id: "facebook_abcde" + provider_id: "facebook" data: '{"email":"test-with-facebook-identity@example.com","sub":"facebook_abcde"}' email_id: d781006b-4f55-4327-bad6-55bc34b88585 created_at: 2020-12-31 23:59:59 diff --git a/backend/thirdparty/linking.go b/backend/thirdparty/linking.go index a33561929..bc4703e7c 100644 --- a/backend/thirdparty/linking.go +++ b/backend/thirdparty/linking.go @@ -3,11 +3,13 @@ package thirdparty import ( "fmt" "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/persistence" "github.com/teamhanko/hanko/backend/persistence/models" "github.com/teamhanko/hanko/backend/webhooks/events" "strings" + "time" ) type AccountLinkingResult struct { @@ -17,14 +19,14 @@ type AccountLinkingResult struct { UserCreated bool } -func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, isSaml bool, isFlow bool) (*AccountLinkingResult, error) { +func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, isSaml bool, samlDomain *string, isFlow bool) (*AccountLinkingResult, error) { if !isFlow { if cfg.Email.RequireVerification && !userData.Metadata.EmailVerified { return nil, ErrorUnverifiedProviderEmail("third party provider email must be verified") } } - identity, err := p.GetIdentityPersister().Get(userData.Metadata.Subject, providerName) + identity, err := p.GetIdentityPersister().Get(userData.Metadata.Subject, providerID) if err != nil { return nil, ErrorServer("could not get identity").WithCause(err) } @@ -36,29 +38,29 @@ func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister } if user == nil { - return signUp(tx, cfg, p, userData, providerName) + return signUp(tx, cfg, p, userData, providerID, isSaml, samlDomain) } else { - return link(tx, cfg, p, userData, providerName, user, isSaml) + return link(tx, cfg, p, userData, providerID, user, isSaml, samlDomain) } } else { return signIn(tx, cfg, p, userData, identity) } } -func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, user *models.User, isSaml bool) (*AccountLinkingResult, error) { +func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, user *models.User, isSaml bool, samlDomain *string) (*AccountLinkingResult, error) { if !isSaml { - if strings.HasPrefix(providerName, "custom_") { - provider, ok := cfg.ThirdParty.CustomProviders[strings.TrimPrefix(providerName, "custom_")] + if strings.HasPrefix(providerID, "custom_") { + provider, ok := cfg.ThirdParty.CustomProviders[strings.TrimPrefix(providerID, "custom_")] if !ok { - return nil, ErrorServer(fmt.Sprintf("unknown provider: %s", providerName)) + return nil, ErrorServer(fmt.Sprintf("unknown provider: %s", providerID)) } if !provider.AllowLinking { return nil, ErrorUserConflict("third party account linking for existing user with same email disallowed") } } else { - provider := cfg.ThirdParty.Providers.Get(providerName) + provider := cfg.ThirdParty.Providers.Get(providerID) if provider == nil { - return nil, fmt.Errorf("unknown provider: %s", providerName) + return nil, fmt.Errorf("unknown provider: %s", providerID) } if !provider.AllowLinking { @@ -74,7 +76,7 @@ func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userD return nil, ErrorServer("could not link account").WithCause(err) } - identity, err := models.NewIdentity(providerName, userDataMap, email.ID) + identity, err := models.NewIdentity(providerID, userDataMap, email.ID) if err != nil { return nil, ErrorServer("could not create identity").WithCause(err) } @@ -84,6 +86,38 @@ func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userD return nil, ErrorServer("could not create identity").WithCause(err) } + if isSaml && samlDomain != nil && *samlDomain != "" { + if existingSamlIdentity := email.GetSamlIdentityForDomain(*samlDomain); existingSamlIdentity != nil { + identityToDeleteID := existingSamlIdentity.IdentityID + existingSamlIdentity.IdentityID = identity.ID + + err = p.GetSamlIdentityPersisterWithConnection(tx).Update(*existingSamlIdentity) + if err != nil { + return nil, ErrorServer("could update saml identity").WithCause(err) + } + + err = p.GetIdentityPersisterWithConnection(tx).Delete(models.Identity{ID: identityToDeleteID}) + if err != nil { + return nil, ErrorServer("could not delete identity").WithCause(err) + } + } else { + samlIdentityID, _ := uuid.NewV4() + now := time.Now().UTC() + samlIdentity := &models.SamlIdentity{ + ID: samlIdentityID, + IdentityID: identity.ID, + Domain: *samlDomain, + CreatedAt: now, + UpdatedAt: now, + } + + err = p.GetSamlIdentityPersisterWithConnection(tx).Create(*samlIdentity) + if err != nil { + return nil, ErrorServer("could not create saml identity").WithCause(err) + } + } + } + u, terr := p.GetUserPersisterWithConnection(tx).Get(*email.UserID) if terr != nil { return nil, ErrorServer("could not get user").WithCause(terr) @@ -187,7 +221,7 @@ func signIn(tx *pop.Connection, cfg *config.Config, p persistence.Persister, use return linkingResult, nil } -func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string) (*AccountLinkingResult, error) { +func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, isSaml bool, samlDomain *string) (*AccountLinkingResult, error) { if !cfg.Account.AllowSignup { return nil, ErrorSignUpDisabled("account signup is disabled") } @@ -241,14 +275,31 @@ func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, use return nil, ErrorServer("could not link account").WithCause(err) } - identity, terr := models.NewIdentity(providerName, userDataMap, email.ID) + identity, terr := models.NewIdentity(providerID, userDataMap, email.ID) if terr != nil { return nil, ErrorServer("could not create identity").WithCause(terr) } terr = identityPersister.Create(*identity) if terr != nil { - return nil, ErrorServer("could not create identity").WithCause(terr) + return nil, ErrorServer("could not store identity").WithCause(terr) + } + + if isSaml && samlDomain != nil && *samlDomain != "" { + samlIdentityID, _ := uuid.NewV4() + now := time.Now().UTC() + samlIdentity := &models.SamlIdentity{ + ID: samlIdentityID, + IdentityID: identity.ID, + Domain: *samlDomain, + CreatedAt: now, + UpdatedAt: now, + } + + err = p.GetSamlIdentityPersisterWithConnection(tx).Create(*samlIdentity) + if err != nil { + return nil, ErrorServer("could not store saml identity").WithCause(err) + } } u, terr := userPersister.Get(*email.UserID) diff --git a/backend/thirdparty/provider.go b/backend/thirdparty/provider.go index 02a9b516f..a4759b1fa 100644 --- a/backend/thirdparty/provider.go +++ b/backend/thirdparty/provider.go @@ -85,7 +85,7 @@ type OAuthProvider interface { AuthCodeURL(string, ...oauth2.AuthCodeOption) string GetUserData(*oauth2.Token) (*UserData, error) GetOAuthToken(string) (*oauth2.Token, error) - Name() string + ID() string } func GetProvider(config config.ThirdParty, id string) (OAuthProvider, error) { diff --git a/backend/thirdparty/provider_apple.go b/backend/thirdparty/provider_apple.go index 37e9bd603..8cf59a3f0 100644 --- a/backend/thirdparty/provider_apple.go +++ b/backend/thirdparty/provider_apple.go @@ -116,6 +116,6 @@ func (a appleProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return userData, nil } -func (a appleProvider) Name() string { - return a.config.Name +func (a appleProvider) ID() string { + return a.config.ID } diff --git a/backend/thirdparty/provider_custom.go b/backend/thirdparty/provider_custom.go index d5bbf7d6b..0c499dfe4 100644 --- a/backend/thirdparty/provider_custom.go +++ b/backend/thirdparty/provider_custom.go @@ -17,7 +17,7 @@ type customProvider struct { func NewCustomThirdPartyProvider(config *config.CustomThirdPartyProvider, redirectURL string) (OAuthProvider, error) { if !config.Enabled { - return nil, fmt.Errorf("provider %s is disabled", config.Name) + return nil, fmt.Errorf("provider %s is disabled", config.ID) } customProvider := &customProvider{ @@ -100,6 +100,6 @@ func (p customProvider) GetUserData(token *oauth2.Token) (*UserData, error) { }, nil } -func (p customProvider) Name() string { - return p.config.Name +func (p customProvider) ID() string { + return p.config.ID } diff --git a/backend/thirdparty/provider_discord.go b/backend/thirdparty/provider_discord.go index ddfeab2dd..16bb26e80 100644 --- a/backend/thirdparty/provider_discord.go +++ b/backend/thirdparty/provider_discord.go @@ -104,6 +104,6 @@ func (g discordProvider) buildAvatarURL(userID string, avatarHash string) string return fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", userID, avatarHash) } -func (g discordProvider) Name() string { - return g.config.Name +func (g discordProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_facebook.go b/backend/thirdparty/provider_facebook.go index 6c3c9f637..158b1e6be 100644 --- a/backend/thirdparty/provider_facebook.go +++ b/backend/thirdparty/provider_facebook.go @@ -124,6 +124,6 @@ func (f facebookProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (f facebookProvider) Name() string { - return f.config.Name +func (f facebookProvider) ID() string { + return f.config.ID } diff --git a/backend/thirdparty/provider_github.go b/backend/thirdparty/provider_github.go index 1815ccb86..33fedff37 100644 --- a/backend/thirdparty/provider_github.go +++ b/backend/thirdparty/provider_github.go @@ -114,6 +114,6 @@ func (g githubProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g githubProvider) Name() string { - return g.config.Name +func (g githubProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_google.go b/backend/thirdparty/provider_google.go index 0f5f99c46..a6ad38c09 100644 --- a/backend/thirdparty/provider_google.go +++ b/backend/thirdparty/provider_google.go @@ -93,6 +93,6 @@ func (g googleProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g googleProvider) Name() string { - return g.config.Name +func (g googleProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_linkedin.go b/backend/thirdparty/provider_linkedin.go index 381f20645..9dd3ee56a 100644 --- a/backend/thirdparty/provider_linkedin.go +++ b/backend/thirdparty/provider_linkedin.go @@ -107,6 +107,6 @@ func (g linkedInProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g linkedInProvider) Name() string { - return g.config.Name +func (g linkedInProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_microsoft.go b/backend/thirdparty/provider_microsoft.go index 9bcf240e6..6ff753d61 100644 --- a/backend/thirdparty/provider_microsoft.go +++ b/backend/thirdparty/provider_microsoft.go @@ -159,8 +159,8 @@ func (p microsoftProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (p microsoftProvider) Name() string { - return p.config.Name +func (p microsoftProvider) ID() string { + return p.config.ID } func (p microsoftProvider) issuerValidator() jwt.ValidatorFunc {