Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add disallow_route and disallow_startup decorators and tests #135

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/api/v1/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from app.config import ExecutionMode, settings
from app.core.types import GatewayMode
from app.errors import ErrorCode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()
logger = logging.getLogger("ClimateToken")


@router.get("/", response_model=schemas.ActivitiesResponse)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_activity(
search: Optional[str] = None,
search_by: Optional[schemas.ActivitySearchBy] = None,
Expand Down Expand Up @@ -135,7 +135,7 @@ async def get_activity(


@router.get("/activity-record", response_model=schemas.ActivityRecordResponse)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_activity_by_cw_unit_id(
cw_unit_id: str,
coin_id: str,
Expand Down
8 changes: 4 additions & 4 deletions app/api/v1/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from app.db.session import get_engine_cls
from app.errors import ErrorCode
from app.models import State
from app.utils import disallow
from app.utils import disallow_startup

router = APIRouter()
errorcode = ErrorCode()
Expand All @@ -30,7 +30,7 @@


@router.on_event("startup")
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def init_db() -> None:
Engine = await get_engine_cls()

Expand Down Expand Up @@ -143,7 +143,7 @@ async def _scan_token_activity(

@router.on_event("startup")
@repeat_every(seconds=60, logger=logger)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def scan_token_activity() -> None:
if lock.locked():
return
Expand Down Expand Up @@ -194,7 +194,7 @@ async def _scan_blockchain_state(

@router.on_event("startup")
@repeat_every(seconds=10, logger=logger)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def scan_blockchain_state() -> None:
async with (
deps.get_db_session_context() as db,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from app import schemas
from app.api import dependencies as deps
from app.config import ExecutionMode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()

Expand All @@ -26,7 +26,7 @@
"/",
response_model=schemas.Key,
)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.EXPLORER]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.EXPLORER])
async def get_key(
hardened: bool = False,
derivation_index: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/organizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from app import crud
from app.config import ExecutionMode, settings
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()


# pass through resource to expose organization data from cadt
@router.get("/", response_model=Any)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_organizations() -> Any:
all_organizations = crud.ClimateWareHouseCrud(
url=settings.CADT_API_SERVER_HOST,
Expand Down
12 changes: 6 additions & 6 deletions app/api/v1/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from app.core import utils
from app.core.climate_wallet.wallet import ClimateWallet
from app.core.types import ClimateTokenIndex, GatewayMode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()
logger = logging.getLogger("ClimateToken")
Expand All @@ -30,7 +30,7 @@
"/",
response_model=schemas.TokenizationTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def create_tokenization_tx(
request: schemas.TokenizationTxRequest,
wallet_rpc_client: WalletRpcClient = Depends(deps.get_wallet_rpc_client),
Expand Down Expand Up @@ -110,7 +110,7 @@ async def create_tokenization_tx(
"/{asset_id}/detokenize",
response_model=schemas.DetokenizationTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def create_detokenization_tx(
asset_id: str,
request: schemas.DetokenizationTxRequest,
Expand Down Expand Up @@ -150,7 +150,7 @@ async def create_detokenization_tx(
"/{asset_id}/request-detokenization",
response_model=schemas.DetokenizationFileResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY])
async def create_detokenization_file(
asset_id: str,
request: schemas.DetokenizationFileRequest,
Expand Down Expand Up @@ -221,7 +221,7 @@ async def create_detokenization_file(
"/parse-detokenization",
response_model=schemas.DetokenizationFileParseResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def parse_detokenization_file(
content: str,
) -> schemas.DetokenizationFileParseResponse:
Expand Down Expand Up @@ -257,7 +257,7 @@ async def parse_detokenization_file(
"/{asset_id}/permissionless-retire",
response_model=schemas.PermissionlessRetirementTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY])
async def create_permissionless_retirement_tx(
asset_id: str,
request: schemas.PermissionlessRetirementTxRequest,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from app.core.chialisp.gateway import create_gateway_puzzle, parse_gateway_spend
from app.core.types import CLIMATE_WALLET_INDEX, GatewayMode
from app.schemas.types import ChiaJsonObject
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()

Expand All @@ -33,7 +33,7 @@
"/{transaction_id}",
response_model=schemas.Transaction,
)
@disallow([ExecutionMode.EXPLORER]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER])
async def get_transaction(
transaction_id: str,
wallet_rpc_client: WalletRpcClient = Depends(deps.get_wallet_rpc_client),
Expand Down
81 changes: 44 additions & 37 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,65 @@
from __future__ import annotations

import functools
import logging
import os
import time
from typing import Callable
from collections.abc import Coroutine
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar

from fastapi import status
from fastapi import HTTPException

from app.config import ExecutionMode, settings

logger = logging.getLogger("ClimateToken")

# from typing import Any, Callable, Concatenate, Coroutine, List, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")


# P = ParamSpec("P")
# R = TypeVar("R")


# def disallow(
# modes: List[ExecutionMode],
# ) -> Callable[[Callable[Concatenate[P], Coroutine[Any, Any, R]]], Callable[Concatenate[P], Coroutine[Any, Any, R]],]:
# def decorator(
# f: Callable[Concatenate[P], Coroutine[Any, Any, R]]
# ) -> Callable[Concatenate[P], Coroutine[Any, Any, R]]:
# if settings.MODE in modes:

# async def not_allowed(*args: P.args, **kwargs: P.kwargs) -> Any:
# return status.HTTP_405_METHOD_NOT_ALLOWED

# return not_allowed

# return f

# return decorator


def disallow(modes: list[ExecutionMode]): # type: ignore[no-untyped-def]
def _disallow(f: Callable): # type: ignore[no-untyped-def, type-arg]
def disallow_route(
modes: list[ExecutionMode],
) -> Callable[
[Callable[Concatenate[P], Coroutine[Any, Any, R]]],
Callable[Concatenate[P], Coroutine[Any, Any, R]],
]:
def decorator(
f: Callable[Concatenate[P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[P], Coroutine[Any, Any, R]]:
if settings.MODE in modes:
# P.args & P.kwargs don't seem to work with fastapi decorators
# see https://github.com/PrefectHQ/marvin/issues/625
# putting any parameters here in this call results in fastapi always returning 422
# if the request doesn't have the proper query parameters (literally "args" and "kwargs")
# using () allows the 405 error to be returned by fastapi regardless of query parameters
async def not_allowed() -> Any:
raise HTTPException(status_code=405, detail="Method not allowed")

return not_allowed # type: ignore[return-value]

return f

return decorator


def disallow_startup(
modes: list[ExecutionMode],
) -> Callable[
[Callable[Concatenate[P], Coroutine[Any, Any, None]]],
Callable[Concatenate[P], Coroutine[Any, Any, None]],
]:
def decorator(
f: Callable[Concatenate[P], Coroutine[Any, Any, None]],
) -> Callable[Concatenate[P], Coroutine[Any, Any, None]]:
if settings.MODE in modes:

async def _f(*args, **kargs): # type: ignore[no-untyped-def]
return status.HTTP_405_METHOD_NOT_ALLOWED

else:
async def not_allowed(*args: P.args, **kwargs: P.kwargs) -> None:
return

@functools.wraps(f)
async def _f(*args, **kargs): # type: ignore[no-untyped-def]
return await f(*args, **kargs)
return not_allowed

return _f
return f

return _disallow
return decorator


def wait_until_dir_exists(path: str, interval: int = 1) -> None:
Expand Down
Loading
Loading