Skip to content

Commit

Permalink
Merge pull request #135 from Chia-Network/EL.fix-disallow-types
Browse files Browse the repository at this point in the history
refactor: add disallow_route and disallow_startup decorators and tests
  • Loading branch information
TheLastCicada authored Jan 8, 2025
2 parents d1c8f54 + 0bf2842 commit 86c206f
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 61 deletions.
6 changes: 3 additions & 3 deletions app/api/v1/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from app.config import ExecutionMode, settings
from app.core.types import GatewayMode
from app.errors import ErrorCode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()
logger = logging.getLogger("ClimateToken")


@router.get("/", response_model=schemas.ActivitiesResponse)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_activity(
search: Optional[str] = None,
search_by: Optional[schemas.ActivitySearchBy] = None,
Expand Down Expand Up @@ -135,7 +135,7 @@ async def get_activity(


@router.get("/activity-record", response_model=schemas.ActivityRecordResponse)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_activity_by_cw_unit_id(
cw_unit_id: str,
coin_id: str,
Expand Down
8 changes: 4 additions & 4 deletions app/api/v1/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from app.db.session import get_engine_cls
from app.errors import ErrorCode
from app.models import State
from app.utils import disallow
from app.utils import disallow_startup

router = APIRouter()
errorcode = ErrorCode()
Expand All @@ -30,7 +30,7 @@


@router.on_event("startup")
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def init_db() -> None:
Engine = await get_engine_cls()

Expand Down Expand Up @@ -143,7 +143,7 @@ async def _scan_token_activity(

@router.on_event("startup")
@repeat_every(seconds=60, logger=logger)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def scan_token_activity() -> None:
if lock.locked():
return
Expand Down Expand Up @@ -194,7 +194,7 @@ async def _scan_blockchain_state(

@router.on_event("startup")
@repeat_every(seconds=10, logger=logger)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_startup([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def scan_blockchain_state() -> None:
async with (
deps.get_db_session_context() as db,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from app import schemas
from app.api import dependencies as deps
from app.config import ExecutionMode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()

Expand All @@ -26,7 +26,7 @@
"/",
response_model=schemas.Key,
)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.EXPLORER]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.EXPLORER])
async def get_key(
hardened: bool = False,
derivation_index: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/organizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from app import crud
from app.config import ExecutionMode, settings
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()


# pass through resource to expose organization data from cadt
@router.get("/", response_model=Any)
@disallow([ExecutionMode.REGISTRY, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.REGISTRY, ExecutionMode.CLIENT])
async def get_organizations() -> Any:
all_organizations = crud.ClimateWareHouseCrud(
url=settings.CADT_API_SERVER_HOST,
Expand Down
12 changes: 6 additions & 6 deletions app/api/v1/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from app.core import utils
from app.core.climate_wallet.wallet import ClimateWallet
from app.core.types import ClimateTokenIndex, GatewayMode
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()
logger = logging.getLogger("ClimateToken")
Expand All @@ -30,7 +30,7 @@
"/",
response_model=schemas.TokenizationTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def create_tokenization_tx(
request: schemas.TokenizationTxRequest,
wallet_rpc_client: WalletRpcClient = Depends(deps.get_wallet_rpc_client),
Expand Down Expand Up @@ -110,7 +110,7 @@ async def create_tokenization_tx(
"/{asset_id}/detokenize",
response_model=schemas.DetokenizationTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def create_detokenization_tx(
asset_id: str,
request: schemas.DetokenizationTxRequest,
Expand Down Expand Up @@ -150,7 +150,7 @@ async def create_detokenization_tx(
"/{asset_id}/request-detokenization",
response_model=schemas.DetokenizationFileResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY])
async def create_detokenization_file(
asset_id: str,
request: schemas.DetokenizationFileRequest,
Expand Down Expand Up @@ -221,7 +221,7 @@ async def create_detokenization_file(
"/parse-detokenization",
response_model=schemas.DetokenizationFileParseResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.CLIENT]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.CLIENT])
async def parse_detokenization_file(
content: str,
) -> schemas.DetokenizationFileParseResponse:
Expand Down Expand Up @@ -257,7 +257,7 @@ async def parse_detokenization_file(
"/{asset_id}/permissionless-retire",
response_model=schemas.PermissionlessRetirementTxResponse,
)
@disallow([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER, ExecutionMode.REGISTRY])
async def create_permissionless_retirement_tx(
asset_id: str,
request: schemas.PermissionlessRetirementTxRequest,
Expand Down
4 changes: 2 additions & 2 deletions app/api/v1/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from app.core.chialisp.gateway import create_gateway_puzzle, parse_gateway_spend
from app.core.types import CLIMATE_WALLET_INDEX, GatewayMode
from app.schemas.types import ChiaJsonObject
from app.utils import disallow
from app.utils import disallow_route

router = APIRouter()

Expand All @@ -33,7 +33,7 @@
"/{transaction_id}",
response_model=schemas.Transaction,
)
@disallow([ExecutionMode.EXPLORER]) # type: ignore[misc]
@disallow_route([ExecutionMode.EXPLORER])
async def get_transaction(
transaction_id: str,
wallet_rpc_client: WalletRpcClient = Depends(deps.get_wallet_rpc_client),
Expand Down
81 changes: 44 additions & 37 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,65 @@
from __future__ import annotations

import functools
import logging
import os
import time
from typing import Callable
from collections.abc import Coroutine
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar

from fastapi import status
from fastapi import HTTPException

from app.config import ExecutionMode, settings

logger = logging.getLogger("ClimateToken")

# from typing import Any, Callable, Concatenate, Coroutine, List, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")


# P = ParamSpec("P")
# R = TypeVar("R")


# def disallow(
# modes: List[ExecutionMode],
# ) -> Callable[[Callable[Concatenate[P], Coroutine[Any, Any, R]]], Callable[Concatenate[P], Coroutine[Any, Any, R]],]:
# def decorator(
# f: Callable[Concatenate[P], Coroutine[Any, Any, R]]
# ) -> Callable[Concatenate[P], Coroutine[Any, Any, R]]:
# if settings.MODE in modes:

# async def not_allowed(*args: P.args, **kwargs: P.kwargs) -> Any:
# return status.HTTP_405_METHOD_NOT_ALLOWED

# return not_allowed

# return f

# return decorator


def disallow(modes: list[ExecutionMode]): # type: ignore[no-untyped-def]
def _disallow(f: Callable): # type: ignore[no-untyped-def, type-arg]
def disallow_route(
modes: list[ExecutionMode],
) -> Callable[
[Callable[Concatenate[P], Coroutine[Any, Any, R]]],
Callable[Concatenate[P], Coroutine[Any, Any, R]],
]:
def decorator(
f: Callable[Concatenate[P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[P], Coroutine[Any, Any, R]]:
if settings.MODE in modes:
# P.args & P.kwargs don't seem to work with fastapi decorators
# see https://github.com/PrefectHQ/marvin/issues/625
# putting any parameters here in this call results in fastapi always returning 422
# if the request doesn't have the proper query parameters (literally "args" and "kwargs")
# using () allows the 405 error to be returned by fastapi regardless of query parameters
async def not_allowed() -> Any:
raise HTTPException(status_code=405, detail="Method not allowed")

return not_allowed # type: ignore[return-value]

return f

return decorator


def disallow_startup(
modes: list[ExecutionMode],
) -> Callable[
[Callable[Concatenate[P], Coroutine[Any, Any, None]]],
Callable[Concatenate[P], Coroutine[Any, Any, None]],
]:
def decorator(
f: Callable[Concatenate[P], Coroutine[Any, Any, None]],
) -> Callable[Concatenate[P], Coroutine[Any, Any, None]]:
if settings.MODE in modes:

async def _f(*args, **kargs): # type: ignore[no-untyped-def]
return status.HTTP_405_METHOD_NOT_ALLOWED

else:
async def not_allowed(*args: P.args, **kwargs: P.kwargs) -> None:
return

@functools.wraps(f)
async def _f(*args, **kargs): # type: ignore[no-untyped-def]
return await f(*args, **kargs)
return not_allowed

return _f
return f

return _disallow
return decorator


def wait_until_dir_exists(path: str, interval: int = 1) -> None:
Expand Down
Loading

0 comments on commit 86c206f

Please sign in to comment.