diff --git a/ictv/pages/shibboleth.py b/ictv/pages/shibboleth.py index ea718fc..18e29b6 100644 --- a/ictv/pages/shibboleth.py +++ b/ictv/pages/shibboleth.py @@ -21,6 +21,7 @@ from onelogin.saml2.auth import OneLogin_Saml2_Auth from onelogin.saml2.utils import OneLogin_Saml2_Utils +from urllib.parse import urlparse from ictv.models.user import User from ictv.pages.utils import ICTVPage @@ -65,20 +66,19 @@ def init_saml_auth(req, settings): def prepare_request(): # If server is behind proxys or balancers use the HTTP_X_FORWARDED fields - data = Storage(flask.request.form) + url_data = urlparse(flask.request.url) return { 'https': 'on' if flask.request.scheme == 'https' else 'off', - 'http_host': flask.request.environ["SERVER_NAME"], - 'server_port': flask.request.environ["SERVER_PORT"], - 'script_name': flask.g.homepath, - 'get_data': data.copy(), - 'post_data': data.copy(), + 'http_host': flask.request.host, + 'server_port': url_data.port, + 'script_name': flask.request.path, + 'get_data': flask.request.args.copy(), + 'post_data': flask.request.form.copy(), # Uncomment if using ADFS as IdP, https://github.com/onelogin/python-saml/pull/144 # 'lowercase_urlencoding': True, - 'query_string': flask.g.query + 'query_string': flask.request.query_string } - class MetadataPage(ICTVPage): def get(self): req = prepare_request() @@ -105,12 +105,12 @@ def get(self): not_auth_warn = False success_slo = False - input_data = self.form + input_data = flask.request.args if 'sso' in input_data: resp.seeother(auth.login()) - resp.seeother('/') + return resp.seeother('/') def post(self): """ @@ -132,9 +132,9 @@ def post(self): not_auth_warn = False success_slo = False - input_data = self.form + input_data = flask.request.form - if 'acs' in input_data: + if 'acs' in flask.request.args: auth.process_response() # decrypt and extract informations errors = auth.get_errors() not_auth_warn = not auth.is_authenticated() @@ -161,6 +161,6 @@ def post(self): self_url = OneLogin_Saml2_Utils.get_self_url(req) if 'RelayState' in input_data and self_url != input_data['RelayState']: - resp.seeother(auth.redirect_to(input_data['RelayState'])) + return resp.seeother(auth.redirect_to(input_data['RelayState'])) - resp.seeother('/') + return resp.seeother('/')