diff --git a/fence/auth.py b/fence/auth.py index 8cd9ed15c..37eb9a221 100644 --- a/fence/auth.py +++ b/fence/auth.py @@ -57,7 +57,9 @@ def build_redirect_url(hostname, path): return redirect_base + path -def login_user(username, provider, fence_idp=None, shib_idp=None, email=None): +def login_user( + username, provider, fence_idp=None, shib_idp=None, email=None, id_from_idp=None +): """ Login a user with the given username and provider. Set values in Flask session to indicate the user being logged in. In addition, commit the user @@ -70,6 +72,8 @@ def login_user(username, provider, fence_idp=None, shib_idp=None, email=None): shib_idp (str, optional): Downstreawm shibboleth IdP email (str, optional): email of user (may or may not match username depending on the IdP) + id_from_idp (str, optional): id from the IDP (which may be different than + the username) """ def set_flask_session_values(user): @@ -93,6 +97,7 @@ def set_flask_session_values(user): user = query_for_user(session=current_session, username=username) if user: _update_users_email(user, email) + _update_users_id_from_idp(user, id_from_idp) # This expression is relevant to those users who already have user and # idp info persisted to the database. We return early to avoid @@ -101,11 +106,16 @@ def set_flask_session_values(user): set_flask_session_values(user) return else: + # we need a new user + user = User(username=username) + if email: - user = User(username=username, email=email) - else: - user = User(username=username) + user.email = email + + if id_from_idp: + user.id_from_idp = id_from_idp + # setup idp connection for new user (or existing user w/o it setup) idp = ( current_session.query(IdentityProvider) .filter(IdentityProvider.name == provider) @@ -271,3 +281,17 @@ def _update_users_email(user, email): current_session.add(user) current_session.commit() + + +def _update_users_id_from_idp(user, id_from_idp): + """ + Update id_from_idp if provided and doesn't match db entry. + """ + if id_from_idp and user.id_from_idp != id_from_idp: + logger.info( + f"Updating username {user.username}'s id_from_idp from {user.id_from_idp} to {id_from_idp}" + ) + user.id_from_idp = id_from_idp + + current_session.add(user) + current_session.commit() diff --git a/fence/blueprints/login/base.py b/fence/blueprints/login/base.py index 61ce1ca44..580dd0a3c 100644 --- a/fence/blueprints/login/base.py +++ b/fence/blueprints/login/base.py @@ -56,7 +56,14 @@ def get(self): class DefaultOAuth2Callback(Resource): - def __init__(self, idp_name, client, username_field="email", email_field="email"): + def __init__( + self, + idp_name, + client, + username_field="email", + email_field="email", + id_from_idp_field="sub", + ): """ Construct a resource for a login callback endpoint @@ -65,14 +72,18 @@ def __init__(self, idp_name, client, username_field="email", email_field="email" client (fence.resources.openid.idp_oauth2.Oauth2ClientBase): Some instaniation of this base client class or a child class username_field (str, optional): default field from response to - retrieve the username + retrieve the unique username email_field (str, optional): default field from response to retrieve the email (if available) + id_from_idp_field (str, optional): default field from response to + retrieve the idp-specific ID for this user (could be the same + as username_field) """ self.idp_name = idp_name self.client = client self.username_field = username_field self.email_field = email_field + self.id_from_idp_field = id_from_idp_field def get(self): # Check if user granted access @@ -101,20 +112,23 @@ def get(self): result = self.client.get_user_id(code) username = result.get(self.username_field) email = result.get(self.email_field) + id_from_idp = result.get(self.id_from_idp_field) if username: - resp = _login(username, self.idp_name, email=email) - self.post_login(flask.g.user, result) + resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp) + self.post_login( + user=flask.g.user, token_result=result, id_from_idp=id_from_idp + ) return resp raise UserError(result) - def post_login(self, user=None, token_result=None): - prepare_login_log(self.idp_name) + def post_login(self, user=None, token_result=None, id_from_idp=None): + prepare_login_log(self.idp_name, id_from_idp=id_from_idp) -def prepare_login_log(idp_name): +def prepare_login_log(idp_name, id_from_idp=None): flask.g.audit_data = { "username": flask.g.user.username, - "sub": flask.g.user.id, + "sub": id_from_idp, "idp": idp_name, "fence_idp": flask.session.get("fence_idp"), "shib_idp": flask.session.get("shib_idp"), @@ -122,12 +136,12 @@ def prepare_login_log(idp_name): } -def _login(username, idp_name, email=None): +def _login(username, idp_name, email=None, id_from_idp=None): """ Login user with given username, then redirect if session has a saved redirect. """ - login_user(username, idp_name, email=email) + login_user(username, idp_name, email=email, id_from_idp=id_from_idp) if config["REGISTER_USERS_ON"]: if not flask.g.user.additional_info.get("registration_info"): diff --git a/fence/blueprints/login/ras.py b/fence/blueprints/login/ras.py index c8df791e1..8120af507 100644 --- a/fence/blueprints/login/ras.py +++ b/fence/blueprints/login/ras.py @@ -35,7 +35,7 @@ def __init__(self): username_field="username", ) - def post_login(self, user=None, token_result=None): + def post_login(self, user=None, token_result=None, id_from_idp=None): # TODO: I'm not convinced this code should be in post_login. # Just putting it in here for now, but might refactor later. # This saves us a call to RAS /userinfo, but will not make sense @@ -187,4 +187,4 @@ def post_login(self, user=None, token_result=None): ) sync.sync_single_user_visas(user, current_session) - super(RASCallback, self).post_login() + super(RASCallback, self).post_login(id_from_idp=id_from_idp) diff --git a/fence/blueprints/login/synapse.py b/fence/blueprints/login/synapse.py index d1dce907e..b6f3cbeb7 100644 --- a/fence/blueprints/login/synapse.py +++ b/fence/blueprints/login/synapse.py @@ -22,10 +22,10 @@ def __init__(self): idp_name=IdentityProvider.synapse, client=flask.current_app.synapse_client, username_field="fence_username", + id_from_idp_field="sub", ) - def post_login(self, user=None, token_result=None): - user.id_from_idp = token_result["sub"] + def post_login(self, user=None, token_result=None, id_from_idp=None): user.email = token_result["email"] user.display_name = "{given_name} {family_name}".format(**token_result) info = {} @@ -53,4 +53,4 @@ def post_login(self, user=None, token_result=None): user.username, config["DREAM_CHALLENGE_GROUP"] ) - super(SynapseCallback, self).post_login() + super(SynapseCallback, self).post_login(id_from_idp=id_from_idp) diff --git a/fence/job/visa_update_cronjob.py b/fence/job/visa_update_cronjob.py index cff3bc58f..969a37426 100644 --- a/fence/job/visa_update_cronjob.py +++ b/fence/job/visa_update_cronjob.py @@ -67,8 +67,10 @@ async def update_tokens(self, db_session): Initialize a producer-consumer workflow. Producer: Collects users from db and feeds it to the workers - Worker: Takes in the users from the Producer and passes it to the Updater to update the tokens and passes those updated tokens for JWT validation - Updater: Updates refresh_tokens and visas by calling the update_user_visas from the correct client + Worker: Takes in the users from the Producer and passes it to the Updater to + update the tokens and passes those updated tokens for JWT validation + Updater: Updates refresh_tokens and visas by calling the update_user_visas from + the correct client """ start_time = time.time() diff --git a/fence/resources/openid/cognito_oauth2.py b/fence/resources/openid/cognito_oauth2.py index 100b93f9e..b85484041 100644 --- a/fence/resources/openid/cognito_oauth2.py +++ b/fence/resources/openid/cognito_oauth2.py @@ -57,7 +57,7 @@ def get_user_id(self, code): if claims["email"] and ( claims["email_verified"] or self.settings["assume_emails_verified"] ): - return {"email": claims["email"]} + return {"email": claims["email"], "sub": claims.get("sub")} elif claims["email"]: return {"error": "Email is not verified"} else: diff --git a/fence/resources/openid/google_oauth2.py b/fence/resources/openid/google_oauth2.py index b1329c7df..d98cfc6d1 100644 --- a/fence/resources/openid/google_oauth2.py +++ b/fence/resources/openid/google_oauth2.py @@ -53,7 +53,7 @@ def get_user_id(self, code): claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims["email"] and claims["email_verified"]: - return {"email": claims["email"]} + return {"email": claims["email"], "sub": claims.get("sub")} elif claims["email"]: return {"error": "Email is not verified"} else: diff --git a/fence/resources/openid/microsoft_oauth2.py b/fence/resources/openid/microsoft_oauth2.py index 081701e5c..ec9b4a17b 100755 --- a/fence/resources/openid/microsoft_oauth2.py +++ b/fence/resources/openid/microsoft_oauth2.py @@ -52,7 +52,7 @@ def get_user_id(self, code): claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims.get("email"): - return {"email": claims["email"]} + return {"email": claims["email"], "sub": claims.get("sub")} return {"error": "Can't get user's Microsoft email!"} except Exception as exception: self.logger.exception("Can't get user info") diff --git a/fence/resources/openid/okta_oauth2.py b/fence/resources/openid/okta_oauth2.py index 0cb2ab7f8..4d84658c2 100644 --- a/fence/resources/openid/okta_oauth2.py +++ b/fence/resources/openid/okta_oauth2.py @@ -41,7 +41,7 @@ def get_user_id(self, code): claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims["email"]: - return {"email": claims["email"]} + return {"email": claims["email"], "sub": claims.get("sub")} else: return {"error": "Can't get user's email!"} except Exception as e: diff --git a/fence/resources/openid/orcid_oauth2.py b/fence/resources/openid/orcid_oauth2.py index cd7190684..d06724ab1 100644 --- a/fence/resources/openid/orcid_oauth2.py +++ b/fence/resources/openid/orcid_oauth2.py @@ -45,7 +45,7 @@ def get_user_id(self, code): claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) if claims["sub"]: - return {"orcid": claims["sub"]} + return {"orcid": claims["sub"], "sub": claims["sub"]} else: return {"error": "Can't get user's orcid"} except Exception as e: diff --git a/fence/resources/openid/ras_oauth2.py b/fence/resources/openid/ras_oauth2.py index f301d8f52..49ce37430 100644 --- a/fence/resources/openid/ras_oauth2.py +++ b/fence/resources/openid/ras_oauth2.py @@ -201,7 +201,11 @@ def get_user_id(self, code): self.logger.exception("{}: {}".format(err_msg, e)) return {"error": err_msg} - return {"username": username, "email": userinfo.get("email")} + return { + "username": username, + "email": userinfo.get("email"), + "sub": userinfo.get("sub"), + } def refresh_cronjob_pkey_cache(self, issuer, kid, pkey_cache): """ diff --git a/tests/login/test_login_user.py b/tests/login/test_login_user.py index bba7a0cbe..c18bdc3cb 100644 --- a/tests/login/test_login_user.py +++ b/tests/login/test_login_user.py @@ -7,19 +7,25 @@ def test_login_user_already_in_db(db_session): """ Test that if a user is already in the database and logs in, the session will contain - the user's information. + the user's information (including additional information that may have been provided + during the login like email and id_from_idp) """ email = "testuser@gmail.com" provider = "Test Provider" + id_from_idp = "Provider_ID_0001" test_user = User(username=email, is_admin=False) db_session.add(test_user) db_session.commit() user_id = str(test_user.id) + assert not test_user.email + assert not test_user.id_from_idp - login_user(email, provider) + login_user(email, provider, email=email, id_from_idp=id_from_idp) assert test_user.identity_provider.name == provider + assert test_user.id_from_idp == id_from_idp + assert test_user.email == email assert flask.session["username"] == email assert flask.session["provider"] == provider assert flask.session["user_id"] == user_id @@ -33,8 +39,11 @@ def test_login_user_with_idp_already_in_db(db_session): """ email = "testuser@gmail.com" provider = "Test Provider" + id_from_idp = "Provider_ID_0001" - test_user = User(username=email, is_admin=False) + test_user = User( + username=email, email=email, id_from_idp=id_from_idp, is_admin=False + ) test_idp = IdentityProvider(name=provider) test_user.identity_provider = test_idp @@ -42,9 +51,11 @@ def test_login_user_with_idp_already_in_db(db_session): db_session.commit() user_id = str(test_user.id) - login_user(email, provider) + login_user(email, provider, email=email, id_from_idp=id_from_idp) assert test_user.identity_provider.name == provider + assert test_user.id_from_idp == id_from_idp + assert test_user.email == email assert flask.session["username"] == email assert flask.session["provider"] == provider assert flask.session["user_id"] == user_id @@ -58,12 +69,15 @@ def test_login_new_user(db_session): """ email = "testuser@gmail.com" provider = "Test Provider" + id_from_idp = "Provider_ID_0001" - login_user(email, provider) + login_user(email, provider, email=email, id_from_idp=id_from_idp) test_user = db_session.query(User).filter(User.username == email.lower()).first() assert test_user.identity_provider.name == provider + assert test_user.id_from_idp == id_from_idp + assert test_user.email == email assert flask.session["username"] == email assert flask.session["provider"] == provider assert flask.session["user_id"] == str(test_user.id) diff --git a/tests/login/test_microsoft_login.py b/tests/login/test_microsoft_login.py index 43aaf2c40..23343a7d2 100755 --- a/tests/login/test_microsoft_login.py +++ b/tests/login/test_microsoft_login.py @@ -24,7 +24,8 @@ def test_get_user_id(microsoft_oauth2_client): return_value=return_value, ): user_id = microsoft_oauth2_client.get_user_id(code="123") - assert user_id == expected_value # nosec + for key, value in expected_value.items(): + assert return_value[key] == value def test_get_user_id_missing_claim(microsoft_oauth2_client):