diff --git a/test_launch_ros/test/test_launch_ros/actions/test_load_composable_nodes.py b/test_launch_ros/test/test_launch_ros/actions/test_load_composable_nodes.py index c7b7f4df..f2413187 100644 --- a/test_launch_ros/test/test_launch_ros/actions/test_load_composable_nodes.py +++ b/test_launch_ros/test/test_launch_ros/actions/test_load_composable_nodes.py @@ -44,14 +44,11 @@ class MockComponentContainer(rclpy.node.Node): - def __init__(self): + def __init__(self, context): # List of LoadNode requests received self.requests = [] - self._context = rclpy.context.Context() - rclpy.init(context=self._context) - - super().__init__(TEST_CONTAINER_NAME, context=self._context) + super().__init__(TEST_CONTAINER_NAME, context=context) self.load_node_service = self.create_service( LoadNode, @@ -59,16 +56,6 @@ def __init__(self): self.load_node_callback ) - self._executor = rclpy.executors.SingleThreadedExecutor(context=self._context) - - # Start spinning in a thread - self._thread = threading.Thread( - target=rclpy.spin, - args=(self, self._executor), - daemon=True - ) - self._thread.start() - def load_node_callback(self, request, response): self.requests.append(request) response.success = True @@ -79,12 +66,6 @@ def load_node_callback(self, request, response): response.unique_id = len(self.requests) return response - def shutdown(self): - self._executor.shutdown() - rclpy.shutdown(context=self._context) - self.destroy_node() - self._thread.join() - def _assert_launch_no_errors(actions): ld = LaunchDescription(actions) @@ -122,9 +103,20 @@ def _load_composable_node( @pytest.fixture def mock_component_container(): - container = MockComponentContainer() - yield container - container.shutdown() + context = rclpy.context.Context() + with rclpy.init(context=context): + executor = rclpy.executors.SingleThreadedExecutor(context=context) + + container = MockComponentContainer(context) + executor.add_node(container) + + # Start spinning in a thread + thread = threading.Thread(target=lambda executor: executor.spin(), args=(executor,)) + thread.start() + yield container + executor.remove_node(container) + executor.shutdown() + thread.join() def test_load_node(mock_component_container):