Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fastui #9

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

app.include_router(api.router)
app.include_router(view.router)
app.include_router(view.admin.api_rotuer)

# prometheus
from prometheus_fastapi_instrumentator import Instrumentator # noqa
Expand Down
8 changes: 4 additions & 4 deletions app/api/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ async def admin_checker(
user: auth.CURR_USER_SAFE,
token_header: str | None = Header(None, alias="X-Token"),
token_query: str | None = Query(None, alias="token"),
) -> schema.User:
) -> UserDB:
if user and user.is_admin:
return user

if token_header and token_header == settings.API_TOKEN:
return _fake_admin_user
return _fake_admin_user # type: ignore
if token_query and token_query == settings.API_TOKEN:
return _fake_admin_user
return _fake_admin_user # type: ignore

raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="No.",
)


CURR_ADMIN = Annotated[schema.User, Depends(admin_checker)]
CURR_ADMIN = Annotated[UserDB, Depends(admin_checker)]

logger = get_logger("api.admin")
router = APIRouter(
Expand Down
3 changes: 2 additions & 1 deletion app/api/admin/admin_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from typing import Annotated, Mapping
from collections.abc import Mapping
from typing import Annotated

from beanie import BulkWriter
from beanie.operators import Set
Expand Down
36 changes: 18 additions & 18 deletions app/api/admin/admin_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def api_admin_users_internal() -> Mapping[uuid.UUID, schema.User]:
return all_users


async def api_admin_user_get_internal(user_id: uuid.UUID) -> UserDB:
async def get_user(user_id: uuid.UUID) -> UserDB:
user = await UserDB.find_by_user_uuid(user_id)
if not user:
raise HTTPException(
Expand All @@ -27,29 +27,33 @@ async def api_admin_user_get_internal(user_id: uuid.UUID) -> UserDB:
return user


CURR_USER = Annotated[UserDB, Depends(get_user)]


class PasswordChangeForm(BaseModel):
new_password: str


@router.get("/user/{user_id}")
async def api_admin_user(user_id: uuid.UUID, user: CURR_ADMIN) -> schema.User.admin_model:
ret_user = await api_admin_user_get_internal(user_id)
return ret_user
async def api_admin_user(admin: CURR_ADMIN, user: CURR_USER) -> schema.User.admin_model:
return user


@router.post("/user/{user_id}")
async def api_admin_user_edit(new_user: schema.User, user_id: uuid.UUID, user: CURR_ADMIN) -> schema.User.admin_model:
new_user = await db.update_user_admin(user_id, new_user)
async def api_admin_user_edit(new_user: schema.User, user_id: uuid.UUID, admin: CURR_ADMIN) -> schema.User.admin_model:
# new_user = await db.update_user_admin(user_id, new_user)
raise Exception

return new_user


@router.get("/users/me")
async def api_admin_users_me(user: CURR_ADMIN) -> schema.User.admin_model:
return user
async def api_admin_users_me(admin: CURR_ADMIN) -> schema.User.admin_model:
return admin


@router.get("/users")
async def api_admin_users(user: CURR_ADMIN) -> Mapping[uuid.UUID, schema.User.admin_model]:
async def api_admin_users(admin: CURR_ADMIN) -> Mapping[uuid.UUID, schema.User.admin_model]:
all_users = await api_admin_users_internal()
return all_users

Expand All @@ -58,7 +62,7 @@ async def api_admin_users(user: CURR_ADMIN) -> Mapping[uuid.UUID, schema.User.ad
async def api_admin_user_edit_password(
new_password: PasswordChangeForm,
admin: CURR_ADMIN,
user: schema.User = Depends(api_admin_user_get_internal),
user: CURR_USER,
) -> schema.User.admin_model:
au = user.auth_source
if not isinstance(au, schema.auth.SimpleAuth.AuthModel):
Expand All @@ -73,7 +77,7 @@ async def api_admin_user_edit_password(
@router.get("/user/{user_id}/score")
async def api_admin_user_recalc_score(
admin: CURR_ADMIN,
user: UserDB = Depends(api_admin_user_get_internal),
user: CURR_USER,
) -> schema.User.admin_model:
await user.recalc_score_one()
return user
Expand All @@ -82,18 +86,14 @@ async def api_admin_user_recalc_score(
@router.delete("/user/{user_id}")
async def api_admin_user_delete(
admin: CURR_ADMIN,
user: schema.User = Depends(api_admin_user_get_internal),
user: CURR_USER,
) -> str:
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="user not exist",
)
raise Exception

if len(user.solved_tasks) > 0:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user have solved tasks",
)
await db.delete_user(user)
# await db.delete_user(user)
return "deleted"
4 changes: 2 additions & 2 deletions app/api/api_users.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from collections.abc import Sequence
from typing import Iterable, TypeVar
from collections.abc import Iterable, Sequence
from typing import TypeVar

from fastapi import APIRouter, Depends, HTTPException, Request, Response, status

Expand Down
6 changes: 3 additions & 3 deletions app/cli/cmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from . import load as load_cmds
from . import stress as stress_cmds

get_cmds = get_cmds
stress_cmds = stress_cmds
load_cmds = load_cmds
get_cmds = get_cmds # noqa: PLW0127
stress_cmds = stress_cmds # noqa: PLW0127
load_cmds = load_cmds # noqa: PLW0127

tasks_to_create: list[RawTask] = [
RawTask(
Expand Down
4 changes: 3 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class DefaultTokenError(ValueError):

class Settings(BaseSettings):
DEBUG: bool = False
TESTING: bool = False

PROFILING: bool = False

TOKEN_PATH: str = "/api/users/login"
Expand Down Expand Up @@ -65,7 +67,7 @@ class Settings(BaseSettings):

@model_validator(mode="after")
def check_non_default_tokens(self) -> Self:
if self.DEBUG:
if self.DEBUG or self.TESTING:
return self

token_check_list = ["JWT_SECRET_KEY", "FLAG_SIGN_KEY", "API_TOKEN", "WS_API_TOKEN"]
Expand Down
9 changes: 6 additions & 3 deletions app/db/beanie.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,18 @@ def __init__(self) -> None:
async def init(self) -> None:
self.client = AsyncIOMotorClient(str(settings.MONGO), tz_aware=True)
self.db = self.client[settings.DB_NAME]
await init_beanie(database=self.db, document_models=[TaskDB, UserDB]) # type: ignore # bad library ;(
await init_beanie(
database=self.db,
document_models=[TaskDB, UserDB]
)
logger.info("Beanie init ok")

async def close(self) -> None:
logger.info("DB close ok")

async def reset_db(self) -> None:
if not settings.DEBUG:
logger.warning("DB Reset without debug")
if not (settings.DEBUG or settings.TESTING):
logger.warning(f"DB Reset without debug ({settings.DEBUG = }) or testing {settings.TESTING = }")
return

await self.client.drop_database(settings.DB_NAME)
Expand Down
3 changes: 2 additions & 1 deletion app/schema/auth/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def check_password(self, model: "SimpleAuth.AuthModel") -> bool:
def check_valid(self) -> bool:
if settings.DEBUG:
return True

if (
len(self.internal.username) < SimpleAuth.auth_settings.MIN_USERNAME_LEN
or len(self.internal.username) > SimpleAuth.auth_settings.MAX_USERNAME_LEN
):
return False
if len(self.internal.password) < SimpleAuth.auth_settings.MIN_PASSWORD_LEN:
if len(self.internal.password) < SimpleAuth.auth_settings.MIN_PASSWORD_LEN: # noqa: SIM103
return False
return True

Expand Down
2 changes: 1 addition & 1 deletion app/schema/ebasemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def build_model( # noqa: PLR0912, C901 # WTF: refactor & simplify

new_union = Union[tuple(new_union_base)] # type: ignore # noqa: UP007, PGH003 # так надо.

target_fields[field_name] = (
target_fields[field_name] = ( # type: ignore # WTF: ???
new_union,
field_value,
)
Expand Down
14 changes: 13 additions & 1 deletion app/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated, ClassVar
from zoneinfo import ZoneInfo

from pydantic import Field
from pydantic import Field, computed_field

from .. import config
from ..config import settings
Expand Down Expand Up @@ -52,6 +52,8 @@ class Task(EBaseModel):
"description",
"flag",
"hidden",
"points",
"solves",
}

task_id: uuid.UUID = Field(default_factory=uuid.uuid4)
Expand Down Expand Up @@ -147,6 +149,16 @@ def first_pwned_str(self) -> tuple[uuid.UUID, str]:
def short_desc(self) -> str:
return f"task_id={self.task_id} task_name={self.task_name} hidden={self.hidden} points={self.scoring.points}"

@computed_field
@property
def points(self) -> int:
return self.scoring.points

@computed_field
@property
def solves(self) -> int:
return len(self.pwned_by)


class TaskForm(EBaseModel):
task_name: str
Expand Down
10 changes: 10 additions & 0 deletions app/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: S101, S106, ANN201, T201 # this is a __test file__

import typing
from contextlib import contextmanager

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -127,5 +128,14 @@ def client(request):
client.__exit__()


@contextmanager
def enable_debug() -> typing.Generator[None, typing.Any, None]:
settings.DEBUG = True
try:
yield
finally:
settings.DEBUG = False


# from . import test_auth # noqa
# from . import test_main # noqa
46 changes: 27 additions & 19 deletions app/test/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import status

from .. import config, schema
from . import ClientEx, app
from . import ClientEx, app, enable_debug
from . import client as client_cl

client = client_cl
Expand All @@ -12,65 +12,73 @@


def test_register(client: ClientEx):
resp = client.simple_register_raw(username="Rubikoid", password="123")
resp = client.simple_register_raw(username="Rubikoid", password="123456789")

assert resp.status_code == status.HTTP_200_OK, resp.text
assert resp.text == '"ok"', resp.text


def test_login(client: ClientEx):
test_register(client)
resp = client.simple_login_raw(username="Rubikoid", password="123")
resp = client.simple_login_raw(username="Rubikoid", password="123456789")

assert resp.status_code == status.HTTP_200_OK, resp.text
assert resp.text == '"ok"', resp.text


def test_admin(client: ClientEx):
# config.settings.DEBUG = True
test_login(client)
# config.settings.DEBUG = False
# need to enable debug here, because `Rubikoid-as-default-admin` is debug feature
with enable_debug():
test_login(client)

resp = client.get(app.url_path_for("api_admin_users_me"))
# print(resp.json())
assert resp.status_code == status.HTTP_200_OK, resp.text
assert resp.json()["is_admin"] is True, resp.json()
assert resp.json()["username"] == "Rubikoid", resp.json()


def test_not_admin_without_debug(client: ClientEx):
test_login(client)
resp = client.get(app.url_path_for("api_admin_users_me"))
assert resp.status_code == status.HTTP_403_FORBIDDEN, resp.text


def test_admin_fail(client: ClientEx):
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123")
assert resp1.status_code == status.HTTP_200_OK, resp1.text
assert resp1.text == '"ok"', resp1.text
with enable_debug():
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123456789")
assert resp1.status_code == status.HTTP_200_OK, resp1.text
assert resp1.text == '"ok"', resp1.text

resp2 = client.simple_login_raw(username="Not_Rubikoid", password="123")
assert resp2.status_code == status.HTTP_200_OK, resp2.text
assert resp2.text == '"ok"', resp2.text
resp2 = client.simple_login_raw(username="Not_Rubikoid", password="123456789")
assert resp2.status_code == status.HTTP_200_OK, resp2.text
assert resp2.text == '"ok"', resp2.text

resp3 = client.get(app.url_path_for("api_admin_users_me"))
assert resp3.status_code == status.HTTP_403_FORBIDDEN, resp3.text
resp3 = client.get(app.url_path_for("api_admin_users_me"))
assert resp3.status_code == status.HTTP_403_FORBIDDEN, resp3.text


def test_not_existing_user(client: ClientEx):
resp1 = client.post(
app.url_path_for("api_auth_simple_login"),
json=LoginForm(username="Not_Existing_Account", password="123").model_dump(mode="json"),
json=LoginForm(username="Not_Existing_Account", password="123456789").model_dump(mode="json"),
)
assert resp1.status_code == status.HTTP_401_UNAUTHORIZED, resp1.text


def test_invalid_password(client: ClientEx):
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123")
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123456789")
assert resp1.status_code == status.HTTP_200_OK, resp1.text
assert resp1.text == '"ok"', resp1.text

resp2 = client.simple_login_raw(username="Not_Rubikoid", password="1234")
resp2 = client.simple_login_raw(username="Not_Rubikoid", password="1234567890")
assert resp2.status_code == status.HTTP_401_UNAUTHORIZED, resp2.text


def test_register_existing_user(client: ClientEx):
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123")
resp1 = client.simple_register_raw(username="Not_Rubikoid", password="123456789")
assert resp1.status_code == status.HTTP_200_OK, resp1.text
assert resp1.text == '"ok"', resp1.text

resp2 = client.simple_register_raw(username="Not_Rubikoid", password="1234")
resp2 = client.simple_register_raw(username="Not_Rubikoid", password="1234567890")
assert resp2.status_code == status.HTTP_403_FORBIDDEN, resp2.text
10 changes: 2 additions & 8 deletions app/test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import uuid

import pytest

from .. import schema
from . import TestClient, app
from . import TestClient
from . import client as client_cl
from . import test_auth

client = client_cl


def test_read_main(client: TestClient):
def test_read_main(client: TestClient) -> None:
resp = client.get("/")
assert resp.status_code == 200
Loading