Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Nov 28, 2023
1 parent dc8c897 commit 8c871c3
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 23 deletions.
15 changes: 2 additions & 13 deletions tests/emailpassword/test_emailverify.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from supertokens_python.recipe.session.constants import ANTI_CSRF_HEADER_KEY
from supertokens_python.utils import (
is_version_gte,
set_request_in_user_context_if_not_defined,
)
from tests.utils import (
TEST_ACCESS_TOKEN_MAX_AGE_CONFIG_KEY,
Expand Down Expand Up @@ -1326,19 +1325,14 @@ async def send_email(
nonlocal email_verify_link
email_verify_link = template_vars.email_verify_link

def get_origin(
req: Optional[BaseRequest], user_context: Optional[Dict[str, Any]]
) -> str:
if req is not None:
set_request_in_user_context_if_not_defined(user_context, req)
return user_context["url"] # type: ignore
def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str:
return user_context["url"]

init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://api.supertokens.io",
website_domain=None,
origin=get_origin,
api_base_path="/auth",
),
Expand All @@ -1354,11 +1348,6 @@ def get_origin(
)
start_st()

version = await Querier.get_instance().get_api_version()
if not is_version_gte(version, "2.9"):
# If the version less than 2.9, the recipe doesn't exist. So skip the test
skip()

response_1 = sign_up_request(driver_config_client, "[email protected]", "testPass123")
assert response_1.status_code == 200
dict_response = json.loads(response_1.text)
Expand Down
6 changes: 4 additions & 2 deletions tests/emailpassword/test_passwordreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,10 @@ async def test_reset_password_link_uses_correct_origin(
password_reset_url = ""

def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
if req is not None and req.get_header("origin") is not None:
return req.get_header("origin") # type: ignore
if req is not None:
value = req.get_header("origin")
if value is not None:
return value
return "localhost:3000"

class CustomEmailService(
Expand Down
10 changes: 6 additions & 4 deletions tests/passwordless/test_emaildelivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,16 @@ async def send_email_override(


@mark.asyncio
async def test_reset_password_link_uses_correct_origin(
async def test_magic_link_uses_correct_origin(
driver_config_client: TestClient,
):
login_url = ""

def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
if req is not None and req.get_header("origin") is not None:
return req.get_header("origin") # type: ignore
def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str:
if req is not None:
value = req.get_header("origin")
if value is not None:
return value
return "localhost:3000"

class CustomEmailService(
Expand Down
140 changes: 139 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# under the License.
from pytest import mark
from unittest.mock import MagicMock
from supertokens_python import InputAppInfo, SupertokensConfig, init
from supertokens_python import InputAppInfo, SupertokensConfig, init, Supertokens
from supertokens_python.normalised_url_domain import NormalisedURLDomain
from supertokens_python.normalised_url_path import NormalisedURLPath
from supertokens_python.recipe import session
from supertokens_python.recipe.session import SessionRecipe
from supertokens_python.recipe.session.asyncio import create_new_session
from typing import Optional, Dict, Any
from supertokens_python.framework import BaseRequest

from tests.utils import clean_st, reset, setup_st, start_st

Expand Down Expand Up @@ -814,3 +816,139 @@ async def test_cookie_samesite_with_ec2_public_url():
assert SessionRecipe.get_instance().config.cookie_domain is None
assert SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "lax"
assert SessionRecipe.get_instance().config.cookie_secure is False


@mark.asyncio
async def test_samesite_explicit_config():
init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
origin="http://localhost:3000",
api_domain="http://localhost:3001",
),
framework="fastapi",
recipe_list=[
session.init(
cookie_same_site="strict",
)
],
)
assert (
SessionRecipe.get_instance().config.get_cookie_same_site(None, {}) == "strict"
)


@mark.asyncio
async def test_that_exception_is_thrown_if_website_domain_and_origin_are_not_passed():
try:
init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://localhost:3001",
),
framework="fastapi",
recipe_list=[session.init()],
)
except Exception as e:
assert str(e) == "Please provide at least one of website_domain or origin"
else:
assert False, "Exception not thrown"


@mark.asyncio
async def test_that_init_works_fine_when_using_origin_string():
init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://localhost:3001",
origin="localhost:3000",
),
framework="fastapi",
recipe_list=[session.init()],
)

assert (
Supertokens.get_instance()
.app_info.get_origin(None, {})
.get_as_string_dangerous()
== "http://localhost:3000"
)


@mark.asyncio
async def test_that_init_works_fine_when_using_website_domain_string():
init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://localhost:3001",
website_domain="localhost:3000",
),
framework="fastapi",
recipe_list=[session.init()],
)

assert (
Supertokens.get_instance()
.app_info.get_origin(None, {})
.get_as_string_dangerous()
== "http://localhost:3000"
)


@mark.asyncio
async def test_that_init_works_fine_when_using_origin_function():
def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str:
if "input" in user_context:
return user_context["input"]
return "localhost:3000"

init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://localhost:3001",
origin=get_origin,
),
framework="fastapi",
recipe_list=[session.init()],
)

assert (
Supertokens.get_instance()
.app_info.get_origin(None, {"input": "localhost:1000"})
.get_as_string_dangerous()
== "http://localhost:1000"
)

assert (
Supertokens.get_instance()
.app_info.get_origin(None, {})
.get_as_string_dangerous()
== "http://localhost:3000"
)


@mark.asyncio
async def test_that_init_chooses_origin_over_website_domain():
init(
supertokens_config=SupertokensConfig("http://localhost:3567"),
app_info=InputAppInfo(
app_name="SuperTokens Demo",
api_domain="http://localhost:3001",
website_domain="localhost:3000",
origin="supertokens.io",
),
framework="fastapi",
recipe_list=[session.init()],
)

assert (
Supertokens.get_instance()
.app_info.get_origin(None, {})
.get_as_string_dangerous()
== "https://supertokens.io"
)
8 changes: 5 additions & 3 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,11 @@ async def test_expose_access_token_to_frontend_in_cookie_based_auth(
async def test_token_transfer_method_works_when_using_origin_function(
driver_config_client: TestClient,
):
def get_origin(req: Optional[BaseRequest], _: Optional[Dict[str, Any]]) -> str:
if req is not None and req.get_header("origin") is not None:
return req.get_header("origin") # type: ignore
def get_origin(req: Optional[BaseRequest], _: Dict[str, Any]) -> str:
if req is not None:
value = req.get_header("origin")
if value is not None:
return value
return "localhost:3000"

def token_transfer_method(req: BaseRequest, _: bool, __: Dict[str, Any]):
Expand Down

0 comments on commit 8c871c3

Please sign in to comment.