diff --git a/.python-version b/.python-version index 8531a3b..871f80a 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12.2 +3.12.3 diff --git a/.vscode/settings.json b/.vscode/settings.json index de7acfc..4b064bf 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,5 +10,8 @@ "**/__pycache__": true, "**/.pytest_cache": true, ".venv/": true - } + }, + "python.testing.pytestArgs": ["src"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/README.md b/README.md index 4c9fb54..93aa44a 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# llegos +# fastactor diff --git a/pyproject.toml b/pyproject.toml index 0b26944..937dfc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,15 @@ [project] -name = "llegos" +name = "fastactor" version = "0.1.0" -description = "Add your description here" -authors = [ - { name = "Cyrus Nouroozi", email = "cyrus@edendaolab.com" } -] +description = "FastActor is a fast and easy-to-use actor framework for Python." +authors = [{ name = "Cyrus Nouroozi", email = "cyrus@zenbase.ai" }] dependencies = [ - "anyio>=4.3.0", - "pyee>=11.1.0", + "anyio>=4.7.0", + "pyee>=12.1.1", + "pydantic-settings>=2.7.0", + "sorcery>=0.2.2", + "svix-ksuid>=0.6.2", + "beartype>=0.19.0", ] readme = "README.md" requires-python = ">= 3.8" @@ -21,15 +23,14 @@ managed = true dev-dependencies = [ "ipython>=8.23.0", "ipdb>=0.13.13", - "openai>=1.17.1", - "pytest>=8.1.1", + "pytest>=8.3.3", "pytest-anyio>=0.0.0", "trio>=0.25.0", - "pytest-asyncio>=0.23.6", + "pydantic>=2.10.4", ] [tool.hatch.metadata] allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["src/llegos"] +packages = ["src/fastactor"] diff --git a/requirements-dev.lock b/requirements-dev.lock index d12ee1b..6e05413 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,39 +6,31 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.3.0 - # via httpx - # via llegos - # via openai +anyio==4.7.0 + # via fastactor # via pytest-anyio asttokens==2.4.1 + # via sorcery # via stack-data attrs==23.2.0 # via outcome # via trio -certifi==2024.2.2 - # via httpcore - # via httpx +beartype==0.19.0 + # via fastactor decorator==5.1.1 # via ipdb # via ipython -distro==1.9.0 - # via openai executing==2.0.1 + # via sorcery # via stack-data -h11==0.14.0 - # via httpcore -httpcore==1.0.5 - # via httpx -httpx==0.27.0 - # via openai idna==3.7 # via anyio - # via httpx # via trio iniconfig==2.0.0 # via pytest @@ -47,9 +39,10 @@ ipython==8.23.0 # via ipdb jedi==0.19.1 # via ipython +littleutils==0.2.4 + # via sorcery matplotlib-inline==0.1.6 # via ipython -openai==1.17.1 outcome==1.3.0.post0 # via trio packaging==24.0 @@ -58,7 +51,7 @@ parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -pluggy==1.4.0 +pluggy==1.5.0 # via pytest prompt-toolkit==3.0.43 # via ipython @@ -66,40 +59,46 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pydantic==2.7.0 - # via openai -pydantic-core==2.18.1 +pydantic==2.10.4 + # via pydantic-settings +pydantic-core==2.27.2 # via pydantic -pyee==11.1.0 - # via llegos +pydantic-settings==2.7.0 + # via fastactor +pyee==12.1.1 + # via fastactor pygments==2.17.2 # via ipython -pytest==8.1.1 +pytest==8.3.4 # via pytest-anyio - # via pytest-asyncio pytest-anyio==0.0.0 -pytest-asyncio==0.23.6 +python-baseconv==1.2.2 + # via svix-ksuid +python-dotenv==1.0.1 + # via pydantic-settings six==1.16.0 # via asttokens sniffio==1.3.1 # via anyio - # via httpx - # via openai # via trio +sorcery==0.2.2 + # via fastactor sortedcontainers==2.4.0 # via trio stack-data==0.6.3 # via ipython -tqdm==4.66.2 - # via openai +svix-ksuid==0.6.2 + # via fastactor traitlets==5.14.2 # via ipython # via matplotlib-inline trio==0.25.0 -typing-extensions==4.11.0 - # via openai +typing-extensions==4.12.2 + # via anyio # via pydantic # via pydantic-core # via pyee wcwidth==0.2.13 # via prompt-toolkit +wrapt==1.17.0 + # via sorcery diff --git a/requirements.lock b/requirements.lock index c6d6a9c..a4d239f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,15 +6,46 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. -anyio==4.3.0 - # via llegos +annotated-types==0.7.0 + # via pydantic +anyio==4.7.0 + # via fastactor +asttokens==3.0.0 + # via sorcery +beartype==0.19.0 + # via fastactor +executing==2.1.0 + # via sorcery idna==3.7 # via anyio -pyee==11.1.0 - # via llegos +littleutils==0.2.4 + # via sorcery +pydantic==2.9.2 + # via pydantic-settings +pydantic-core==2.23.4 + # via pydantic +pydantic-settings==2.7.0 + # via fastactor +pyee==12.1.1 + # via fastactor +python-baseconv==1.2.2 + # via svix-ksuid +python-dotenv==1.0.1 + # via pydantic-settings sniffio==1.3.1 # via anyio +sorcery==0.2.2 + # via fastactor +svix-ksuid==0.6.2 + # via fastactor typing-extensions==4.11.0 + # via anyio + # via pydantic + # via pydantic-core # via pyee +wrapt==1.17.0 + # via sorcery diff --git a/src/fastactor/__init__.py b/src/fastactor/__init__.py new file mode 100644 index 0000000..319375d --- /dev/null +++ b/src/fastactor/__init__.py @@ -0,0 +1,5 @@ +from .settings import settings + +__all__ = [ + "settings", +] diff --git a/src/fastactor/otp.py b/src/fastactor/otp.py new file mode 100644 index 0000000..5cb9bb4 --- /dev/null +++ b/src/fastactor/otp.py @@ -0,0 +1,834 @@ +from abc import ABC +from asyncio import gather, as_completed +from collections import deque +from dataclasses import dataclass, field +from time import monotonic +import logging +import typing as t + +from anyio import ( + ClosedResourceError, + EndOfStream, + create_memory_object_stream, + create_task_group, + Event, + fail_after, + Lock, +) +from anyio.abc import TaskGroup +from anyio.streams.memory import MemoryObjectSendStream, MemoryObjectReceiveStream +from sorcery import dict_of + +# Adjust these to your actual import paths +from .settings import settings +from .utils import id_generator + +# ----------------------------------------------------------- +# Logging Setup +# ----------------------------------------------------------- +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------- +# Exceptions +# ----------------------------------------------------------- +class Crashed(Exception): + def __init__(self, reason: str): + self.reason = reason + + +class Failed(Exception): + def __init__(self, reason: str): + self.reason = reason + + +class Shutdown(Exception): + def __init__(self, reason: t.Any): + self.reason = reason + + +# ----------------------------------------------------------- +# Helpers +# ----------------------------------------------------------- +def _is_normal_shutdown_reason(reason: t.Any) -> bool: + match reason: + case "normal" | "shutdown": + return True + case Shutdown(): + return True + case _: + return False + + +# ----------------------------------------------------------- +# Base Message Types +# ----------------------------------------------------------- +@dataclass +class Message(ABC): + sender: "Process" + + +@dataclass +class Info(Message): + message: t.Any + + +@dataclass +class Stop(Message): + reason: t.Any + reply: t.Optional[t.Any] = None + + +@dataclass +class Exit(Message): + reason: t.Any + + +@dataclass +class Down(Message): + reason: t.Any + + +class Ignore: + """Used by `Process.init` to skip starting the loop if desired.""" + + +# ----------------------------------------------------------- +# Minimal Actor Process +# ----------------------------------------------------------- +@dataclass(repr=False) +class Process: + """ + A minimal actor with mailbox, linking/monitoring, start/stop lifecycle. + Does NOT implement synchronous call/cast logic; that belongs in `GenServer`. + """ + + id: str = field(default_factory=id_generator("process")) + supervisor: t.Optional["Supervisor"] = None + trap_exits: bool = False + + _inbox: MemoryObjectSendStream[Message] | None = field(default=None, init=False) + _mailbox: MemoryObjectReceiveStream[Message] | None = field( + default=None, + init=False, + ) + + _started: Event = field(default_factory=Event, init=False) + _stopped: Event = field(default_factory=Event, init=False) + _crash_exc: Exception = field(default=None, init=False) + + links: set["Process"] = field(default_factory=set, init=False) + monitors: set["Process"] = field(default_factory=set, init=False) + monitored_by: set["Process"] = field(default_factory=set, init=False) + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other: "Process"): + return self.id == other.id + + # --------------- Lifecycle --------------- # + async def _init(self, *args, **kwargs): + logger.debug("%s init", self) + try: + init_result = await self.init(*args, **kwargs) + except Exception as error: + logger.error("%s init crashed %s", self, error) + init_result = Stop(self, error) + + match init_result: + case Ignore(): + self._started.set() + self._stopped.set() + case Stop(_, reason): + self._started.set() + self._stopped.set() + raise Shutdown(reason) + case _: + self._inbox, self._mailbox = create_memory_object_stream( + settings.mailbox_size + ) + + async def init(self, *args, **kwargs) -> t.Union[Ignore, Stop, None]: + """ + Subclasses should override. Return Ignore to skip the loop, + or Stop(...) to fail initialization. + """ + return None + + def has_started(self) -> bool: + return self._started.is_set() + + async def started(self) -> t.Self: + await self._started.wait() + return self + + def has_stopped(self) -> bool: + return self._stopped.is_set() + + async def stopped(self) -> t.Self: + await self._stopped.wait() + return self + + async def handle_exit(self, message: Exit): + """Override in subclass if needed.""" + pass + + async def terminate(self, reason: t.Any): + """ + Called after the main loop ends or if forcibly stopped. + Subclasses can override for final cleanup. + """ + logger.debug("%s terminate reason=%s", self, reason) + + if self.monitored_by: + logger.debug("%s notifying monitors", self) + for process in list(self.monitored_by): + process.demonitor(self) + try: + await process.send(Down(self, reason)) + except Exception as error: + logger.error( + "%r sending Down(%s, %s) to monitor: %s", + error, + self, + reason, + process, + ) + + if self.links: + logger.debug("%s notifying links", self) + abnormal_shutdown = not _is_normal_shutdown_reason(reason) + for process in list(self.links): + self.unlink(process) + if process.trap_exits: + try: + await process.send(Exit(self, reason)) + except Exception as error: + logger.error( + "%r sending Exit(%s, %s) to linked actor: %s", + error, + self, + reason, + process, + ) + elif abnormal_shutdown: + process._crash_exc = ( + reason + if isinstance(reason, Exception) + else Crashed(str(reason)) + ) + try: + await process.stop(reason) + except Exception as error: + logger.error("%r killing %s: %s", error, process, reason) + + for process in list(self.monitors): + process.demonitor(self) + + Runtime.current().unregister(self) + logger.debug("%s terminated reason=%s", self, reason) + + # --------------- Link / Monitor --------------- # + def link(self, other: "Process"): + self.links.add(other) + other.links.add(self) + + def unlink(self, other: "Process"): + self.links.discard(other) + other.links.discard(self) + + def monitor(self, other: "Process"): + self.monitors.add(other) + other.monitored_by.add(self) + + def demonitor(self, other: "Process"): + self.monitors.discard(other) + other.monitored_by.discard(self) + + # --------------- Basic Send / Stop --------------- # + async def send(self, message: t.Any): + await self._inbox.send(message) + + def send_nowait(self, message: t.Any): + self._inbox.send_nowait(message) + + async def stop( + self, + reason: t.Any = "normal", + timeout: t.Optional[int] = 60, + sender: t.Optional["Process"] = None, + ): + """ + Gracefully stop by sending a Stop(...) message. + """ + logger.debug("%s stop %s", self, reason) + self.send_nowait(Stop(sender or self.supervisor, reason)) + with fail_after(timeout): + await self.stopped() + + async def kill(self): + """ + Forcibly kill the process by closing the mailbox. + """ + logger.debug("%s kill", self) + await self._mailbox.aclose() + await self.stopped() + + def info(self, message: t.Any, sender: t.Optional["Process"] = None): + sender = sender or self.supervisor + self.send_nowait(Info(sender, message)) + + async def handle_info(self, message: Message): ... + + # --------------- Internal Loop --------------- # + async def loop(self, args, kwargs): + """ + Wraps `_loop` in a try/finally to detect crashes. Subclasses can override or extend. + """ + await self._init(*args, **kwargs) + async with self._mailbox: + reason = "normal" + try: + await self._loop() + except ClosedResourceError: + reason = "killed" + except Shutdown as error: + if not _is_normal_shutdown_reason(error.reason): + reason = error.reason + self._crash_exc = Crashed(str(reason)) + except Exception as error: + self._crash_exc = error + reason = error + finally: + await self.terminate(reason) + self._stopped.set() + + async def _loop(self): + """ + The main loop: reads messages, calls `_handle_message`. + """ + self._started.set() + logger.debug("%s loop started", self) + while True: + try: + message = await self._mailbox.receive() + logger.debug("%s received message %s", self, message) + except (EndOfStream, ClosedResourceError): + break + + await self._handle_message(message) + + async def _handle_message(self, message: Message) -> bool: + """ + Returns True to keep going, False to break the loop. + """ + match message: + case Stop(_, reason): + raise Shutdown(reason) + case Exit(_, reason): + if not self.trap_exits and not _is_normal_shutdown_reason(reason): + self.unlink(message.sender) + self._crash_exc = Crashed(str(reason)) + raise Shutdown(reason) + else: + await self.handle_exit(message) + case _: + # Unrecognized => treat as info by default + await self.handle_info(message) + + @classmethod + async def start( + cls, + *args, + trap_exits=False, + supervisor: t.Optional["Supervisor"] = None, + **kwargs, + ) -> t.Optional[t.Self]: + runtime = Runtime.current() + supervisor = supervisor or runtime.supervisor + + process = cls(supervisor=supervisor, trap_exits=trap_exits) + return await runtime.spawn(process, *args, **kwargs) + + @classmethod + async def start_link( + cls, + *args, + trap_exits=False, + supervisor: t.Optional["Supervisor"] = None, + **kwargs, + ) -> t.Optional[t.Self]: + """ + Similar to `start`, but also link the new process with the supervisor. + """ + process = await cls.start( + *args, trap_exits=trap_exits, supervisor=supervisor, **kwargs + ) + if process is not None: + process.link(process.supervisor) + return process + + +# ----------------------------------------------------------- +# GenServer +# ----------------------------------------------------------- +@dataclass +class Call(Message): + """ + GenServer-specific synchronous call. + """ + + message: t.Any + _result: t.Any = field(default=None, init=False) + _ready: Event = field(default_factory=Event, init=False, repr=False) + + def set_result(self, value: t.Any): + self._result = value + self._ready.set() + + async def result(self, timeout=5): + with fail_after(timeout): + await self._ready.wait() + if isinstance(self._result, Exception): + raise self._result + return self._result + + +@dataclass +class Cast(Message): + message: t.Any + + +class GenServer(Process): + """ + A GenServer extends Process with `call`/`cast` semantics similar to Elixir's GenServer. + """ + + # Overridable callbacks + async def handle_call(self, request: t.Any) -> t.Any: + return None + + async def handle_cast(self, request: t.Any) -> t.Any: + pass + + async def handle_info(self, message: t.Any): + pass + + # Public API for user code + async def call( + self, + request: t.Any, + sender: t.Optional["Process"] = None, + timeout=5, + ) -> t.Any: + sender = sender or self.supervisor + callmsg = Call(sender, request) + self.send_nowait(callmsg) + return await callmsg.result(timeout) + + def cast(self, request: t.Any, sender: t.Optional["Process"] = None): + sender = sender or self.supervisor + self.send_nowait(Cast(sender, request)) + + async def _handle_message(self, message: t.Any) -> bool: + """ + We override to check for call/cast messages. If it's not, fallback to base `_handle_message`. + """ + match message: + case Call(): + try: + reply = await self.handle_call(message) + message.set_result(reply) + return True + except Exception as error: + message.set_result(error) + logger.error( + "%s encountered %r processing %s", self, error, message + ) + raise error + case Cast(): + try: + await self.handle_cast(message) + return True + except Exception as error: + logger.error( + "%s encountered %r processing %s", self, error, message + ) + raise error + case _: + return await super()._handle_message(message) + + +# ----------------------------------------------------------- +# Supervisor +# ----------------------------------------------------------- +RestartType = t.Literal["permanent", "transient", "temporary"] +ShutdownType = t.Union[int, t.Literal["brutal_kill", "infinity"]] +RestartStrategy = t.Literal["one_for_one", "one_for_all"] + + +class ChildSpec[T: type[Process]](t.NamedTuple): + type: T + args: tuple + kwargs: dict + restart_type: RestartType + shutdown_type: ShutdownType + + +class Child[T: Process](t.NamedTuple): + child_id: str + child_proc: T + child_type: str + child_modules: list[type[T]] + + +class RunningChild(t.NamedTuple): + process: Process + restart_type: RestartType + shutdown_type: ShutdownType + + +@dataclass(repr=False) +class Supervisor(Process): + # __init__ + trap_exits: bool = True + + # init + strategy: RestartStrategy = "one_for_one" + max_restarts: int = 3 + max_seconds: float = 5.0 + + child_specs: dict[str, ChildSpec] = field(default_factory=dict) + children: dict[str, RunningChild] = field(default_factory=dict, init=False) + _task_group: TaskGroup | None = field(default=None, init=False) + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other: "Supervisor"): + return self.id == other.id + + async def init( + self, + strategy: RestartStrategy = "one_for_one", + max_restarts: int = 3, + max_seconds: float = 5.0, + ): + self.strategy = strategy + self.max_restarts = max_restarts + self.max_seconds = max_seconds + + async def terminate(self, reason: t.Any): + """ + Stop all children on termination. + """ + tasks = [ + self._shutdown_child(actor, sh, reason) + for actor, _, sh in list(self.children.values()) + ] + await gather(*tasks) + self.children.clear() + + await super().terminate(reason) + + def _should_restart(self, reason: t.Any, rst: RestartType) -> bool: + if rst == "permanent": + return True + elif rst == "temporary": + return False + elif rst == "transient": + return not _is_normal_shutdown_reason(reason) + raise ValueError(f"Unsupported restart type: {rst}") + + async def _shutdown_child( + self, actor: Process, shutdown: ShutdownType, reason="normal" + ): + if actor.has_stopped(): + return + elif shutdown == "brutal_kill": + await actor.kill() + await actor.stopped() + elif shutdown == "infinity": + await actor.stop(reason, timeout=None, sender=self) + elif isinstance(shutdown, int): + await actor.stop(reason, timeout=shutdown, sender=self) + else: + raise ValueError(f"Unsupported shutdown type: {shutdown}") + + async def _run_child(self, child_id: str, child_spec: ChildSpec, ready: Event): + restart_times = deque() + while True: + process = await child_spec.type.start_link( + *child_spec.args, + supervisor=self, + **child_spec.kwargs, + ) + self.children[child_id] = RunningChild( + process, + child_spec.restart_type, + child_spec.shutdown_type, + ) + ready.set() + + # Wait for the child to stop + await process.stopped() + + reason = process._crash_exc or "normal" + if not self._should_restart(reason, child_spec.restart_type): + break + + now = monotonic() + restart_times.append(now) + while restart_times and now - restart_times[0] > self.max_seconds: + restart_times.popleft() + if len(restart_times) > self.max_restarts: + raise Failed("Max restart intensity reached") + + del self.children[child_id] + + async def _one_for_one(self): + tasks = [] + for child_id, spec in self.child_specs.items(): + ready = Event() + self._task_group.start_soon(self._run_child, child_id, spec, ready) + tasks.append(ready.wait()) + await gather(*tasks) + + async def _one_for_all(self): + async def _run_child(child_id, spec): + proc = await spec.type.start_link( + *spec.args, supervisor=self, **spec.kwargs + ) + self.children[child_id] = (proc, spec.restart_type, spec.shutdown_type) + return (child_id, proc, spec.restart_type, spec.shutdown_type) + + restarts = deque() + while True: + actors_info = await gather( + *[ + _run_child(child_id, spec) + for child_id, spec in self.child_specs.items() + ] + ) + + # Wait for any process to stop + await as_completed( + [proc.stopped() for _, proc, _, _ in actors_info], + return_when="FIRST_COMPLETED", + ) + # Get reasons for all processes (stopped or not) + reasons = [ + (proc.has_stopped(), proc._crash_exc or "normal", rst) + for _, proc, rst, _ in actors_info + ] + + # If any child should restart => restart all + if any( + has_stopped and self._should_restart(reason, restart_type) + for (has_stopped, reason, restart_type) in reasons + ): + now = monotonic() + restarts.append(now) + while restarts and now - restarts[0] > self.max_seconds: + restarts.popleft() + if len(restarts) > self.max_restarts: + raise Failed("Max restart intensity reached") + else: + break + self.children.clear() + + async def loop(self, args, kwargs): + await self._init(*args, **kwargs) + async with create_task_group() as tg: + self._task_group = tg + + await getattr(self, f"_{self.strategy}")() + + async with self._mailbox: + reason = "normal" + try: + await self._loop() + except Shutdown as error: + reason = error.reason + except ClosedResourceError: + reason = "killed" + except Exception as error: + self._crash_exc = error + reason = error + finally: + await self.terminate(reason) + self._stopped.set() + + # --- Child management + def which_children(self) -> list[Child]: + results = [] + for child_id, spec in self.child_specs.items(): + if child_id in self.children: + proc, _, _ = self.children[child_id] + child = proc + else: + child = ":undefined" + ctype = "worker" + mods = [spec.type] + results.append(Child(child_id, child, ctype, mods)) + return results + + def count_children(self) -> dict: + specs = len(self.child_specs) + active = sum( + 1 for (proc, _, _) in self.children.values() if not proc.has_stopped() + ) + supervisors = sum( + 1 for spec in self.child_specs.values() if issubclass(spec.type, Supervisor) + ) + workers = specs - supervisors + return dict_of(specs, active, supervisors, workers) + + async def terminate_child(self, child_id: str): + if child_id not in self.children: + raise RuntimeError(f"Child {child_id} not currently running.") + + child = self.children[child_id] + await self._shutdown_child(child.process, child.shutdown_type, reason="normal") + del self.children[child_id] + + def delete_child(self, child_id: str): + if child_id in self.children: + raise RuntimeError( + f"Cannot delete running child {child_id}; terminate it first." + ) + + del self.child_specs[child_id] + + def _check_task_group(self): + if not self._task_group: + raise RuntimeError("Supervisor not running") + + async def start_child[ + T: Process + ](self, child_id: str, child_spec: ChildSpec[type[T]]) -> T: + self._check_task_group() + if child_id in self.child_specs: + raise RuntimeError(f"child_id {child_id} already exists") + + self.child_specs[child_id] = child_spec + ready = Event() + self._task_group.start_soon(self._run_child, child_id, child_spec, ready) + await ready.wait() + return self.children[child_id].process + + async def restart_child(self, child_id: str) -> Process: + self._check_task_group() + if child_id not in self.child_specs: + raise RuntimeError("No such child_id.") + if child_id in self.children and not self.children[child_id][0].has_stopped(): + raise RuntimeError("Child is running, cannot restart.") + + await self._task_group.start( + self._run_child, child_id, self.child_specs[child_id] + ) + return self.children[child_id].process + + @staticmethod + def child_spec[ + T: type[Process] + ]( + module_or_map: T, + args: tuple = (), + kwargs: dict = {}, + restart: RestartType = "permanent", + shutdown: ShutdownType = 5000, + ) -> ChildSpec[T]: + return ChildSpec(module_or_map, args, kwargs, restart, shutdown) + + +# ----------------------------------------------------------- +# Runtime Supervisor +# ----------------------------------------------------------- +class RuntimeSupervisor(Supervisor): + async def handle_info(self, message: t.Any): + logger.info("RuntimeSupervisor: handle_info %s", message) + + async def handle_exit(self, message: Exit): + logger.info("RuntimeSupervisor: handle_exit %s", message) + + +# ----------------------------------------------------------- +# Global "Runtime" +# ----------------------------------------------------------- +@dataclass(repr=False) +class Runtime: + # Class variables + _current: t.ClassVar[t.Optional["Runtime"]] = None + _lock: t.ClassVar[Lock] = Lock() + + # Instance variables + supervisor: t.Optional[RuntimeSupervisor] = None + _task_group: t.Optional[TaskGroup] = None + + # Runtime state + registry: dict[str, str] = field(default_factory=dict, init=False) + _reverse_registry: dict[str, str] = field(default_factory=dict, init=False) + processes: dict[str, Process] = field(default_factory=dict, init=False) + + @classmethod + def current(cls) -> "Runtime": + if cls._current is None: + raise RuntimeError("No Runtime is currently active.") + return cls._current + + async def __aenter__(self): + async with self._lock: + if Runtime._current is not None: + raise RuntimeError("Runtime already started") + + self._task_group = await create_task_group().__aenter__() + self.supervisor = RuntimeSupervisor(trap_exits=True) + await self.spawn(self.supervisor) + + Runtime._current = self + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + async with self._lock: + if self.supervisor is not None: + if self.supervisor.has_started() and not self.supervisor.has_stopped(): + await self.supervisor.stop(sender=self) + self.supervisor = None + + if self._task_group: + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + Runtime._current = None + + async def spawn[P: Process](self, process: P, *args, **kwargs) -> P: + self._task_group.start_soon(process.loop, args, kwargs) + await process.started() + + if process.has_stopped(): + raise RuntimeError("Process crashed before it could start") + + self.register(process) + return process + + def register(self, proc: Process): + self.processes[proc.id] = proc + + def unregister(self, proc: Process): + del self.processes[proc.id] + if name := self._reverse_registry.get(proc.id): + self.unregister_name(name) + + def register_name(self, name: str, proc: Process): + self.registry[name] = proc.id + self._reverse_registry[proc.id] = name + + def unregister_name(self, name: str): + id = self.registry.pop(name) + del self._reverse_registry[id] + + def where_is(self, name: str) -> t.Optional[Process]: + if id := self.registry.get(name): + return self.processes.get(id) diff --git a/src/fastactor/settings.py b/src/fastactor/settings.py new file mode 100644 index 0000000..1075336 --- /dev/null +++ b/src/fastactor/settings.py @@ -0,0 +1,10 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="FASTACTOR_") + + mailbox_size: int = 1024 + + +settings = Settings() diff --git a/src/fastactor/utils.py b/src/fastactor/utils.py new file mode 100644 index 0000000..00062f4 --- /dev/null +++ b/src/fastactor/utils.py @@ -0,0 +1,20 @@ +import typing as t +from collections import deque + +from ksuid import Ksuid + +from fastactor.settings import settings + + +def id_generator(prefix: str = "") -> t.Callable[[], str]: + def gen_id(): + return f"{prefix}:{Ksuid()}" + + return gen_id + + +def deque_factory(maxlen: int = settings.mailbox_size): + def factory(): + return deque(maxlen=maxlen) + + return factory diff --git a/src/llegos/__init__.py b/src/llegos/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/llegos/llegos.py b/src/llegos/llegos.py deleted file mode 100644 index c7e5351..0000000 --- a/src/llegos/llegos.py +++ /dev/null @@ -1,202 +0,0 @@ -from abc import ABC -from dataclasses import dataclass, field -from uuid import uuid4 -import asyncio -import logging -import os -import random -import typing as t - -from anyio import ( - Event, - create_memory_object_stream, - create_task_group, - fail_after, -) -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pyee.asyncio import AsyncIOEventEmitter - - -logger = logging.getLogger(__name__) - -LLEGOS_MAILBOX_SIZE = int(os.getenv("LLEGOS_MAILBOX_SIZE", "128")) - - -def gen_id(): - return str(uuid4()) - - -@dataclass -class Call[T]: - sender: "Actor" - request: T - _response_stream: MemoryObjectSendStream[T] - - id: str = field(default_factory=gen_id) - - def __str__(self): - return f"Call({self.id})" - - -@dataclass -class Cast[T]: - sender: "Actor" - request: T - - id: str = field(default_factory=gen_id) - - def __str__(self): - return f"Cast({self.id})" - - -@dataclass -class Actor[Req](ABC): - """ - An actor that can handle call and cast messages. - - Call messages wait for a response, while cast messages are fire-and-forget. - """ - - id: str = field(default_factory=gen_id) - mailbox_size: int = field(default=LLEGOS_MAILBOX_SIZE) - events: AsyncIOEventEmitter = field(init=False, default_factory=AsyncIOEventEmitter) - inbox: MemoryObjectSendStream[Req] = field(init=False, default=None) - mailbox: MemoryObjectReceiveStream[Req] = field(init=False, default=None) - - _started: Event = field(init=False, default_factory=Event) - - async def started(self): - return await self._started.wait() - - async def init(self): - """Any async startup code should go here.""" - self.inbox, self.mailbox = create_memory_object_stream(self.mailbox_size) - self._started.set() - - async def handle_call(self, sender: "Actor", message: Req) -> t.Any: - """Handle a call message. Return a response.""" - - async def handle_cast(self, sender: "Actor", message: Req) -> None: - """Handle a cast message. No response.""" - - async def call[ - Res - ](self, receiver: "Actor", req: Req, timeout: float | None = 5) -> Res: - send_stream, receive_stream = create_memory_object_stream(1) - - receiver.inbox.send_nowait(Call(self, req, _response_stream=send_stream)) - - with fail_after(timeout): - response = await receive_stream.receive() - - return response - - async def cast(self, receiver: "Actor", message: Req) -> None: - await receiver.inbox.send(Cast(self, message)) - - async def loop(self): - await self.init() - with self.mailbox: - async for message in self.mailbox: - await runtime.perform(self, message) - - -# TODO: Productionization — durable restarts — Redis? Kafka? -class Runtime: - events: AsyncIOEventEmitter - - def __init__(self) -> None: - self.events = AsyncIOEventEmitter() - - async def perform(self, actor: Actor, msg: Call | Cast): - kind = msg.__class__.__name__.lower() - self.events.emit("before:perform", msg) - self.events.emit(f"before:{kind}", msg) - - match msg: - case Call(): - response = await actor.handle_call(msg.sender, msg.request) - msg._response_stream.send_nowait(response) - case Cast(): - await actor.handle_cast(msg.sender, msg.request) - response = None - - self.events.emit(f"after:{kind}", msg, response) - self.events.emit("after:perform", msg, response) - - return response - - -runtime = Runtime() - - -@dataclass -class Supervisor(Actor): - class Strategy: - @classmethod - async def run(cls, task, max_restarts: int = 3): - try: - await task() - except Exception as exc: - logger.warning("Actor crashed", exc_info=exc) - max_restarts -= 1 - if max_restarts == 0: - logger.error("Max restarts reached. Actor will not be restarted.") - raise exc - - await cls.run(task, max_restarts) - - @classmethod - async def one_for_one(cls, tasks: list, max_restarts: int = 3): - async with create_task_group() as tg: - for task in tasks: - tg.start_soon(cls.run, task, max_restarts) - - @classmethod - async def one_for_all(cls, tasks: list, max_restarts: int = 3): - async def run_tasks(tasks): - async with create_task_group() as tg: - for task in tasks: - tg.start_soon(task) - - await cls.run(run_tasks, tasks, max_restarts) - - children: list[Actor] = field(default_factory=list) - strategy: t.Callable = field(default=Strategy.one_for_one) - max_restarts: int = field(default=3) - - async def started(self): - await asyncio.gather( - super().started(), *[actor.started() for actor in self.children] - ) - - async def loop(self): - async with create_task_group() as tg: - tg.start_soon(super().loop) - tg.start_soon( - self.strategy, - [actor.loop for actor in self.children], - self.max_restarts, - ) - - -# Modelled after https://elixirschool.com/en/lessons/misc/poolboy -@dataclass -class WorkerPool(Supervisor, ABC): - """ - A worker pool that distributes messages to workers. - """ - - async def router(self, message: Call | Cast) -> Actor: - """Select a worker to handle the message.""" - return random.choice(self.children) - - async def loop(self): - async with create_task_group() as tg: - tg.start_soon(super().loop) - - await self.started() - - async for message in self.mailbox: - worker = await self.router(message) - worker.inbox.send_nowait(message) diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000..4e5eaaf --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture(autouse=True) +def anyio_backend(): + return "asyncio" diff --git a/src/tests/llegos_test.py b/src/tests/llegos_test.py deleted file mode 100644 index a838056..0000000 --- a/src/tests/llegos_test.py +++ /dev/null @@ -1,117 +0,0 @@ -import pytest -import typing as t - -from anyio import create_task_group - -from llegos import llegos - - -class Pinger(llegos.Actor): - count: int - - async def init(self): - await super().init() - self.count = 0 - - async def handle_call(self, sender: llegos.Actor, message: t.Any): - self.count += 1 - return "ping" - - -class Ponger(llegos.Actor): - count: int - - async def init(self): - await super().init() - self.count = 0 - - async def handle_call(self, sender: llegos.Actor, message: t.Any): - self.count += 1 - return "pong" - - -@pytest.mark.asyncio -async def test_ping_pong(): - pinger = Pinger() - ponger = Ponger() - - async with create_task_group() as tg: - tg.start_soon(pinger.loop) - tg.start_soon(ponger.loop) - - await pinger.started() - await ponger.started() - - await pinger.call(ponger, "ping") - assert ponger.count == 1 - assert pinger.count == 0 - - await ponger.call(pinger, "pong") - assert pinger.count == 1 - - tg.cancel_scope.cancel() - - -class FaultyPonger(llegos.Actor): - count: int - - async def init(self): - await super().init() - self.count = 0 - - async def handle_call(self, sender: llegos.Actor, message: t.Any): - self.count += 1 - raise Exception("I'm faulty") - - -@pytest.mark.asyncio -async def test_supervisor_one_for_one(): - pinger = Pinger() - ponger = FaultyPonger() - - async with create_task_group() as tg: - supervisor = llegos.Supervisor( - children=[pinger, ponger], - strategy=llegos.Supervisor.Strategy.one_for_one, - ) - - tg.start_soon(supervisor.loop) - - await supervisor.started() # => supervisor, pinger, and ponger are started - - try: - await pinger.call(ponger, "ping", 0.1) - assert False, "ponger should've crashed" - except Exception: - ... - - await supervisor.started() - assert ponger.count == 0 - - tg.cancel_scope.cancel() - - -@pytest.mark.asyncio -@pytest.mark.skip -async def test_worker_pool(): - print() - llegos.runtime.events.on( - "after:perform", lambda msg, response: print(msg, response) - ) - - ponger = Ponger() - pinger_pool = llegos.WorkerPool(children=[Pinger(), Pinger(), Pinger()]) - - async with create_task_group() as tg: - tg.start_soon(ponger.loop) - tg.start_soon(pinger_pool.loop) - - await ponger.started() - await pinger_pool.started() - - for _ in range(10): - print(await ponger.call(pinger_pool, "pong")) - - tg.cancel_scope.cancel() - - assert sum(actor.count for actor in pinger_pool.children) == 10 diff --git a/src/tests/test_otp.py b/src/tests/test_otp.py new file mode 100644 index 0000000..3308d44 --- /dev/null +++ b/src/tests/test_otp.py @@ -0,0 +1,466 @@ +import anyio +import pytest +import asyncio + +from fastactor.otp import ( + Call, + Cast, + Crashed, + Down, + Exit, + Runtime, + GenServer, + Supervisor, + Failed, +) + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def runtime(): + """ + This fixture starts a fresh Runtime for each test, and ensures cleanup. + """ + async with Runtime() as rt: + yield rt + + +@pytest.fixture +async def supervisor(runtime: Runtime): + assert runtime.supervisor + return runtime.supervisor + + +async def test_process_lifecycle(runtime: Runtime): + """ + G: a Runtime is active, we spawn a GenServer G1. + W: G1 calls exit("normal") or stops normally. + T: runtime's dictionary no longer has G1, no abnormal signals to linked procs. + """ + + # Given + G1 = await GenServer.start() + assert G1.has_started() + + # Then + assert G1.id in runtime.processes + + # When + await G1.stop("normal") + + # Then + assert G1.has_stopped() + # We assume no abnormal exit occurred => no ActorCrashed + assert not isinstance(G1._crash_exc, Exception) + + assert G1.id not in runtime.processes + + +async def test_named_process_lookup(runtime: Runtime): + """ + G: A Runtime with a GenServer G2, and we call runtime.register_name("foo_server", G2). + W: we do pid = runtime.where_is("foo_server"). + T: pid == G2. + W: we kill G2 => the runtime's processes no longer has G2. + T: runtime.where_is("foo_server") returns None. + """ + + # Given + G2 = await GenServer.start() + runtime.register_name("foo_server", G2) + + # When + proc = runtime.where_is("foo_server") + + # Then + assert proc is not None + assert proc == G2 + + # When we kill G2 + await G2.kill() + await G2.stopped() + + assert runtime.where_is("foo_server") is None + assert G2.id not in runtime.processes + + +class CallReplyServer(GenServer): + async def handle_call(self, call: Call): + if call.message == "ping": + return "pong" + else: + raise ValueError(call.message) + + +async def test_genserver_call_reply(runtime: Runtime): + """ + G: G3 that returns "pong" if request=="ping", else raises => abnormal exit + W: call(G3, "ping") => "pong" + W: call(G3, "bad_input") => triggers an exception => G3 exits + T: G3 is removed from processes, call raises or returns error + """ + G3 = await CallReplyServer.start() + + # 1) "ping" + reply = await G3.call("ping") + assert reply == "pong" + + # 2) "bad_input" + with pytest.raises(ValueError, match="bad_input"): + await G3.call("bad_input") + + # Then G3 should be exited abnormally: + await asyncio.sleep(0.1) # give time to crash + assert G3.has_stopped() + assert isinstance(G3._crash_exc, ValueError) + + +class MathServer(GenServer): + count: int + + async def init(self, count: int = 0): + self.count = count + + async def handle_cast(self, msg: Cast): + match msg.message: + case ("add", n): + self.count += n + case ("sub", n): + self.count -= n + case ("mul", n): + self.count *= n + case ("div", n): + self.count /= n + case _: + raise ValueError(f"unknown message: {msg.message}") + return self.count + + +async def test_genserver_cast(runtime: Runtime): + """ + G: G4 has handle_cast => increments self.state["count"] + W: cast("increment") multiple times + T: calls return immediately, state increments each time + """ + + G4 = await MathServer.start(count=42) + + # cast + G4.cast(("div", 2)) + G4.cast(("add", 4)) + G4.cast(("div", 5)) + G4.cast(("mul", 2)) + + # Let them process + await asyncio.sleep(0.2) + + # Then + assert G4.count == 10 + + # Stop + await G4.stop() + assert G4.has_stopped() + + +class BoomServer(GenServer): + async def handle_call(self, msg: Call): + if msg.message == "boom": + raise RuntimeError("Boom!") + return "ok" + + +async def test_supervisor_one_for_one(runtime: Runtime, supervisor: Supervisor): + """ + G: Sup1 with child transient, if child crashes abnormally => restart, if normal => no restart + W: child crashes => new child is started + W: child later exits normal => no restart + T: verifies the child is replaced only on abnormal + """ + # Add a single child, transient + cA = await supervisor.start_child( + "childA", supervisor.child_spec(BoomServer, restart="transient") + ) + # We have a single child + assert len(supervisor.which_children()) == 1 + # abnormal crash + with pytest.raises(RuntimeError, match="Boom!"): + await cA.call("boom") + + await anyio.sleep(0.2) + # Should have restarted child + cA2 = supervisor.children["childA"][0] + assert id(cA2) != id(cA), "should have restarted the child" + + # normal exit => no restart + await cA2.stop("normal") + await anyio.sleep(0.2) + assert "childA" not in supervisor.children, "should have removed the child" + # The child is not restarted if normal exit => ephemeral + + +async def xtest_supervisor_one_for_all(runtime: Runtime): + """ + G: Sup2 with strategy=one_for_all, children [C1, C2, C3] + W: C2 crashes => all are killed => all are restarted + T: each child has new IDs + """ + + sup2_spec = runtime.supervisor.child_spec( + Supervisor, kwargs={"strategy": "one_for_all"} + ) + sup2 = await runtime.supervisor.start_child("sup2", sup2_spec) + + # Add 3 children + for cid in ["C1", "C2", "C3"]: + c_spec = sup2.child_spec(BoomServer, restart="transient") + await sup2.start_child(cid, c_spec) + + c1_old = sup2.children["C1"][0] + c2_old = sup2.children["C2"][0] + c3_old = sup2.children["C3"][0] + + # Crash C2 + with pytest.raises(RuntimeError, match="Boom!"): + await c2_old.call("boom") + await asyncio.sleep(0.3) + + # All replaced + c1_new = sup2.children["C1"][0] + c2_new = sup2.children["C2"][0] + c3_new = sup2.children["C3"][0] + + assert id(c1_new) != id(c1_old) + assert id(c2_new) != id(c2_old) + assert id(c3_new) != id(c3_old) + + +class LinkServer(GenServer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exits_received = [] + + async def handle_exit(self, msg: Exit): + self.exits_received.append((msg.sender, msg.reason)) + + +async def test_linking_trap_exits(runtime: Runtime): + """ + G: Two GenServers A and B. A.link(B), A.trap_exits=True. + W: B.exit("error") abnormally + T: A not killed, A.handle_info receives "EXIT(B, error)" or handle_exit is called + If A.trap_exits=False => A would crash too + """ + + A = await LinkServer.start(trap_exits=True) + B = await LinkServer.start(trap_exits=False) + + A.link(B) + + # B calls stop with abnormal reason + await B.stop("error") + await B.stopped() + + # A should remain alive + assert not A.has_stopped() + + # The abnormal exit is passed to A's handle_exit + assert len(A.exits_received) == 1 + assert A.exits_received[0][1] == "error" + + # Now if we had not set trap_exits=True in A => it would have died as well + # We'll stop A now + await A.stop("normal") + + +class MonitorServer(GenServer): + down_msgs: list[Down] + + async def init(self, *args, **kwargs): + self.down_msgs = [] + + async def handle_info(self, message): + if isinstance(message, Down): + self.down_msgs.append(message) + + +async def test_monitoring(runtime: Runtime): + """ + G: M monitors T (T.monitor(M)). + W: T exits "normal". + T: M.handle_info("DOWN", T, "normal") => M remains alive + """ + + M = await MonitorServer.start() + T = await GenServer.start() + + # M monitors T + M.monitor(T) + + # T stops normally + await T.stop("normal") + await T.stopped() + + # M remains alive + assert not M.has_stopped() + + # M should eventually get a Down message + await asyncio.sleep(0.2) + assert len(M.down_msgs) == 1 + msg = M.down_msgs[0] + assert msg.reason == "normal" + assert msg.sender == T + + # stop M + await M.stop("normal") + + +async def test_chain_crash_via_linking(runtime: Runtime): + """ + G: X, Y, Z. X.link(Y), Y.link(Z), none trap_exits + W: Z crashes with reason "fatal" + T: Y sees crash => Y also crashes => X sees crash => X also crashes + => all removed + """ + + X = await GenServer.start() + Y = await GenServer.start() + Z = await GenServer.start() + + X.link(Y) + Y.link(Z) + + await X.stop("fatal") + + # chain reaction => Y, then X => all gone + await asyncio.sleep(0.3) + + assert X.has_stopped() + assert Y.has_stopped() + # all have abnormal reason + assert isinstance(X._crash_exc, Crashed) and X._crash_exc.reason == "fatal" + assert isinstance(Y._crash_exc, Crashed) and Y._crash_exc.reason == "fatal" + assert isinstance(Z._crash_exc, Crashed) and Z._crash_exc.reason == "fatal" + + +class CrashyServer(GenServer): + async def init(self, *args, **kwargs): + raise RuntimeError("I always crash") + + +async def test_supervisor_crash(runtime: Runtime, supervisor: Supervisor): + """ + G: Sup3 with max_restarts=2, child that always crashes => 3 crashes => sup fails + W: child keeps failing, eventually sup => ActorFailed + T: sup removed, any watchers notified + """ + + sup = await supervisor.start_child( + "sup3", + supervisor.child_spec(Supervisor, kwargs={"max_restarts": 2, "max_seconds": 5}), + ) + + # add child that always crashes on startup + with pytest.raises(Failed, match="Max restart intensity reached"): + await sup.start_child( + "bad_child", sup.child_spec(CrashyServer, restart="transient") + ) + await sup.stopped() + + # sup3 should be removed from parent's children + assert "sup3" not in runtime.supervisor.children + + +async def test_multiple_monitors_links(runtime: Runtime): + """ + G: process Z monitored by [M1, M2, M3] and linked to L1 + W: Z.exit("boom") + T: each M? sees ("DOWN", Z, "boom"), L1 goes down if trap_exits=False + """ + + M1 = await MonitorServer.start() + M2 = await MonitorServer.start() + M3 = await MonitorServer.start() + L1 = await GenServer.start(trap_exits=False) + Z = await GenServer.start() + + # Monitor Z + M1.monitor(Z) + M2.monitor(Z) + M3.monitor(Z) + # Link L1 with Z + L1.link(Z) + + # Z => "boom" + await Z.stop("boom") + await Z.stopped() + # Wait for chain reaction + await asyncio.sleep(0.2) + + # All M? got "DOWN" + for M in (M1, M2, M3): + assert len(M.downs) == 1 + assert M.downs[0][0] == "DOWN" + assert M.downs[0][1] == Z + assert M.downs[0][2] == "boom" + + # L1 => crashes if not trap_exits => "boom" + await L1.stopped() + assert L1._crash_exc == "boom" + + # M1, M2, M3 remain alive + for M in (M1, M2, M3): + assert not M.has_stopped() + + # Clean up + for M in (M1, M2, M3): + await M.stop("normal") + + +async def test_named_supervisor_tree(runtime: Runtime, supervisor: Supervisor): + """ + G: spawn RootSup, register as "root" + W: spawn children [SupA, SupB] in RootSup, each with own GenServer kids + T: we can do root_id = where_is("root"), from there .which_children() + """ + + # We'll define a minimal root + + # spawn root + root_spec = supervisor.child_spec(Supervisor, restart="permanent") + root = await supervisor.start_child("RootSup", root_spec) + + runtime.register_name("root", root) + assert runtime.where_is("root") == root + + # spawn sub supervisors + subA_spec = root.child_spec(Supervisor, restart="permanent") + subB_spec = root.child_spec(Supervisor, restart="permanent") + + supA = await root.start_child("SupA", subA_spec) + supB = await root.start_child("SupB", subB_spec) + + # e.g. 2 child servers each + server_spec = supA.child_spec(GenServer, restart="transient") + await supA.start_child("childA1", server_spec) + await supA.start_child("childA2", server_spec) + + server_spec2 = supB.child_spec(GenServer, restart="transient") + await supB.start_child("childB1", server_spec2) + await supB.start_child("childB2", server_spec2) + + # we can check root's children + rc = root.which_children() + assert len(rc) == 2 # SupA, SupB + + # subA has 2 children + ac = supA.which_children() + assert len(ac) == 2 + + # subB has 2 children + bc = supB.which_children() + assert len(bc) == 2 + + await root.stop("normal") + + assert root.has_stopped()