Skip to content

Commit

Permalink
Merge branch 'fork' into deab-savable-inh
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Feb 5, 2025
2 parents 845a7d6 + b485d7d commit 0a3b65e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
62 changes: 59 additions & 3 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol
import re
from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol

if TYPE_CHECKING:
# identifiers for subscribers
Expand All @@ -23,8 +24,8 @@ def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE |
def add_broadcast_subscriber(
self,
subscriber: 'BroadcastSubscriber',
subject_filters: list[Hashable | Pattern[str]] | None = None,
sender_filters: list[Hashable | Pattern[str]] | None = None,
subject_filters: list[Hashable | re.Pattern[str]] | None = None,
sender_filters: list[Hashable | re.Pattern[str]] | None = None,
identifier: 'ID_TYPE | None' = None,
) -> Any: ...

Expand All @@ -50,3 +51,58 @@ def broadcast_send(
def task_send(self, task: Any, no_reply: bool = False) -> Any: ...

def close(self) -> None: ...


class BroadcastFilter:
"""A filter that can be used to limit the subjects and/or senders that will be received"""

def __init__(self, subscriber, subject=None, sender=None): # type: ignore
self._subscriber = subscriber
self._subject_filters = []
self._sender_filters = []
if subject is not None:
self.add_subject_filter(subject)
if sender is not None:
self.add_sender_filter(sender)

@property
def __name__(self): # type: ignore
return 'BroadcastFilter'

def __call__(self, communicator, body, sender=None, subject=None, correlation_id=None): # type: ignore
if self.is_filtered(sender, subject):
return None
return self._subscriber(communicator, body, sender, subject, correlation_id)

def is_filtered(self, sender, subject) -> bool: # type: ignore
if subject is not None and self._subject_filters and not any(check(subject) for check in self._subject_filters):
return True

if sender is not None and self._sender_filters and not any(check(sender) for check in self._sender_filters):
return True

return False

def add_subject_filter(self, subject_filter: re.Pattern[str] | None) -> None:
self._subject_filters.append(self._ensure_filter(subject_filter)) # type: ignore

def add_sender_filter(self, sender_filter: re.Pattern[str]) -> None:
self._sender_filters.append(self._ensure_filter(sender_filter)) # type: ignore

@classmethod
def _ensure_filter(cls, filter_value): # type: ignore
if isinstance(filter_value, str):
return re.compile(filter_value.replace('.', '[.]').replace('*', '.*')).match
if isinstance(filter_value, re.Pattern): # pylint: disable=isinstance-second-argument-not-valid-type
return filter_value.match

return lambda val: val == filter_value

@classmethod
def _make_regex(cls, filter_str): # type: ignore
"""
:param filter_str: The filter string
:type filter_str: str
:return: The regular expression object
"""
return re.compile(filter_str.replace('.', '[.]'))
14 changes: 6 additions & 8 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
cast,
)

import kiwipy

from plumpy.coordinator import Coordinator
from plumpy.coordinator import BroadcastFilter, Coordinator
from plumpy.persistence import ensure_object_loader

try:
Expand Down Expand Up @@ -390,12 +388,12 @@ def init(self) -> None:

try:
# filter out state change broadcasts
# XXX: remove dep on kiwipy
subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*'))
subscriber = BroadcastFilter( # type: ignore
self.broadcast_receive,
subject=re.compile(r'^(?!state_changed).*'),
)
identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid))
# identifier = self._coordinator.add_broadcast_subscriber(
# subscriber, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid)
# )

self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier))
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid)
Expand Down

0 comments on commit 0a3b65e

Please sign in to comment.