diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 4ed931276..083b21386 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -35,7 +35,8 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union - +from asyncio import TaskGroup +import asyncio import warnings from rclpy.client import Client @@ -102,14 +103,14 @@ def wait(self, timeout_sec: Optional[float] = None): return True -async def await_or_execute(callback: Union[Callable, Coroutine], *args) -> Any: +async def await_or_execute(callback: Union[Callable, Coroutine], *args, **kwargs) -> Any: """Await a callback if it is a coroutine, else execute it.""" if inspect.iscoroutinefunction(callback): # Await a coroutine - return await callback(*args) + return await callback(*args, **kwargs) else: # Call a normal function - return callback(*args) + return callback(*args, **kwargs) class TimeoutException(Exception): @@ -550,12 +551,8 @@ async def handler(entity, gc, is_shutdown, work_tracker): gc.trigger() except InvalidHandle: pass - task = Task( - handler, (entity, self._guard, self._is_shutdown, self._work_tracker), - executor=self) - with self._tasks_lock: - self._tasks.append((task, entity, node)) - return task + + return handler(entity, self._guard, self._is_shutdown, self._work_tracker) def can_execute(self, entity: WaitableEntityType) -> bool: """ @@ -566,6 +563,185 @@ def can_execute(self, entity: WaitableEntityType) -> bool: """ return not entity._executor_event and entity.callback_group.can_execute(entity) + + def _construct_wait_set_and_wait( + self, nodes_to_use: list['Node'], timeout_timer: Optional[Timer], timeout_nsec: int + ) -> list[Tuple[Coroutine, WaitableEntityType, 'Node']]: + handlers: list[Tuple[Coroutine, WaitableEntityType, 'Node']] = [] + + # Gather entities that can be waited on + subscriptions: List[Subscription] = [] + guards: List[GuardCondition] = [] + timers: List[Timer] = [] + clients: List[Client] = [] + services: List[Service] = [] + waitables: List[Waitable[Any]] = [] + for node in nodes_to_use: + subscriptions.extend(filter(self.can_execute, node.subscriptions)) + timers.extend(filter(self.can_execute, node.timers)) + clients.extend(filter(self.can_execute, node.clients)) + services.extend(filter(self.can_execute, node.services)) + node_guards = filter(self.can_execute, node.guards) + waitables.extend(filter(self.can_execute, node.waitables)) + # retrigger a guard condition that was triggered but not handled + for gc in node_guards: + if gc._executor_triggered: + gc.trigger() + guards.append(gc) + if timeout_timer is not None: + timers.append(timeout_timer) + + guards.append(self._guard) + guards.append(self._sigint_gc) + + entity_count = NumberOfEntities( + len(subscriptions), len(guards), len(timers), len(clients), len(services)) + + # Construct a wait set + wait_set = None + with ExitStack() as context_stack: + sub_handles = [] + for sub in subscriptions: + try: + context_stack.enter_context(sub.handle) + sub_handles.append(sub.handle) + except InvalidHandle: + entity_count.num_subscriptions -= 1 + + client_handles = [] + for cli in clients: + try: + context_stack.enter_context(cli.handle) + client_handles.append(cli.handle) + except InvalidHandle: + entity_count.num_clients -= 1 + + service_handles = [] + for srv in services: + try: + context_stack.enter_context(srv.handle) + service_handles.append(srv.handle) + except InvalidHandle: + entity_count.num_services -= 1 + + timer_handles = [] + for tmr in timers: + try: + context_stack.enter_context(tmr.handle) + timer_handles.append(tmr.handle) + except InvalidHandle: + entity_count.num_timers -= 1 + + guard_handles = [] + for gc in guards: + try: + context_stack.enter_context(gc.handle) + guard_handles.append(gc.handle) + except InvalidHandle: + entity_count.num_guard_conditions -= 1 + + for waitable in waitables: + try: + context_stack.enter_context(waitable) + entity_count += waitable.get_num_entities() + except InvalidHandle: + pass + + context_stack.enter_context(self._context.handle) + + wait_set = _rclpy.WaitSet( + entity_count.num_subscriptions, + entity_count.num_guard_conditions, + entity_count.num_timers, + entity_count.num_clients, + entity_count.num_services, + entity_count.num_events, + self._context.handle) + + wait_set.clear_entities() + for sub_handle in sub_handles: + wait_set.add_subscription(sub_handle) + for cli_handle in client_handles: + wait_set.add_client(cli_handle) + for srv_capsule in service_handles: + wait_set.add_service(srv_capsule) + for tmr_handle in timer_handles: + wait_set.add_timer(tmr_handle) + for gc_handle in guard_handles: + wait_set.add_guard_condition(gc_handle) + for waitable in waitables: + waitable.add_to_wait_set(wait_set) + + # Wait for something to become ready + wait_set.wait(timeout_nsec) + if self._is_shutdown: + raise ShutdownException() + if not self._context.ok(): + raise ExternalShutdownException() + # get ready entities + subs_ready = wait_set.get_ready_entities('subscription') + guards_ready = wait_set.get_ready_entities('guard_condition') + timers_ready = wait_set.get_ready_entities('timer') + clients_ready = wait_set.get_ready_entities('client') + services_ready = wait_set.get_ready_entities('service') + + # Mark all guards as triggered before yielding since they're auto-taken + for gc in guards: + if gc.handle.pointer in guards_ready: + gc._executor_triggered = True + + # Check waitables before wait set is destroyed + for node in nodes_to_use: + for wt in node.waitables: + # Only check waitables that were added to the wait set + if wt in waitables and wt.is_ready(wait_set): + if wt.callback_group.can_execute(wt): + handler = self._make_handler(wt, node, self._take_waitable) + yield handler, wt, node + + # Process ready entities one node at a time + for node in nodes_to_use: + for tmr in node.timers: + if tmr.handle.pointer in timers_ready: + # Check timer is ready to workaround rcl issue with cancelled timers + if tmr.handle.is_timer_ready(): + if tmr.callback_group.can_execute(tmr): + handler = self._make_handler(tmr, node, self._take_timer) + yield handler, tmr, node + + for sub in node.subscriptions: + if sub.handle.pointer in subs_ready: + if sub.callback_group.can_execute(sub): + handler = self._make_handler(sub, node, self._take_subscription) + yield handler, sub, node + + for gc in node.guards: + if gc._executor_triggered: + if gc.callback_group.can_execute(gc): + handler = self._make_handler(gc, node, self._take_guard_condition) + yield handler, gc, node + + for client in node.clients: + if client.handle.pointer in clients_ready: + if client.callback_group.can_execute(client): + handler = self._make_handler(client, node, self._take_client) + yield handler, client, node + + for srv in node.services: + if srv.handle.pointer in services_ready: + if srv.callback_group.can_execute(srv): + handler = self._make_handler(srv, node, self._take_service) + yield handler, srv, node + + # Check timeout timer + if ( + timeout_nsec == 0 or + (timeout_timer is not None and timeout_timer.handle.pointer in timers_ready) + ): + raise TimeoutException() + + return handlers + def _wait_for_ready_callbacks( self, timeout_sec: Optional[Union[float, TimeoutObject]] = None, @@ -611,183 +787,15 @@ def _wait_for_ready_callbacks( # Get rid of any tasks that are done self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) - # Gather entities that can be waited on - subscriptions: List[Subscription] = [] - guards: List[GuardCondition] = [] - timers: List[Timer] = [] - clients: List[Client] = [] - services: List[Service] = [] - waitables: List[Waitable[Any]] = [] - for node in nodes_to_use: - subscriptions.extend(filter(self.can_execute, node.subscriptions)) - timers.extend(filter(self.can_execute, node.timers)) - clients.extend(filter(self.can_execute, node.clients)) - services.extend(filter(self.can_execute, node.services)) - node_guards = filter(self.can_execute, node.guards) - waitables.extend(filter(self.can_execute, node.waitables)) - # retrigger a guard condition that was triggered but not handled - for gc in node_guards: - if gc._executor_triggered: - gc.trigger() - guards.append(gc) - if timeout_timer is not None: - timers.append(timeout_timer) - - guards.append(self._guard) - guards.append(self._sigint_gc) - - entity_count = NumberOfEntities( - len(subscriptions), len(guards), len(timers), len(clients), len(services)) - - # Construct a wait set - wait_set = None - with ExitStack() as context_stack: - sub_handles = [] - for sub in subscriptions: - try: - context_stack.enter_context(sub.handle) - sub_handles.append(sub.handle) - except InvalidHandle: - entity_count.num_subscriptions -= 1 - - client_handles = [] - for cli in clients: - try: - context_stack.enter_context(cli.handle) - client_handles.append(cli.handle) - except InvalidHandle: - entity_count.num_clients -= 1 - - service_handles = [] - for srv in services: - try: - context_stack.enter_context(srv.handle) - service_handles.append(srv.handle) - except InvalidHandle: - entity_count.num_services -= 1 - timer_handles = [] - for tmr in timers: - try: - context_stack.enter_context(tmr.handle) - timer_handles.append(tmr.handle) - except InvalidHandle: - entity_count.num_timers -= 1 + for coro, entity, node in self._construct_wait_set_and_wait(nodes_to_use, timeout_timer, timeout_nsec): + yielded_work = True + task = Task(coro, executor=self) + with self._tasks_lock: + self._tasks.append((task, entity, node)) - guard_handles = [] - for gc in guards: - try: - context_stack.enter_context(gc.handle) - guard_handles.append(gc.handle) - except InvalidHandle: - entity_count.num_guard_conditions -= 1 + yield task, entity, node - for waitable in waitables: - try: - context_stack.enter_context(waitable) - entity_count += waitable.get_num_entities() - except InvalidHandle: - pass - - context_stack.enter_context(self._context.handle) - - wait_set = _rclpy.WaitSet( - entity_count.num_subscriptions, - entity_count.num_guard_conditions, - entity_count.num_timers, - entity_count.num_clients, - entity_count.num_services, - entity_count.num_events, - self._context.handle) - - wait_set.clear_entities() - for sub_handle in sub_handles: - wait_set.add_subscription(sub_handle) - for cli_handle in client_handles: - wait_set.add_client(cli_handle) - for srv_capsule in service_handles: - wait_set.add_service(srv_capsule) - for tmr_handle in timer_handles: - wait_set.add_timer(tmr_handle) - for gc_handle in guard_handles: - wait_set.add_guard_condition(gc_handle) - for waitable in waitables: - waitable.add_to_wait_set(wait_set) - - # Wait for something to become ready - wait_set.wait(timeout_nsec) - if self._is_shutdown: - raise ShutdownException() - if not self._context.ok(): - raise ExternalShutdownException() - - # get ready entities - subs_ready = wait_set.get_ready_entities('subscription') - guards_ready = wait_set.get_ready_entities('guard_condition') - timers_ready = wait_set.get_ready_entities('timer') - clients_ready = wait_set.get_ready_entities('client') - services_ready = wait_set.get_ready_entities('service') - - # Mark all guards as triggered before yielding since they're auto-taken - for gc in guards: - if gc.handle.pointer in guards_ready: - gc._executor_triggered = True - - # Check waitables before wait set is destroyed - for node in nodes_to_use: - for wt in node.waitables: - # Only check waitables that were added to the wait set - if wt in waitables and wt.is_ready(wait_set): - if wt.callback_group.can_execute(wt): - handler = self._make_handler(wt, node, self._take_waitable) - yielded_work = True - yield handler, wt, node - - # Process ready entities one node at a time - for node in nodes_to_use: - for tmr in node.timers: - if tmr.handle.pointer in timers_ready: - # Check timer is ready to workaround rcl issue with cancelled timers - if tmr.handle.is_timer_ready(): - if tmr.callback_group.can_execute(tmr): - handler = self._make_handler(tmr, node, self._take_timer) - yielded_work = True - yield handler, tmr, node - - for sub in node.subscriptions: - if sub.handle.pointer in subs_ready: - if sub.callback_group.can_execute(sub): - handler = self._make_handler(sub, node, self._take_subscription) - yielded_work = True - yield handler, sub, node - - for gc in node.guards: - if gc._executor_triggered: - if gc.callback_group.can_execute(gc): - handler = self._make_handler(gc, node, self._take_guard_condition) - yielded_work = True - yield handler, gc, node - - for client in node.clients: - if client.handle.pointer in clients_ready: - if client.callback_group.can_execute(client): - handler = self._make_handler(client, node, self._take_client) - yielded_work = True - yield handler, client, node - - for srv in node.services: - if srv.handle.pointer in services_ready: - if srv.callback_group.can_execute(srv): - handler = self._make_handler(srv, node, self._take_service) - yielded_work = True - yield handler, srv, node - - # Check timeout timer - if ( - timeout_nsec == 0 or - (timeout_timer is not None and timeout_timer.handle.pointer in timers_ready) - ): - raise TimeoutException() if self._is_shutdown: raise ShutdownException() if condition(): @@ -972,3 +980,72 @@ def shutdown( success: bool = super().shutdown(timeout_sec) self._executor.shutdown(wait=wait_for_threads) return success + +# if we want to enable "spin once" behavior: +# it is possible to run the loop in _wait_for_ready_callbacks once and exit +# or inherit from asyncio.EventLoop and call _run_once() +class AsyncioExecutor(Executor): + def __init__(self): + super().__init__() + self._loop = asyncio.new_event_loop() + self._task_group: Optional[TaskGroup] = None + self._thread_pool = ThreadPoolExecutor(max_workers=1) + + def _wait_for_ready_callbacks( + self, + timeout_sec: Optional[Union[float, TimeoutObject]] = None, + nodes: Optional[List['Node']] = None, + condition: Callable[[], bool] = lambda: False, + ) -> Generator[Tuple[Task, WaitableEntityType, 'Node'], None, None]: + """ + Yield callbacks that are ready to be executed. + + :raise TimeoutException: on timeout. + :raise ShutdownException: on if executor was shut down. + + :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + :param nodes: A list of nodes to wait on. Wait on all nodes if ``None``. + :param condition: A callable that makes the function return immediately when it evaluates + to True. + """ + timeout_timer = None + timeout_nsec = timeout_sec_to_nsec( + timeout_sec.timeout if isinstance(timeout_sec, TimeoutObject) else timeout_sec) + if timeout_nsec > 0: + timeout_timer = Timer(None, None, timeout_nsec, self._clock, context=self._context) + + while not self._is_shutdown and not condition(): + # Refresh "all" nodes in case executor was woken by a node being added or removed + nodes_to_use = nodes + if nodes is None: + nodes_to_use = self.get_nodes() + + new_coroutines = [] + for coro, entity, node in self._construct_wait_set_and_wait(nodes_to_use, timeout_timer, timeout_nsec): + new_coroutines.append(coro) + + self._loop.call_soon_threadsafe(self._add_coros_to_task_group, new_coroutines) + + if self._is_shutdown: + raise ShutdownException() + if condition(): + raise ConditionReachedException() + + def _add_coros_to_task_group(self, new_coroutines: list[Coroutine]): + for coro in new_coroutines: + self._task_group.create_task(coro) + + async def spin_async(self, once: bool = False): + async with TaskGroup() as tg: + self._task_group = tg + await self._loop.run_in_executor(self._thread_pool, self._wait_for_ready_callbacks) + + def spin(self): + self._loop.run_until_complete(self.spin_async()) + + def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> Task: + if not inspect.iscoroutine(callback): + callback = await_or_execute(callback, *args, **kwargs) + + return self._loop.create_task(callback)