diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 51e23ca18..0aef70c40 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -56,7 +56,6 @@ from supertokens_python.recipe.session.constants import ANTI_CSRF_HEADER_KEY from supertokens_python.utils import ( is_version_gte, - set_request_in_user_context_if_not_defined, ) from tests.utils import ( TEST_ACCESS_TOKEN_MAX_AGE_CONFIG_KEY, @@ -1326,19 +1325,14 @@ async def send_email( nonlocal email_verify_link email_verify_link = template_vars.email_verify_link - def get_origin( - req: Optional[BaseRequest], user_context: Optional[Dict[str, Any]] - ) -> str: - if req is not None: - set_request_in_user_context_if_not_defined(user_context, req) - return user_context["url"] # type: ignore + def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: + return user_context["url"] init( supertokens_config=SupertokensConfig("http://localhost:3567"), app_info=InputAppInfo( app_name="SuperTokens Demo", api_domain="http://api.supertokens.io", - website_domain=None, origin=get_origin, api_base_path="/auth", ), @@ -1354,11 +1348,6 @@ def get_origin( ) start_st() - version = await Querier.get_instance().get_api_version() - if not is_version_gte(version, "2.9"): - # If the version less than 2.9, the recipe doesn't exist. So skip the test - skip() - response_1 = sign_up_request(driver_config_client, "test@gmail.com", "testPass123") assert response_1.status_code == 200 dict_response = json.loads(response_1.text) diff --git a/tests/emailpassword/test_passwordreset.py b/tests/emailpassword/test_passwordreset.py index 632c2ab04..60ead26ef 100644 --- a/tests/emailpassword/test_passwordreset.py +++ b/tests/emailpassword/test_passwordreset.py @@ -409,8 +409,10 @@ async def test_reset_password_link_uses_correct_origin( password_reset_url = "" def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str: - if req is not None and req.get_header("origin") is not None: - return req.get_header("origin") # type: ignore + if req is not None: + value = req.get_header("origin") + if value is not None: + return value return "localhost:3000" class CustomEmailService( diff --git a/tests/passwordless/test_emaildelivery.py b/tests/passwordless/test_emaildelivery.py index ad5575f02..95d289c01 100644 --- a/tests/passwordless/test_emaildelivery.py +++ b/tests/passwordless/test_emaildelivery.py @@ -540,14 +540,16 @@ async def send_email_override( @mark.asyncio -async def test_reset_password_link_uses_correct_origin( +async def test_magic_link_uses_correct_origin( driver_config_client: TestClient, ): login_url = "" - def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str: - if req is not None and req.get_header("origin") is not None: - return req.get_header("origin") # type: ignore + def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str: + if req is not None: + value = req.get_header("origin") + if value is not None: + return value return "localhost:3000" class CustomEmailService( diff --git a/tests/test_config.py b/tests/test_config.py index 80b711793..653cf32f7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,12 +13,14 @@ # under the License. from pytest import mark from unittest.mock import MagicMock -from supertokens_python import InputAppInfo, SupertokensConfig, init +from supertokens_python import InputAppInfo, SupertokensConfig, init, Supertokens from supertokens_python.normalised_url_domain import NormalisedURLDomain from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe import session from supertokens_python.recipe.session import SessionRecipe from supertokens_python.recipe.session.asyncio import create_new_session +from typing import Optional, Dict, Any +from supertokens_python.framework import BaseRequest from tests.utils import clean_st, reset, setup_st, start_st @@ -814,3 +816,139 @@ async def test_cookie_samesite_with_ec2_public_url(): assert SessionRecipe.get_instance().config.cookie_domain is None assert SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "lax" assert SessionRecipe.get_instance().config.cookie_secure is False + + +@mark.asyncio +async def test_samesite_explicit_config(): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + origin="http://localhost:3000", + api_domain="http://localhost:3001", + ), + framework="fastapi", + recipe_list=[ + session.init( + cookie_same_site="strict", + ) + ], + ) + assert ( + SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "strict" + ) + + +@mark.asyncio +async def test_that_exception_is_thrown_if_website_domain_and_origin_are_not_passed(): + try: + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://localhost:3001", + ), + framework="fastapi", + recipe_list=[session.init()], + ) + except Exception as e: + assert str(e) == "Please provide at least one of website_domain or origin" + else: + assert False, "Exception not thrown" + + +@mark.asyncio +async def test_that_init_works_fine_when_using_origin_string(): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://localhost:3001", + origin="localhost:3000", + ), + framework="fastapi", + recipe_list=[session.init()], + ) + + assert ( + Supertokens.get_instance() + .app_info.get_origin(None, {}) + .get_as_string_dangerous() + == "http://localhost:3000" + ) + + +@mark.asyncio +async def test_that_init_works_fine_when_using_website_domain_string(): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://localhost:3001", + website_domain="localhost:3000", + ), + framework="fastapi", + recipe_list=[session.init()], + ) + + assert ( + Supertokens.get_instance() + .app_info.get_origin(None, {}) + .get_as_string_dangerous() + == "http://localhost:3000" + ) + + +@mark.asyncio +async def test_that_init_works_fine_when_using_origin_function(): + def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: + if "input" in user_context: + return user_context["input"] + return "localhost:3000" + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://localhost:3001", + origin=get_origin, + ), + framework="fastapi", + recipe_list=[session.init()], + ) + + assert ( + Supertokens.get_instance() + .app_info.get_origin(None, {"input": "localhost:1000"}) + .get_as_string_dangerous() + == "http://localhost:1000" + ) + + assert ( + Supertokens.get_instance() + .app_info.get_origin(None, {}) + .get_as_string_dangerous() + == "http://localhost:3000" + ) + + +@mark.asyncio +async def test_that_init_chooses_origin_over_website_domain(): + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://localhost:3001", + website_domain="localhost:3000", + origin="supertokens.io", + ), + framework="fastapi", + recipe_list=[session.init()], + ) + + assert ( + Supertokens.get_instance() + .app_info.get_origin(None, {}) + .get_as_string_dangerous() + == "https://supertokens.io" + ) diff --git a/tests/test_session.py b/tests/test_session.py index b670b2e14..715e0dbf0 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -828,9 +828,11 @@ async def test_expose_access_token_to_frontend_in_cookie_based_auth( async def test_token_transfer_method_works_when_using_origin_function( driver_config_client: TestClient, ): - def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str: - if req is not None and req.get_header("origin") is not None: - return req.get_header("origin") # type: ignore + def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str: + if req is not None: + value = req.get_header("origin") + if value is not None: + return value return "localhost:3000" def token_transfer_method(req: BaseRequest, _: bool, __: Dict[str, Any]):