Skip to content

Commit

Permalink
server: Introduce ApiProtocol (Chia-Network#15466)
Browse files Browse the repository at this point in the history
* server: Introduce `ApiProtocol`

* genericize (#5)

* `ApiProtocol.api_ready` -> `ApiProtocol.ready()`

* Add `ApiProtocol.log` and give APIs separate loggers

* Fix `CrawlerAPI`

* Drop some unrelated removals

* Fix some of the generic hinting

* Revert some changes in `timelord_api.py`

* Fix `CawlerAPI` readiness

* Fix hinting

* Get some `CrawlerAPI` coverage

---------

Co-authored-by: Kyle Altendorf <[email protected]>
  • Loading branch information
xdustinface and altendky authored Jun 14, 2023
1 parent 46a244a commit 49140b2
Show file tree
Hide file tree
Showing 36 changed files with 288 additions and 112 deletions.
9 changes: 3 additions & 6 deletions chia/data_layer/data_layer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@


class DataLayerAPI:
log: logging.Logger
data_layer: DataLayer

def __init__(self, data_layer: DataLayer) -> None:
self.log = logging.getLogger(__name__)
self.data_layer = data_layer

# def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
Expand All @@ -19,10 +21,5 @@ def __init__(self, data_layer: DataLayer) -> None:
def server(self) -> ChiaServer:
return self.data_layer.server

@property
def log(self) -> logging.Logger:
return self.data_layer.log

@property
def api_ready(self) -> bool:
def ready(self) -> bool:
return self.data_layer.initialized
6 changes: 6 additions & 0 deletions chia/farmer/farmer_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -53,11 +54,16 @@ def strip_old_entries(pairs: List[Tuple[float, Any]], before: float) -> List[Tup


class FarmerAPI:
log: logging.Logger
farmer: Farmer

def __init__(self, farmer: Farmer) -> None:
self.log = logging.getLogger(__name__)
self.farmer = farmer

def ready(self) -> bool:
return self.farmer.started

@api_request(peer_required=True)
async def new_proof_of_space(
self, new_proof_of_space: harvester_protocol.NewProofOfSpace, peer: WSChiaConnection
Expand Down
9 changes: 3 additions & 6 deletions chia/full_node/full_node_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@


class FullNodeAPI:
log: logging.Logger
full_node: FullNode
executor: ThreadPoolExecutor

def __init__(self, full_node: FullNode) -> None:
self.log = logging.getLogger(__name__)
self.full_node = full_node
self.executor = ThreadPoolExecutor(max_workers=1)

Expand All @@ -81,12 +83,7 @@ def server(self) -> ChiaServer:
assert self.full_node.server is not None
return self.full_node.server

@property
def log(self) -> logging.Logger:
return self.full_node.log

@property
def api_ready(self) -> bool:
def ready(self) -> bool:
return self.full_node.initialized

@api_request(peer_required=True, reply_types=[ProtocolMessageTypes.respond_peers])
Expand Down
6 changes: 6 additions & 0 deletions chia/harvester/harvester_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -29,11 +30,16 @@


class HarvesterAPI:
log: logging.Logger
harvester: Harvester

def __init__(self, harvester: Harvester):
self.log = logging.getLogger(__name__)
self.harvester = harvester

def ready(self) -> bool:
return True

@api_request(peer_required=True)
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection
Expand Down
6 changes: 6 additions & 0 deletions chia/introducer/introducer_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import Optional

from chia.introducer.introducer import Introducer
Expand All @@ -14,11 +15,16 @@


class IntroducerAPI:
log: logging.Logger
introducer: Introducer

def __init__(self, introducer) -> None:
self.log = logging.getLogger(__name__)
self.introducer = introducer

def ready(self) -> bool:
return True

def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
pass

Expand Down
7 changes: 4 additions & 3 deletions chia/seeder/crawler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@


class CrawlerAPI:
log: logging.Logger
crawler: Crawler

def __init__(self, crawler: Crawler) -> None:
self.log = logging.getLogger(__name__)
self.crawler = crawler

@property
def server(self) -> ChiaServer:
assert self.crawler.server is not None
return self.crawler.server

@property
def log(self) -> logging.Logger:
return self.crawler.log
def ready(self) -> bool:
return True

@api_request(peer_required=True)
async def request_peers(
Expand Down
2 changes: 1 addition & 1 deletion chia/seeder/start_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_full_node_crawler_service(
config: Dict,
consensus_constants: ConsensusConstants,
connect_to_daemon: bool = True,
) -> Service[Crawler]:
) -> Service[Crawler, CrawlerAPI]:
service_config = config[SERVICE_NAME]

crawler = Crawler(
Expand Down
12 changes: 12 additions & 0 deletions chia/server/api_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from logging import Logger

from typing_extensions import Protocol


class ApiProtocol(Protocol):
log: Logger

def ready(self) -> bool:
...
5 changes: 3 additions & 2 deletions chia/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from chia.protocols.protocol_state_machine import message_requires_reply
from chia.protocols.protocol_timing import INVALID_PROTOCOL_BAN_SECONDS
from chia.protocols.shared_protocol import protocol_version
from chia.server.api_protocol import ApiProtocol
from chia.server.introducer_peers import IntroducerPeers
from chia.server.outbound_message import Message, NodeType
from chia.server.ssl_context import private_ssl_paths, public_ssl_paths
Expand Down Expand Up @@ -122,7 +123,7 @@ class ChiaServer:
_network_id: str
_inbound_rate_limit_percent: int
_outbound_rate_limit_percent: int
api: Any
api: ApiProtocol
node: Any
root_path: Path
config: Dict[str, Any]
Expand All @@ -147,7 +148,7 @@ def create(
cls,
port: int,
node: Any,
api: Any,
api: ApiProtocol,
local_type: NodeType,
ping_interval: int,
network_id: str,
Expand Down
5 changes: 3 additions & 2 deletions chia/server/start_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
from chia.wallet.wallet_node import WalletNode
from chia.wallet.wallet_node_api import WalletNodeAPI

# See: https://bugs.python.org/issue29288
"".encode("idna")
Expand All @@ -31,9 +32,9 @@ def create_data_layer_service(
config: Dict[str, Any],
downloaders: List[str],
uploaders: List[str], # dont add FilesystemUploader to this, it is the default uploader
wallet_service: Optional[Service[WalletNode]] = None,
wallet_service: Optional[Service[WalletNode, WalletNodeAPI]] = None,
connect_to_daemon: bool = True,
) -> Service[DataLayer]:
) -> Service[DataLayer, DataLayerAPI]:
if uploaders is None:
uploaders = []
if downloaders is None:
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_farmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_farmer_service(
consensus_constants: ConsensusConstants,
keychain: Optional[Keychain] = None,
connect_to_daemon: bool = True,
) -> Service[Farmer]:
) -> Service[Farmer, FarmerAPI]:
service_config = config[SERVICE_NAME]

fnp = service_config.get("full_node_peer")
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_full_node_service(
consensus_constants: ConsensusConstants,
connect_to_daemon: bool = True,
override_capabilities: Optional[List[Tuple[uint16, str]]] = None,
) -> Service[FullNode]:
) -> Service[FullNode, FullNodeAPI]:
service_config = config[SERVICE_NAME]

full_node = FullNode(
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_harvester.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def create_harvester_service(
consensus_constants: ConsensusConstants,
farmer_peer: Optional[UnresolvedPeerInfo],
connect_to_daemon: bool = True,
) -> Service[Harvester]:
) -> Service[Harvester, HarvesterAPI]:
service_config = config[SERVICE_NAME]

overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]]
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_introducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_introducer_service(
config: Dict[str, Any],
advertised_port: Optional[int] = None,
connect_to_daemon: bool = True,
) -> Service[Introducer]:
) -> Service[Introducer, IntroducerAPI]:
service_config = config[SERVICE_NAME]

if advertised_port is None:
Expand Down
6 changes: 4 additions & 2 deletions chia/server/start_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chia.cmds.init_funcs import chia_full_version_str
from chia.daemon.server import service_launch_lock_path
from chia.rpc.rpc_server import RpcApiProtocol, RpcServer, RpcServiceProtocol, start_rpc_server
from chia.server.api_protocol import ApiProtocol
from chia.server.chia_policy import set_chia_policy
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
Expand All @@ -34,6 +35,7 @@

T = TypeVar("T")
_T_RpcServiceProtocol = TypeVar("_T_RpcServiceProtocol", bound=RpcServiceProtocol)
_T_ApiProtocol = TypeVar("_T_ApiProtocol", bound=ApiProtocol)

RpcInfo = Tuple[Type[RpcApiProtocol], int]

Expand All @@ -42,12 +44,12 @@ class ServiceException(Exception):
pass


class Service(Generic[_T_RpcServiceProtocol]):
class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]):
def __init__(
self,
root_path: Path,
node: _T_RpcServiceProtocol,
peer_api: Any,
peer_api: _T_ApiProtocol,
node_type: NodeType,
advertised_port: int,
service_name: str,
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_timelord.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_timelord_service(
config: Dict[str, Any],
constants: ConsensusConstants,
connect_to_daemon: bool = True,
) -> Service[Timelord]:
) -> Service[Timelord, TimelordAPI]:
service_config = config[SERVICE_NAME]

connect_peers = {
Expand Down
2 changes: 1 addition & 1 deletion chia/server/start_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_wallet_service(
consensus_constants: ConsensusConstants,
keychain: Optional[Keychain] = None,
connect_to_daemon: bool = True,
) -> Service[WalletNode]:
) -> Service[WalletNode, WalletNodeAPI]:
service_config = config[SERVICE_NAME]

overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]]
Expand Down
11 changes: 6 additions & 5 deletions chia/server/ws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from chia.protocols.protocol_state_machine import message_response_ok
from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.api_protocol import ApiProtocol
from chia.server.capabilities import known_active_capabilities
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
Expand Down Expand Up @@ -66,7 +67,7 @@ class WSChiaConnection:
"""

ws: WebSocket = field(repr=False)
api: Any = field(repr=False)
api: ApiProtocol = field(repr=False)
local_type: NodeType
local_port: int
local_capabilities_for_handshake: List[Tuple[uint16, str]] = field(repr=False)
Expand Down Expand Up @@ -123,7 +124,7 @@ def create(
cls,
local_type: NodeType,
ws: WebSocket,
api: Any,
api: ApiProtocol,
server_port: int,
log: logging.Logger,
is_outbound: bool,
Expand Down Expand Up @@ -373,9 +374,9 @@ async def _api_call(self, full_message: Message, task_id: bytes32) -> None:
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])

# If api is not ready ignore the request
if hasattr(self.api, "api_ready"):
if self.api.api_ready is False:
return None
if not self.api.ready():
self.log.warning(f"API not ready, ignore request: {full_message}")
return None

timeout: Optional[int] = 600
if metadata.execute_task:
Expand Down
Loading

0 comments on commit 49140b2

Please sign in to comment.