Skip to content

Commit

Permalink
Merge branch 'master' into chore/baseimage
Browse files Browse the repository at this point in the history
  • Loading branch information
Avantol13 authored Nov 11, 2021
2 parents e75e510 + 8370e56 commit 56272b7
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 33 deletions.
32 changes: 28 additions & 4 deletions fence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def build_redirect_url(hostname, path):
return redirect_base + path


def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
def login_user(
username, provider, fence_idp=None, shib_idp=None, email=None, id_from_idp=None
):
"""
Login a user with the given username and provider. Set values in Flask
session to indicate the user being logged in. In addition, commit the user
Expand All @@ -70,6 +72,8 @@ def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
shib_idp (str, optional): Downstreawm shibboleth IdP
email (str, optional): email of user (may or may not match username depending
on the IdP)
id_from_idp (str, optional): id from the IDP (which may be different than
the username)
"""

def set_flask_session_values(user):
Expand All @@ -93,6 +97,7 @@ def set_flask_session_values(user):
user = query_for_user(session=current_session, username=username)
if user:
_update_users_email(user, email)
_update_users_id_from_idp(user, id_from_idp)

# This expression is relevant to those users who already have user and
# idp info persisted to the database. We return early to avoid
Expand All @@ -101,11 +106,16 @@ def set_flask_session_values(user):
set_flask_session_values(user)
return
else:
# we need a new user
user = User(username=username)

if email:
user = User(username=username, email=email)
else:
user = User(username=username)
user.email = email

if id_from_idp:
user.id_from_idp = id_from_idp

# setup idp connection for new user (or existing user w/o it setup)
idp = (
current_session.query(IdentityProvider)
.filter(IdentityProvider.name == provider)
Expand Down Expand Up @@ -271,3 +281,17 @@ def _update_users_email(user, email):

current_session.add(user)
current_session.commit()


def _update_users_id_from_idp(user, id_from_idp):
"""
Update id_from_idp if provided and doesn't match db entry.
"""
if id_from_idp and user.id_from_idp != id_from_idp:
logger.info(
f"Updating username {user.username}'s id_from_idp from {user.id_from_idp} to {id_from_idp}"
)
user.id_from_idp = id_from_idp

current_session.add(user)
current_session.commit()
34 changes: 24 additions & 10 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def get(self):


class DefaultOAuth2Callback(Resource):
def __init__(self, idp_name, client, username_field="email", email_field="email"):
def __init__(
self,
idp_name,
client,
username_field="email",
email_field="email",
id_from_idp_field="sub",
):
"""
Construct a resource for a login callback endpoint
Expand All @@ -65,14 +72,18 @@ def __init__(self, idp_name, client, username_field="email", email_field="email"
client (fence.resources.openid.idp_oauth2.Oauth2ClientBase):
Some instaniation of this base client class or a child class
username_field (str, optional): default field from response to
retrieve the username
retrieve the unique username
email_field (str, optional): default field from response to
retrieve the email (if available)
id_from_idp_field (str, optional): default field from response to
retrieve the idp-specific ID for this user (could be the same
as username_field)
"""
self.idp_name = idp_name
self.client = client
self.username_field = username_field
self.email_field = email_field
self.id_from_idp_field = id_from_idp_field

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -101,33 +112,36 @@ def get(self):
result = self.client.get_user_id(code)
username = result.get(self.username_field)
email = result.get(self.email_field)
id_from_idp = result.get(self.id_from_idp_field)
if username:
resp = _login(username, self.idp_name, email=email)
self.post_login(flask.g.user, result)
resp = _login(username, self.idp_name, email=email, id_from_idp=id_from_idp)
self.post_login(
user=flask.g.user, token_result=result, id_from_idp=id_from_idp
)
return resp
raise UserError(result)

def post_login(self, user=None, token_result=None):
prepare_login_log(self.idp_name)
def post_login(self, user=None, token_result=None, id_from_idp=None):
prepare_login_log(self.idp_name, id_from_idp=id_from_idp)


def prepare_login_log(idp_name):
def prepare_login_log(idp_name, id_from_idp=None):
flask.g.audit_data = {
"username": flask.g.user.username,
"sub": flask.g.user.id,
"sub": id_from_idp,
"idp": idp_name,
"fence_idp": flask.session.get("fence_idp"),
"shib_idp": flask.session.get("shib_idp"),
"client_id": flask.session.get("client_id"),
}


def _login(username, idp_name, email=None):
def _login(username, idp_name, email=None, id_from_idp=None):
"""
Login user with given username, then redirect if session has a saved
redirect.
"""
login_user(username, idp_name, email=email)
login_user(username, idp_name, email=email, id_from_idp=id_from_idp)

if config["REGISTER_USERS_ON"]:
if not flask.g.user.additional_info.get("registration_info"):
Expand Down
4 changes: 2 additions & 2 deletions fence/blueprints/login/ras.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
username_field="username",
)

def post_login(self, user=None, token_result=None):
def post_login(self, user=None, token_result=None, id_from_idp=None):
# TODO: I'm not convinced this code should be in post_login.
# Just putting it in here for now, but might refactor later.
# This saves us a call to RAS /userinfo, but will not make sense
Expand Down Expand Up @@ -187,4 +187,4 @@ def post_login(self, user=None, token_result=None):
)
sync.sync_single_user_visas(user, current_session)

super(RASCallback, self).post_login()
super(RASCallback, self).post_login(id_from_idp=id_from_idp)
6 changes: 3 additions & 3 deletions fence/blueprints/login/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self):
idp_name=IdentityProvider.synapse,
client=flask.current_app.synapse_client,
username_field="fence_username",
id_from_idp_field="sub",
)

def post_login(self, user=None, token_result=None):
user.id_from_idp = token_result["sub"]
def post_login(self, user=None, token_result=None, id_from_idp=None):
user.email = token_result["email"]
user.display_name = "{given_name} {family_name}".format(**token_result)
info = {}
Expand Down Expand Up @@ -53,4 +53,4 @@ def post_login(self, user=None, token_result=None):
user.username, config["DREAM_CHALLENGE_GROUP"]
)

super(SynapseCallback, self).post_login()
super(SynapseCallback, self).post_login(id_from_idp=id_from_idp)
6 changes: 4 additions & 2 deletions fence/job/visa_update_cronjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ async def update_tokens(self, db_session):
Initialize a producer-consumer workflow.
Producer: Collects users from db and feeds it to the workers
Worker: Takes in the users from the Producer and passes it to the Updater to update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from the correct client
Worker: Takes in the users from the Producer and passes it to the Updater to
update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from
the correct client
"""
start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/cognito_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_user_id(self, code):
if claims["email"] and (
claims["email_verified"] or self.settings["assume_emails_verified"]
):
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
elif claims["email"]:
return {"error": "Email is not verified"}
else:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/google_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["email"] and claims["email_verified"]:
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
elif claims["email"]:
return {"error": "Email is not verified"}
else:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/microsoft_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims.get("email"):
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
return {"error": "Can't get user's Microsoft email!"}
except Exception as exception:
self.logger.exception("Can't get user info")
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/okta_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["email"]:
return {"email": claims["email"]}
return {"email": claims["email"], "sub": claims.get("sub")}
else:
return {"error": "Can't get user's email!"}
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/orcid_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_user_id(self, code):
claims = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code)

if claims["sub"]:
return {"orcid": claims["sub"]}
return {"orcid": claims["sub"], "sub": claims["sub"]}
else:
return {"error": "Can't get user's orcid"}
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def get_user_id(self, code):
self.logger.exception("{}: {}".format(err_msg, e))
return {"error": err_msg}

return {"username": username, "email": userinfo.get("email")}
return {
"username": username,
"email": userinfo.get("email"),
"sub": userinfo.get("sub"),
}

def refresh_cronjob_pkey_cache(self, issuer, kid, pkey_cache):
"""
Expand Down
24 changes: 19 additions & 5 deletions tests/login/test_login_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,25 @@
def test_login_user_already_in_db(db_session):
"""
Test that if a user is already in the database and logs in, the session will contain
the user's information.
the user's information (including additional information that may have been provided
during the login like email and id_from_idp)
"""
email = "[email protected]"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

test_user = User(username=email, is_admin=False)
db_session.add(test_user)
db_session.commit()
user_id = str(test_user.id)
assert not test_user.email
assert not test_user.id_from_idp

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == user_id
Expand All @@ -33,18 +39,23 @@ def test_login_user_with_idp_already_in_db(db_session):
"""
email = "[email protected]"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

test_user = User(username=email, is_admin=False)
test_user = User(
username=email, email=email, id_from_idp=id_from_idp, is_admin=False
)
test_idp = IdentityProvider(name=provider)
test_user.identity_provider = test_idp

db_session.add(test_user)
db_session.commit()
user_id = str(test_user.id)

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == user_id
Expand All @@ -58,12 +69,15 @@ def test_login_new_user(db_session):
"""
email = "[email protected]"
provider = "Test Provider"
id_from_idp = "Provider_ID_0001"

login_user(email, provider)
login_user(email, provider, email=email, id_from_idp=id_from_idp)

test_user = db_session.query(User).filter(User.username == email.lower()).first()

assert test_user.identity_provider.name == provider
assert test_user.id_from_idp == id_from_idp
assert test_user.email == email
assert flask.session["username"] == email
assert flask.session["provider"] == provider
assert flask.session["user_id"] == str(test_user.id)
Expand Down
3 changes: 2 additions & 1 deletion tests/login/test_microsoft_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_get_user_id(microsoft_oauth2_client):
return_value=return_value,
):
user_id = microsoft_oauth2_client.get_user_id(code="123")
assert user_id == expected_value # nosec
for key, value in expected_value.items():
assert return_value[key] == value


def test_get_user_id_missing_claim(microsoft_oauth2_client):
Expand Down

0 comments on commit 56272b7

Please sign in to comment.