diff --git a/backend/config/config_default.go b/backend/config/config_default.go index 932c44d6f..9d101f3b8 100644 --- a/backend/config/config_default.go +++ b/backend/config/config_default.go @@ -122,37 +122,37 @@ func DefaultConfig() *Config { Apple: ThirdPartyProvider{ DisplayName: "Apple", AllowLinking: true, - Name: "apple", + ID: "apple", }, Discord: ThirdPartyProvider{ DisplayName: "Discord", AllowLinking: true, - Name: "discord", + ID: "discord", }, LinkedIn: ThirdPartyProvider{ DisplayName: "LinkedIn", AllowLinking: true, - Name: "linkedin", + ID: "linkedin", }, Microsoft: ThirdPartyProvider{ DisplayName: "Microsoft", AllowLinking: true, - Name: "microsoft", + ID: "microsoft", }, GitHub: ThirdPartyProvider{ DisplayName: "GitHub", AllowLinking: true, - Name: "github", + ID: "github", }, Google: ThirdPartyProvider{ DisplayName: "Google", AllowLinking: true, - Name: "google", + ID: "google", }, Facebook: ThirdPartyProvider{ DisplayName: "Facebook", AllowLinking: true, - Name: "facebook", + ID: "facebook", }, }, }, diff --git a/backend/config/config_third_party.go b/backend/config/config_third_party.go index 71ae125fd..3960d8e6b 100644 --- a/backend/config/config_third_party.go +++ b/backend/config/config_third_party.go @@ -173,7 +173,7 @@ func (t *ThirdParty) PostProcess() error { for key, provider := range t.CustomProviders { // add prefix per default to ensure built-in and custom providers can be distinguished keyLower := strings.ToLower(key) - provider.Name = "custom_" + keyLower + provider.ID = "custom_" + keyLower providers[keyLower] = provider } t.CustomProviders = providers @@ -211,7 +211,7 @@ func (p *CustomThirdPartyProviders) Validate() error { if err != nil { return fmt.Errorf( "failed to validate third party provider %s: %w", - strings.TrimPrefix(v.Name, "custom_"), + strings.TrimPrefix(v.ID, "custom_"), err, ) } @@ -250,9 +250,9 @@ type CustomThirdPartyProvider struct { // // Required if `use_discovery` is false or omitted. AuthorizationEndpoint string `yaml:"authorization_endpoint" json:"authorization_endpoint,omitempty" koanf:"authorization_endpoint"` - // `name` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by + // `ID` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by // concatenating the prefix "custom_". This allows distinguishing between built-in and custom providers at runtime. - Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` + ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` // `issuer` is the provider's issuer identifier. It should be a URL that uses the "https" // scheme and has no query or fragment components. // @@ -446,9 +446,9 @@ type ThirdPartyProvider struct { // // Required if the provider is `enabled`. Secret string `yaml:"secret" json:"secret,omitempty" koanf:"secret"` - // `name` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field - // in ThirdPartyProviders. See also: CustomThirdPartyProvider.Name. - Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` + // `ID` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field + // in ThirdPartyProviders. See also: CustomThirdPartyProvider.ID. + ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"` } func (ThirdPartyProvider) JSONSchemaExtend(schema *jsonschema.Schema) { diff --git a/backend/dto/admin/identity.go b/backend/dto/admin/identity.go index 5e9315f4f..90d9f90a3 100644 --- a/backend/dto/admin/identity.go +++ b/backend/dto/admin/identity.go @@ -18,8 +18,8 @@ type Identity struct { func FromIdentityModel(model models.Identity) Identity { return Identity{ ID: model.ID, - ProviderID: model.ProviderID, - ProviderName: model.ProviderName, + ProviderID: model.ProviderUserID, + ProviderName: model.ProviderID, EmailID: model.EmailID, CreatedAt: model.CreatedAt, UpdatedAt: model.UpdatedAt, diff --git a/backend/dto/thirdparty.go b/backend/dto/thirdparty.go index e527a001a..3059350f9 100644 --- a/backend/dto/thirdparty.go +++ b/backend/dto/thirdparty.go @@ -45,23 +45,29 @@ func FromIdentityModel(identity *models.Identity, cfg *config.Config) *Identity } return &Identity{ - ID: identity.ProviderID, + ID: identity.ProviderUserID, Provider: getProviderDisplayName(identity, cfg), } } func getProviderDisplayName(identity *models.Identity, cfg *config.Config) string { - if strings.HasPrefix(identity.ProviderName, "custom_") { - providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderName, "custom_") + if identity.SamlIdentity != nil { + for _, ip := range cfg.Saml.IdentityProviders { + if ip.Enabled && ip.Domain == identity.SamlIdentity.Domain { + return ip.Name + } + } + } else if strings.HasPrefix(identity.ProviderID, "custom_") { + providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderID, "custom_") return cfg.ThirdParty.CustomProviders[providerNameWithoutPrefix].DisplayName } else { s := structs.New(config.ThirdPartyProviders{}) for _, field := range s.Fields() { - if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderName) { + if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderID) { return field.Name() } } } - return strings.TrimSpace(identity.ProviderName) + return strings.TrimSpace(identity.ProviderID) } diff --git a/backend/ee/saml/config/saml.go b/backend/ee/saml/config/saml.go index 4d0774f42..94916f7d6 100644 --- a/backend/ee/saml/config/saml.go +++ b/backend/ee/saml/config/saml.go @@ -139,10 +139,21 @@ func (s *Saml) Validate() error { return errors.New("at least one SAML provider is needed") } + configuredDomains := make(map[string]int) for _, provider := range s.IdentityProviders { - validationErrors = provider.Validate() - if validationErrors != nil { - return validationErrors + if provider.Enabled { + validationErrors = provider.Validate() + if validationErrors != nil { + return validationErrors + } + + configuredDomains[provider.Domain] += 1 + } + } + + for configuredDomain, configuredDomainCount := range configuredDomains { + if configuredDomainCount > 1 { + return fmt.Errorf("provider domains must be unique, found domain %s configured %d times", configuredDomain, configuredDomainCount) } } } diff --git a/backend/ee/saml/handler.go b/backend/ee/saml/handler.go index a8a534711..285b90a97 100644 --- a/backend/ee/saml/handler.go +++ b/backend/ee/saml/handler.go @@ -155,8 +155,9 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * var samlError error samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { userdata := provider.GetUserData(assertionInfo) - - linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, state.Provider, true, state.IsFlow) + identityProviderIssuer := assertionInfo.Assertions[0].Issuer + samlDomain := provider.GetDomain() + linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow) if samlErrorTx != nil { return samlErrorTx } @@ -164,7 +165,7 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * accountLinkingResult = linkResult emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email) - identityModel := emailModel.Identities.GetIdentity(provider.GetDomain(), userdata.Metadata.Subject) + identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject) token, tokenError := models.NewToken( linkResult.User.ID, diff --git a/backend/ee/saml/provider/saml.go b/backend/ee/saml/provider/saml.go index 82eabc97a..d0635d6fa 100644 --- a/backend/ee/saml/provider/saml.go +++ b/backend/ee/saml/provider/saml.go @@ -37,7 +37,7 @@ func NewBaseSamlProvider(cfg *config.Config, idpConfig samlConfig.IdentityProvid IDPCertificateStore: &idpMetadata.certs, AssertionConsumerServiceURL: fmt.Sprintf("%s/saml/callback", cfg.Saml.Endpoint), - ServiceProviderIssuer: fmt.Sprintf("%s/saml/metadata", cfg.Saml.Endpoint), + ServiceProviderIssuer: cfg.Saml.Endpoint, ServiceProviderSLOURL: fmt.Sprintf("%s/saml/logout", cfg.Saml.Endpoint), SPKeyStore: serviceProviderCertStore, diff --git a/backend/flow_api/flow/shared/action_exchange_token.go b/backend/flow_api/flow/shared/action_exchange_token.go index 1e007b949..1031d6556 100644 --- a/backend/flow_api/flow/shared/action_exchange_token.go +++ b/backend/flow_api/flow/shared/action_exchange_token.go @@ -90,7 +90,7 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { return fmt.Errorf("failed to set login_method to the stash: %w", err) } - if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderName); err != nil { + if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { return fmt.Errorf("failed to set third_party_provider to the stash: %w", err) } diff --git a/backend/flow_api/flow/shared/action_thirdparty_oauth.go b/backend/flow_api/flow/shared/action_thirdparty_oauth.go index 4a1443056..de6086a44 100644 --- a/backend/flow_api/flow/shared/action_thirdparty_oauth.go +++ b/backend/flow_api/flow/shared/action_thirdparty_oauth.go @@ -39,7 +39,7 @@ func (a ThirdPartyOAuth) Initialize(c flowpilot.InitializationContext) { Required(true) for _, provider := range enabledThirdPartyProviders { - providerInput.AllowedValue(provider.DisplayName, provider.Name) + providerInput.AllowedValue(provider.DisplayName, provider.ID) } slices.SortFunc(enabledCustomThirdPartyProviders, func(a, b config.CustomThirdPartyProvider) bool { @@ -47,7 +47,7 @@ func (a ThirdPartyOAuth) Initialize(c flowpilot.InitializationContext) { }) for _, provider := range enabledCustomThirdPartyProviders { - providerInput.AllowedValue(provider.DisplayName, provider.Name) + providerInput.AllowedValue(provider.DisplayName, provider.ID) } c.AddInputs(flowpilot.StringInput("redirect_to").Hidden(true).Required(true), providerInput) diff --git a/backend/flow_api/services/user.go b/backend/flow_api/services/user.go index 9b13b2cd1..46930d52a 100644 --- a/backend/flow_api/services/user.go +++ b/backend/flow_api/services/user.go @@ -8,7 +8,7 @@ import ( func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool { for _, identity := range identities { - if provider := cfg.ThirdParty.Providers.Get(identity.ProviderName); provider != nil { + if provider := cfg.ThirdParty.Providers.Get(identity.ProviderID); provider != nil { return provider.Enabled } } @@ -18,7 +18,7 @@ func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool { func UserCanDoSaml(cfg config.Config, identities models.Identities) bool { for _, identity := range identities { - if provider := cfg.Saml.GetProviderByDomain(identity.ProviderName); provider != nil { + if provider := cfg.Saml.GetProviderByDomain(identity.ProviderID); provider != nil { return cfg.Saml.Enabled && provider.Enabled } } diff --git a/backend/handler/thirdparty.go b/backend/handler/thirdparty.go index d0f1c850d..3b2e073fd 100644 --- a/backend/handler/thirdparty.go +++ b/backend/handler/thirdparty.go @@ -61,7 +61,7 @@ func (h *ThirdPartyHandler) Auth(c echo.Context) error { return h.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo) } - state, err := thirdparty.GenerateState(h.cfg, provider.Name(), request.RedirectTo) + state, err := thirdparty.GenerateState(h.cfg, provider.ID(), request.RedirectTo) if err != nil { return h.redirectError(c, thirdparty.ErrorServer("could not generate state").WithCause(err), errorRedirectTo) } @@ -143,14 +143,14 @@ func (h *ThirdPartyHandler) Callback(c echo.Context) error { return thirdparty.ErrorInvalidRequest("could not retrieve user data from provider").WithCause(terr) } - linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.Name(), false, state.IsFlow) + linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.ID(), false, nil, state.IsFlow) if terr != nil { return terr } accountLinkingResult = linkingResult emailModel := linkingResult.User.Emails.GetEmailByAddress(userData.Metadata.Email) - identityModel := emailModel.Identities.GetIdentity(provider.Name(), userData.Metadata.Subject) + identityModel := emailModel.Identities.GetIdentity(provider.ID(), userData.Metadata.Subject) token, terr := models.NewToken( linkingResult.User.ID, diff --git a/backend/handler/thirdparty_test.go b/backend/handler/thirdparty_test.go index cea8e196a..812d79ad9 100644 --- a/backend/handler/thirdparty_test.go +++ b/backend/handler/thirdparty_test.go @@ -60,42 +60,42 @@ func (s *thirdPartySuite) setUpConfig(enabledProviders []string, allowedRedirect cfg.ThirdParty = config.ThirdParty{ Providers: config.ThirdPartyProviders{ Apple: config.ThirdPartyProvider{ - Name: "apple", + ID: "apple", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Google: config.ThirdPartyProvider{ - Name: "google", + ID: "google", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, GitHub: config.ThirdPartyProvider{ - Name: "github", + ID: "github", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Discord: config.ThirdPartyProvider{ - Name: "discord", + ID: "discord", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: true, }, Microsoft: config.ThirdPartyProvider{ - Name: "microsoft", + ID: "microsoft", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", AllowLinking: false, }, Facebook: config.ThirdPartyProvider{ - Name: "facebook", + ID: "facebook", Enabled: false, ClientID: "fakeClientID", Secret: "fakeClientSecret", diff --git a/backend/persistence/identity_persister.go b/backend/persistence/identity_persister.go index 6e1a07aa8..1a7e65c4b 100644 --- a/backend/persistence/identity_persister.go +++ b/backend/persistence/identity_persister.go @@ -10,7 +10,7 @@ import ( ) type IdentityPersister interface { - Get(userProviderID string, providerID string) (*models.Identity, error) + Get(providerUserID string, providerID string) (*models.Identity, error) GetByID(identityID uuid.UUID) (*models.Identity, error) Create(identity models.Identity) error Update(identity models.Identity) error @@ -32,9 +32,9 @@ func (p identityPersister) GetByID(identityID uuid.UUID) (*models.Identity, erro return identity, nil } -func (p identityPersister) Get(userProviderID string, providerID string) (*models.Identity, error) { +func (p identityPersister) Get(providerUserID string, providerID string) (*models.Identity, error) { identity := &models.Identity{} - if err := p.db.EagerPreload().Where("provider_id = ? AND provider_name = ?", userProviderID, providerID).First(identity); err != nil { + if err := p.db.EagerPreload().Where("provider_user_id = ? AND provider_id = ?", providerUserID, providerID).First(identity); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/backend/persistence/migrations/20250130154010_change_identities.down.fizz b/backend/persistence/migrations/20250130154010_change_identities.down.fizz new file mode 100644 index 000000000..19ca26dca --- /dev/null +++ b/backend/persistence/migrations/20250130154010_change_identities.down.fizz @@ -0,0 +1,6 @@ +drop_index("identities", "identities_provider_user_id_provider_id_idx") + +rename_column("identities", "provider_id", "provider_name") +rename_column("identities", "provider_user_id", "provider_id") + +add_index("identities", ["provider_id", "provider_name"], {unique: true}) diff --git a/backend/persistence/migrations/20250130154010_change_identities.up.fizz b/backend/persistence/migrations/20250130154010_change_identities.up.fizz new file mode 100644 index 000000000..206f0ab5d --- /dev/null +++ b/backend/persistence/migrations/20250130154010_change_identities.up.fizz @@ -0,0 +1,6 @@ +drop_index("identities", "identities_provider_id_provider_name_idx") + +rename_column("identities", "provider_id", "provider_user_id") +rename_column("identities", "provider_name", "provider_id") + +add_index("identities", ["provider_user_id", "provider_id"], {unique: true}) diff --git a/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz b/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz new file mode 100644 index 000000000..b7aa5e6d4 --- /dev/null +++ b/backend/persistence/migrations/20250130170131_create_saml_identities.down.fizz @@ -0,0 +1 @@ +drop_table("saml_identities") diff --git a/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz b/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz new file mode 100644 index 000000000..50fd61504 --- /dev/null +++ b/backend/persistence/migrations/20250130170131_create_saml_identities.up.fizz @@ -0,0 +1,8 @@ +create_table("saml_identities") { + t.Column("id", "uuid", {primary: true}) + t.Column("identity_id", "uuid", { "null": false }) + t.Column("domain", "string", { "null": false }) + t.Timestamps() + t.ForeignKey("identity_id", {"identities": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"}) + t.Index(["identity_id", "domain"], {"unique": true}) +} diff --git a/backend/persistence/models/email.go b/backend/persistence/models/email.go index 7349d09e0..30e1bf4a4 100644 --- a/backend/persistence/models/email.go +++ b/backend/persistence/models/email.go @@ -45,6 +45,15 @@ func (email *Email) IsPrimary() bool { return false } +func (email *Email) GetSamlIdentityForDomain(domain string) *SamlIdentity { + for _, identity := range email.Identities { + if identity.SamlIdentity != nil && identity.SamlIdentity.Domain == domain { + return identity.SamlIdentity + } + } + return nil +} + func (emails *Emails) GetVerified() Emails { var list Emails for _, email := range *emails { diff --git a/backend/persistence/models/identity.go b/backend/persistence/models/identity.go index e12777753..d88793080 100644 --- a/backend/persistence/models/identity.go +++ b/backend/persistence/models/identity.go @@ -12,21 +12,22 @@ import ( // Identity is used by pop to map your identities database table to your go code. type Identity struct { - ID uuid.UUID `json:"id" db:"id"` - ProviderID string `json:"provider_id" db:"provider_id"` - ProviderName string `json:"provider_name" db:"provider_name"` - Data slices.Map `json:"data" db:"data"` - EmailID uuid.UUID `json:"email_id" db:"email_id"` - Email *Email `json:"email,omitempty" belongs_to:"email"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + ID uuid.UUID `json:"id" db:"id"` + ProviderUserID string `json:"provider_user_id" db:"provider_user_id"` + ProviderID string `json:"provider_id" db:"provider_id"` + Data slices.Map `json:"data" db:"data"` + EmailID uuid.UUID `json:"email_id" db:"email_id"` + Email *Email `json:"email,omitempty" belongs_to:"email"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + SamlIdentity *SamlIdentity `json:"saml_identity" has_one:"saml_identity"` } type Identities []Identity -func (identities Identities) GetIdentity(providerName string, providerId string) *Identity { +func (identities Identities) GetIdentity(providerID string, providerUserID string) *Identity { for _, identity := range identities { - if identity.ProviderName == providerName && identity.ProviderID == providerId { + if identity.ProviderID == providerID && identity.ProviderUserID == providerUserID { return &identity } } @@ -34,22 +35,22 @@ func (identities Identities) GetIdentity(providerName string, providerId string) return nil } -func NewIdentity(provider string, identityData map[string]interface{}, emailID uuid.UUID) (*Identity, error) { - providerID, ok := identityData["sub"] +func NewIdentity(providerID string, identityData map[string]interface{}, emailID uuid.UUID) (*Identity, error) { + providerUserID, ok := identityData["sub"] if !ok { - return nil, errors.New("missing provider id") + return nil, errors.New("missing provider user id") } now := time.Now().UTC() id, _ := uuid.NewV4() identity := &Identity{ - ID: id, - Data: identityData, - ProviderID: providerID.(string), - ProviderName: provider, - EmailID: emailID, - CreatedAt: now, - UpdatedAt: now, + ID: id, + Data: identityData, + ProviderUserID: providerUserID.(string), + ProviderID: providerID, + EmailID: emailID, + CreatedAt: now, + UpdatedAt: now, } return identity, nil @@ -60,8 +61,8 @@ func NewIdentity(provider string, identityData map[string]interface{}, emailID u func (i *Identity) Validate(tx *pop.Connection) (*validate.Errors, error) { return validate.Validate( &validators.UUIDIsPresent{Name: "ID", Field: i.ID}, + &validators.StringIsPresent{Name: "ProviderUserID", Field: i.ProviderUserID}, &validators.StringIsPresent{Name: "ProviderID", Field: i.ProviderID}, - &validators.StringIsPresent{Name: "ProviderName", Field: i.ProviderName}, &validators.TimeIsPresent{Name: "CreatedAt", Field: i.CreatedAt}, &validators.TimeIsPresent{Name: "UpdatedAt", Field: i.UpdatedAt}, ), nil diff --git a/backend/persistence/models/saml_identity.go b/backend/persistence/models/saml_identity.go new file mode 100644 index 000000000..c58dfa511 --- /dev/null +++ b/backend/persistence/models/saml_identity.go @@ -0,0 +1,27 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "time" +) + +type SamlIdentity struct { + ID uuid.UUID `json:"id" db:"id"` + IdentityID uuid.UUID `json:"identity_id" db:"identity_id"` + Domain string `json:"domain" db:"domain"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type SamlIdentities []SamlIdentity + +func (i *SamlIdentity) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: i.ID}, + &validators.UUIDIsPresent{Name: "IdentityID", Field: i.IdentityID}, + &validators.StringIsPresent{Name: "Domain", Field: i.Domain}, + ), nil +} diff --git a/backend/persistence/persister.go b/backend/persistence/persister.go index f9742ef55..aeaa1a884 100644 --- a/backend/persistence/persister.go +++ b/backend/persistence/persister.go @@ -35,6 +35,8 @@ type Persister interface { GetSamlCertificatePersisterWithConnection(tx *pop.Connection) SamlCertificatePersister GetSamlStatePersister() SamlStatePersister GetSamlStatePersisterWithConnection(tx *pop.Connection) SamlStatePersister + GetSamlIdentityPersister() SamlIdentityPersister + GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister GetTokenPersister() TokenPersister GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister GetUserPersister() UserPersister @@ -263,6 +265,14 @@ func (p *persister) GetSamlCertificatePersisterWithConnection(tx *pop.Connection return NewSamlCertificatePersister(tx) } +func (p *persister) GetSamlIdentityPersister() SamlIdentityPersister { + return NewSamlIdentityPersister(p.DB) +} + +func (p *persister) GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister { + return NewSamlIdentityPersister(tx) +} + func (p *persister) GetWebhookPersister(tx *pop.Connection) WebhookPersister { if tx != nil { return NewWebhookPersister(tx) diff --git a/backend/persistence/saml_identity_persister.go b/backend/persistence/saml_identity_persister.go new file mode 100644 index 000000000..4554d8353 --- /dev/null +++ b/backend/persistence/saml_identity_persister.go @@ -0,0 +1,46 @@ +package persistence + +import ( + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +type SamlIdentityPersister interface { + Create(samlIdentity models.SamlIdentity) error + Update(samlIdentity models.SamlIdentity) error +} + +type samlIdentityPersister struct { + db *pop.Connection +} + +func NewSamlIdentityPersister(db *pop.Connection) SamlIdentityPersister { + return &samlIdentityPersister{db: db} +} + +func (p samlIdentityPersister) Create(samlIdentity models.SamlIdentity) error { + vErr, err := p.db.Eager().ValidateAndCreate(&samlIdentity) + if err != nil { + return fmt.Errorf("failed to store saml identity: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("saml identity object validation failed: %w", vErr) + } + + return nil +} + +func (p samlIdentityPersister) Update(samlIdentity models.SamlIdentity) error { + vErr, err := p.db.ValidateAndUpdate(&samlIdentity) + if err != nil { + return fmt.Errorf("failed to update saml identity: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("saml identity object validation failed: %w", vErr) + } + + return nil +} diff --git a/backend/persistence/user_persister.go b/backend/persistence/user_persister.go index 0803a6f3f..389880ce4 100644 --- a/backend/persistence/user_persister.go +++ b/backend/persistence/user_persister.go @@ -36,7 +36,7 @@ func (p *userPersister) Get(id uuid.UUID) (*models.User, error) { eagerPreloadFields := []string{ "Emails", "Emails.PrimaryEmail", - "Emails.Identities", + "Emails.Identities.SamlIdentity", "WebauthnCredentials", "WebauthnCredentials.Transports", "Username", @@ -57,7 +57,7 @@ func (p *userPersister) Get(id uuid.UUID) (*models.User, error) { func (p *userPersister) GetByEmailAddress(emailAddress string) (*models.User, error) { email := models.Email{} - err := p.db.Where("address = (?)", emailAddress).First(&email) + err := p.db.Eager().Where("address = (?)", emailAddress).First(&email) if err != nil && errors.Is(err, sql.ErrNoRows) { return nil, nil diff --git a/backend/test/fixtures/thirdparty/identities.yaml b/backend/test/fixtures/thirdparty/identities.yaml index d0feca72b..5f4590237 100644 --- a/backend/test/fixtures/thirdparty/identities.yaml +++ b/backend/test/fixtures/thirdparty/identities.yaml @@ -1,41 +1,41 @@ - id: 443a984d-bb1c-46fe-b685-151bd0f017b1 - provider_id: "google_abcde" - provider_name: "google" + provider_user_id: "google_abcde" + provider_id: "google" data: '{"email":"test-with-google-identity@example.com","email_verified":true,"iss":"https://www.googleapis.com","sub":"google_abcde"}' email_id: 15ed767e-1544-4e03-a732-4bf80caa2e78 created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: b140230b-be5b-4589-9762-ec72f32d2833 - provider_id: "1234" - provider_name: "github" + provider_user_id: "1234" + provider_id: "github" data: '{"email":"test-with-github-identity@example.com","email_verified":true,"iss":"https://api.github.com","sub":"1234"}' email_id: 80939e4c-0028-4b67-aeee-98c617a20b6b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: 09b084b7-36af-41c5-b4b6-db9b8ed8385b - provider_id: "apple_abcde" - provider_name: "apple" + provider_user_id: "apple_abcde" + provider_id: "apple" data: '{"email":"test-with-apple-identity@example.com","email_verified":true,"iss":"https://appleid.apple.com","sub":"apple_abcde"}' email_id: 05ab6e1f-8dfb-4329-ae04-22571a68d96b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: 18d61f13-2789-467a-a3a6-4292c0621580 - provider_id: "discord_abcde" - provider_name: "discord" + provider_user_id: "discord_abcde" + provider_id: "discord" data: '{"email":"test-with-discord-identity@example.com","email_verified":true,"iss":"https://discord.com/api","sub":"discord_abcde"}' email_id: 09f35529-cca6-44a7-ab1d-b07e95a04e3b created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: e4cd954b-8f1f-4fa2-b7e9-60a0ffa56363 - provider_id: "microsoft_abcde" - provider_name: "microsoft" + provider_user_id: "microsoft_abcde" + provider_id: "microsoft" data: '{"email":"test-with-microsoft-identity@example.com","iss":"https://login.microsoftonline.com/common","sub":"microsoft_abcde"}' email_id: d781006b-4f55-4327-bad6-55bc34b88585 created_at: 2020-12-31 23:59:59 updated_at: 2020-12-31 23:59:59 - id: b6b1309d-61de-4a82-b8b8-d54db0be679b - provider_id: "facebook_abcde" - provider_name: "facebook" + provider_user_id: "facebook_abcde" + provider_id: "facebook" data: '{"email":"test-with-facebook-identity@example.com","sub":"facebook_abcde"}' email_id: d781006b-4f55-4327-bad6-55bc34b88585 created_at: 2020-12-31 23:59:59 diff --git a/backend/thirdparty/linking.go b/backend/thirdparty/linking.go index a33561929..bc4703e7c 100644 --- a/backend/thirdparty/linking.go +++ b/backend/thirdparty/linking.go @@ -3,11 +3,13 @@ package thirdparty import ( "fmt" "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/persistence" "github.com/teamhanko/hanko/backend/persistence/models" "github.com/teamhanko/hanko/backend/webhooks/events" "strings" + "time" ) type AccountLinkingResult struct { @@ -17,14 +19,14 @@ type AccountLinkingResult struct { UserCreated bool } -func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, isSaml bool, isFlow bool) (*AccountLinkingResult, error) { +func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, isSaml bool, samlDomain *string, isFlow bool) (*AccountLinkingResult, error) { if !isFlow { if cfg.Email.RequireVerification && !userData.Metadata.EmailVerified { return nil, ErrorUnverifiedProviderEmail("third party provider email must be verified") } } - identity, err := p.GetIdentityPersister().Get(userData.Metadata.Subject, providerName) + identity, err := p.GetIdentityPersister().Get(userData.Metadata.Subject, providerID) if err != nil { return nil, ErrorServer("could not get identity").WithCause(err) } @@ -36,29 +38,29 @@ func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister } if user == nil { - return signUp(tx, cfg, p, userData, providerName) + return signUp(tx, cfg, p, userData, providerID, isSaml, samlDomain) } else { - return link(tx, cfg, p, userData, providerName, user, isSaml) + return link(tx, cfg, p, userData, providerID, user, isSaml, samlDomain) } } else { return signIn(tx, cfg, p, userData, identity) } } -func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, user *models.User, isSaml bool) (*AccountLinkingResult, error) { +func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, user *models.User, isSaml bool, samlDomain *string) (*AccountLinkingResult, error) { if !isSaml { - if strings.HasPrefix(providerName, "custom_") { - provider, ok := cfg.ThirdParty.CustomProviders[strings.TrimPrefix(providerName, "custom_")] + if strings.HasPrefix(providerID, "custom_") { + provider, ok := cfg.ThirdParty.CustomProviders[strings.TrimPrefix(providerID, "custom_")] if !ok { - return nil, ErrorServer(fmt.Sprintf("unknown provider: %s", providerName)) + return nil, ErrorServer(fmt.Sprintf("unknown provider: %s", providerID)) } if !provider.AllowLinking { return nil, ErrorUserConflict("third party account linking for existing user with same email disallowed") } } else { - provider := cfg.ThirdParty.Providers.Get(providerName) + provider := cfg.ThirdParty.Providers.Get(providerID) if provider == nil { - return nil, fmt.Errorf("unknown provider: %s", providerName) + return nil, fmt.Errorf("unknown provider: %s", providerID) } if !provider.AllowLinking { @@ -74,7 +76,7 @@ func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userD return nil, ErrorServer("could not link account").WithCause(err) } - identity, err := models.NewIdentity(providerName, userDataMap, email.ID) + identity, err := models.NewIdentity(providerID, userDataMap, email.ID) if err != nil { return nil, ErrorServer("could not create identity").WithCause(err) } @@ -84,6 +86,38 @@ func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userD return nil, ErrorServer("could not create identity").WithCause(err) } + if isSaml && samlDomain != nil && *samlDomain != "" { + if existingSamlIdentity := email.GetSamlIdentityForDomain(*samlDomain); existingSamlIdentity != nil { + identityToDeleteID := existingSamlIdentity.IdentityID + existingSamlIdentity.IdentityID = identity.ID + + err = p.GetSamlIdentityPersisterWithConnection(tx).Update(*existingSamlIdentity) + if err != nil { + return nil, ErrorServer("could update saml identity").WithCause(err) + } + + err = p.GetIdentityPersisterWithConnection(tx).Delete(models.Identity{ID: identityToDeleteID}) + if err != nil { + return nil, ErrorServer("could not delete identity").WithCause(err) + } + } else { + samlIdentityID, _ := uuid.NewV4() + now := time.Now().UTC() + samlIdentity := &models.SamlIdentity{ + ID: samlIdentityID, + IdentityID: identity.ID, + Domain: *samlDomain, + CreatedAt: now, + UpdatedAt: now, + } + + err = p.GetSamlIdentityPersisterWithConnection(tx).Create(*samlIdentity) + if err != nil { + return nil, ErrorServer("could not create saml identity").WithCause(err) + } + } + } + u, terr := p.GetUserPersisterWithConnection(tx).Get(*email.UserID) if terr != nil { return nil, ErrorServer("could not get user").WithCause(terr) @@ -187,7 +221,7 @@ func signIn(tx *pop.Connection, cfg *config.Config, p persistence.Persister, use return linkingResult, nil } -func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string) (*AccountLinkingResult, error) { +func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerID string, isSaml bool, samlDomain *string) (*AccountLinkingResult, error) { if !cfg.Account.AllowSignup { return nil, ErrorSignUpDisabled("account signup is disabled") } @@ -241,14 +275,31 @@ func signUp(tx *pop.Connection, cfg *config.Config, p persistence.Persister, use return nil, ErrorServer("could not link account").WithCause(err) } - identity, terr := models.NewIdentity(providerName, userDataMap, email.ID) + identity, terr := models.NewIdentity(providerID, userDataMap, email.ID) if terr != nil { return nil, ErrorServer("could not create identity").WithCause(terr) } terr = identityPersister.Create(*identity) if terr != nil { - return nil, ErrorServer("could not create identity").WithCause(terr) + return nil, ErrorServer("could not store identity").WithCause(terr) + } + + if isSaml && samlDomain != nil && *samlDomain != "" { + samlIdentityID, _ := uuid.NewV4() + now := time.Now().UTC() + samlIdentity := &models.SamlIdentity{ + ID: samlIdentityID, + IdentityID: identity.ID, + Domain: *samlDomain, + CreatedAt: now, + UpdatedAt: now, + } + + err = p.GetSamlIdentityPersisterWithConnection(tx).Create(*samlIdentity) + if err != nil { + return nil, ErrorServer("could not store saml identity").WithCause(err) + } } u, terr := userPersister.Get(*email.UserID) diff --git a/backend/thirdparty/provider.go b/backend/thirdparty/provider.go index 02a9b516f..a4759b1fa 100644 --- a/backend/thirdparty/provider.go +++ b/backend/thirdparty/provider.go @@ -85,7 +85,7 @@ type OAuthProvider interface { AuthCodeURL(string, ...oauth2.AuthCodeOption) string GetUserData(*oauth2.Token) (*UserData, error) GetOAuthToken(string) (*oauth2.Token, error) - Name() string + ID() string } func GetProvider(config config.ThirdParty, id string) (OAuthProvider, error) { diff --git a/backend/thirdparty/provider_apple.go b/backend/thirdparty/provider_apple.go index 37e9bd603..8cf59a3f0 100644 --- a/backend/thirdparty/provider_apple.go +++ b/backend/thirdparty/provider_apple.go @@ -116,6 +116,6 @@ func (a appleProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return userData, nil } -func (a appleProvider) Name() string { - return a.config.Name +func (a appleProvider) ID() string { + return a.config.ID } diff --git a/backend/thirdparty/provider_custom.go b/backend/thirdparty/provider_custom.go index d5bbf7d6b..0c499dfe4 100644 --- a/backend/thirdparty/provider_custom.go +++ b/backend/thirdparty/provider_custom.go @@ -17,7 +17,7 @@ type customProvider struct { func NewCustomThirdPartyProvider(config *config.CustomThirdPartyProvider, redirectURL string) (OAuthProvider, error) { if !config.Enabled { - return nil, fmt.Errorf("provider %s is disabled", config.Name) + return nil, fmt.Errorf("provider %s is disabled", config.ID) } customProvider := &customProvider{ @@ -100,6 +100,6 @@ func (p customProvider) GetUserData(token *oauth2.Token) (*UserData, error) { }, nil } -func (p customProvider) Name() string { - return p.config.Name +func (p customProvider) ID() string { + return p.config.ID } diff --git a/backend/thirdparty/provider_discord.go b/backend/thirdparty/provider_discord.go index ddfeab2dd..16bb26e80 100644 --- a/backend/thirdparty/provider_discord.go +++ b/backend/thirdparty/provider_discord.go @@ -104,6 +104,6 @@ func (g discordProvider) buildAvatarURL(userID string, avatarHash string) string return fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", userID, avatarHash) } -func (g discordProvider) Name() string { - return g.config.Name +func (g discordProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_facebook.go b/backend/thirdparty/provider_facebook.go index 6c3c9f637..158b1e6be 100644 --- a/backend/thirdparty/provider_facebook.go +++ b/backend/thirdparty/provider_facebook.go @@ -124,6 +124,6 @@ func (f facebookProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (f facebookProvider) Name() string { - return f.config.Name +func (f facebookProvider) ID() string { + return f.config.ID } diff --git a/backend/thirdparty/provider_github.go b/backend/thirdparty/provider_github.go index 1815ccb86..33fedff37 100644 --- a/backend/thirdparty/provider_github.go +++ b/backend/thirdparty/provider_github.go @@ -114,6 +114,6 @@ func (g githubProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g githubProvider) Name() string { - return g.config.Name +func (g githubProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_google.go b/backend/thirdparty/provider_google.go index 0f5f99c46..a6ad38c09 100644 --- a/backend/thirdparty/provider_google.go +++ b/backend/thirdparty/provider_google.go @@ -93,6 +93,6 @@ func (g googleProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g googleProvider) Name() string { - return g.config.Name +func (g googleProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_linkedin.go b/backend/thirdparty/provider_linkedin.go index 381f20645..9dd3ee56a 100644 --- a/backend/thirdparty/provider_linkedin.go +++ b/backend/thirdparty/provider_linkedin.go @@ -107,6 +107,6 @@ func (g linkedInProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (g linkedInProvider) Name() string { - return g.config.Name +func (g linkedInProvider) ID() string { + return g.config.ID } diff --git a/backend/thirdparty/provider_microsoft.go b/backend/thirdparty/provider_microsoft.go index 9bcf240e6..6ff753d61 100644 --- a/backend/thirdparty/provider_microsoft.go +++ b/backend/thirdparty/provider_microsoft.go @@ -159,8 +159,8 @@ func (p microsoftProvider) GetUserData(token *oauth2.Token) (*UserData, error) { return data, nil } -func (p microsoftProvider) Name() string { - return p.config.Name +func (p microsoftProvider) ID() string { + return p.config.ID } func (p microsoftProvider) issuerValidator() jwt.ValidatorFunc {