diff --git a/protollm_tools/llm-api/docker-compose.yml b/protollm_tools/llm-api/docker-compose.yml index d8f05da..a2c8d7a 100644 --- a/protollm_tools/llm-api/docker-compose.yml +++ b/protollm_tools/llm-api/docker-compose.yml @@ -32,7 +32,7 @@ services: - RABBITMQ_DEFAULT_USER=admin - RABBITMQ_DEFAULT_PASS=admin volumes: - - "rabbitmq_data:/var/lib/rabbitmq" + - rabbitmq_data:/var/lib/rabbitmq networks: - llm_wrap_network @@ -48,4 +48,8 @@ services: networks: llm_wrap_network: name: llm_wrap_network - driver: bridge \ No newline at end of file + driver: bridge + +volumes: + rabbitmq_data: + redis_data: \ No newline at end of file diff --git a/protollm_tools/llm-api/poetry.lock b/protollm_tools/llm-api/poetry.lock index df10de2..0309784 100644 --- a/protollm_tools/llm-api/poetry.lock +++ b/protollm_tools/llm-api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1709,13 +1709,13 @@ files = [ [[package]] name = "protollm-sdk" -version = "1.0.0" +version = "1.1.0" description = "" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "protollm_sdk-1.0.0-py3-none-any.whl", hash = "sha256:bd53331811e788c606551a7c19d2c59496612db8ac237f2c28a5bec6cf373c89"}, - {file = "protollm_sdk-1.0.0.tar.gz", hash = "sha256:7f34ab288115a33d44047d689703af731abd66af11846b5233c7083b3df8a8e7"}, + {file = "protollm_sdk-1.1.0-py3-none-any.whl", hash = "sha256:2f49516d0229a85fa0abf7d2bead2ed96a565686eb87b9559468eeb3198db7c7"}, + {file = "protollm_sdk-1.1.0.tar.gz", hash = "sha256:165d7d1270dadc7eacbf6d973f8410cd15fe2c7452cf6fb5f784cd5abc1f4c0e"}, ] [package.dependencies] @@ -2746,4 +2746,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1754efe7f94f40030f01885cd8befd08bd2a7a13010e4a85cef95dc50b5844fd" +content-hash = "31ced0fad9bd63b4f70d6e10e4a9d286399888423dcff7c05556ca5f20e3e93f" diff --git a/protollm_tools/llm-api/protollm_api/backend/broker.py b/protollm_tools/llm-api/protollm_api/backend/broker.py index 9cb9cbb..64123b3 100644 --- a/protollm_tools/llm-api/protollm_api/backend/broker.py +++ b/protollm_tools/llm-api/protollm_api/backend/broker.py @@ -1,6 +1,6 @@ -import pika import logging -import json + +from protollm_sdk.object_interface import RabbitMQWrapper from protollm_api.config import Config from protollm_sdk.models.job_context_models import ( @@ -10,7 +10,11 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def send_task(config: Config, queue_name: str, transaction: PromptTransactionModel | ChatCompletionTransactionModel, task_type='generate'): +async def send_task(config: Config, + queue_name: str, + transaction: PromptTransactionModel | ChatCompletionTransactionModel, + rabbitmq: RabbitMQWrapper, + task_type='generate'): """ Sends a task to the RabbitMQ queue. @@ -18,27 +22,12 @@ async def send_task(config: Config, queue_name: str, transaction: PromptTransact config (Config): Configuration object containing RabbitMQ connection details. queue_name (str): Name of the RabbitMQ queue where the task will be published. transaction (PromptTransactionModel | ChatCompletionTransactionModel): Transaction data to be sent. + rabbitmq (RabbitMQWrapper): Rabbit wrapper object to interact with the Rabbit queue. task_type (str, optional): The type of task to be executed (default is 'generate'). Raises: Exception: If there is an error during the connection or message publishing process. """ - connection = pika.BlockingConnection( - pika.ConnectionParameters( - host=config.rabbit_host, - port=config.rabbit_port, - virtual_host='/', - credentials=pika.PlainCredentials( - username=config.rabbit_login, - password=config.rabbit_password - ) - ) - ) - channel = connection.channel() - - # Declare the queue if it does not exist - channel.queue_declare(queue=queue_name) - task = { "type": "task", "task": task_type, @@ -49,18 +38,8 @@ async def send_task(config: Config, queue_name: str, transaction: PromptTransact "eta": None } - message = json.dumps(task) + rabbitmq.publish_message(queue_name, task) - # Publish the message to the RabbitMQ queue - channel.basic_publish( - exchange='', - routing_key=queue_name, - body=message, - properties=pika.BasicProperties( - delivery_mode=2, # Make message persistent - ) - ) - connection.close() async def get_result(config: Config, task_id: str, redis_db: RedisWrapper) -> ResponseModel: diff --git a/protollm_tools/llm-api/protollm_api/backend/endpoints.py b/protollm_tools/llm-api/protollm_api/backend/endpoints.py index d48f4d3..51a1d64 100644 --- a/protollm_tools/llm-api/protollm_api/backend/endpoints.py +++ b/protollm_tools/llm-api/protollm_api/backend/endpoints.py @@ -8,6 +8,7 @@ PromptTypes ) from protollm_sdk.object_interface.redis_wrapper import RedisWrapper +from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -20,6 +21,7 @@ def get_router(config: Config) -> APIRouter: ) redis_db = RedisWrapper(config.redis_host, config.redis_port) + rabbitmq = RabbitMQWrapper(config.redis_host, config.redis_port, config.rabbit_login, config.rabbit_password) @router.post('/generate', response_model=ResponseModel) async def generate(prompt_data: PromptModel, queue_name: str = config.queue_name): @@ -27,7 +29,7 @@ async def generate(prompt_data: PromptModel, queue_name: str = config.queue_name prompt=ChatCompletionModel.from_prompt_model(prompt_data), prompt_type=PromptTypes.CHAT_COMPLETION.value ) - await send_task(config, queue_name, transaction_model) + await send_task(config, queue_name, transaction_model, rabbitmq) logger.info(f"Task {prompt_data.job_id} was sent to LLM.") return await get_result(config, prompt_data.job_id, redis_db) @@ -37,7 +39,7 @@ async def chat_completion(prompt_data: ChatCompletionModel, queue_name: str = co prompt=prompt_data, prompt_type=PromptTypes.CHAT_COMPLETION.value ) - await send_task(config, queue_name, transaction_model) + await send_task(config, queue_name, transaction_model, rabbitmq) logger.info(f"Task {prompt_data.job_id} was sent to LLM.") return await get_result(config, prompt_data.job_id, redis_db) diff --git a/protollm_tools/llm-api/pyproject.toml b/protollm_tools/llm-api/pyproject.toml index 1d99903..0a0d0cf 100644 --- a/protollm_tools/llm-api/pyproject.toml +++ b/protollm_tools/llm-api/pyproject.toml @@ -12,7 +12,7 @@ pika = "^1.3.2" pydantic = "^2.7.4" flower = "2.0.1" -protollm_sdk = "^1.0.0" +protollm_sdk = "^1.1.0" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" diff --git a/protollm_tools/llm-api/requirements.txt b/protollm_tools/llm-api/requirements.txt index 052278c..e584b21 100644 --- a/protollm_tools/llm-api/requirements.txt +++ b/protollm_tools/llm-api/requirements.txt @@ -52,7 +52,7 @@ pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0" prometheus-client==0.21.0 ; python_version >= "3.10" and python_version < "4.0" prompt-toolkit==3.0.48 ; python_version >= "3.10" and python_version < "4.0" propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0" -protollm-sdk==1.0.0 ; python_version >= "3.10" and python_version < "4.0" +protollm-sdk==1.1.0 ; python_version >= "3.10" and python_version < "4.0" pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0" pydantic==2.9.2 ; python_version >= "3.10" and python_version < "4.0" pygments==2.18.0 ; python_version >= "3.10" and python_version < "4.0" diff --git a/protollm_tools/llm-api/tests/integration/test_local_RMQ.py b/protollm_tools/llm-api/tests/integration/test_local_RMQ.py index 053f507..60f7784 100644 --- a/protollm_tools/llm-api/tests/integration/test_local_RMQ.py +++ b/protollm_tools/llm-api/tests/integration/test_local_RMQ.py @@ -5,9 +5,19 @@ import pytest from protollm_sdk.models.job_context_models import (ChatCompletionModel, PromptMeta, ChatCompletionUnit, ChatCompletionTransactionModel, PromptTypes) +from protollm_sdk.object_interface import RabbitMQWrapper from protollm_api.backend.broker import send_task +@pytest.fixture(scope="module") +def rabbit_client(test_local_config): + assert test_local_config.rabbit_host == "localhost" + client = RabbitMQWrapper(test_local_config.rabbit_host, + test_local_config.rabbit_port, + test_local_config.rabbit_login, + test_local_config.rabbit_password) + return client + @pytest.fixture(scope="module") def rabbitmq_connection(test_local_config): @@ -31,7 +41,7 @@ def rabbitmq_connection(test_local_config): @pytest.mark.asyncio -async def test_task_in_queue(test_local_config, rabbitmq_connection): +async def test_task_in_queue(test_local_config, rabbitmq_connection, rabbit_client): queue_name = "test_queue" prompt = ChatCompletionModel( job_id=str(uuid.uuid4()), @@ -40,7 +50,7 @@ async def test_task_in_queue(test_local_config, rabbitmq_connection): ) transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value) - await send_task(test_local_config, queue_name, transaction) + await send_task(test_local_config, queue_name, transaction, rabbit_client) method_frame, header_frame, body = rabbitmq_connection.basic_get(queue=queue_name, auto_ack=True) diff --git a/protollm_tools/llm-api/tests/unit/test_brocker.py b/protollm_tools/llm-api/tests/unit/test_brocker.py index 930aaa2..d29d20c 100644 --- a/protollm_tools/llm-api/tests/unit/test_brocker.py +++ b/protollm_tools/llm-api/tests/unit/test_brocker.py @@ -1,6 +1,6 @@ import json import uuid -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch, MagicMock, ANY import pytest from protollm_sdk.models.job_context_models import ResponseModel, ChatCompletionTransactionModel, ChatCompletionModel, \ @@ -18,15 +18,14 @@ async def test_send_task(test_local_config): ) transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value) - with patch("protollm_api.backend.broker.pika.BlockingConnection") as mock_connection: - mock_channel = MagicMock() - mock_connection.return_value.channel.return_value = mock_channel + mock_rabbit = MagicMock() + #mock_connection.return_value.channel.return_value = mock_channel - await send_task(test_local_config, test_local_config.queue_name, transaction) + await send_task(test_local_config, test_local_config.queue_name, transaction, mock_rabbit) - mock_connection.assert_called_once() - mock_channel.queue_declare.assert_called_with(queue=test_local_config.queue_name) - mock_channel.basic_publish.assert_called_once() + #mock_connection.assert_called_once() + mock_rabbit.publish_message.assert_called_once_with(test_local_config.queue_name, ANY) + #mock_channel.basic_publish.assert_called_once() @pytest.mark.asyncio diff --git a/protollm_tools/llm-api/tests/unit/test_endpoints.py b/protollm_tools/llm-api/tests/unit/test_endpoints.py index 3dbe258..bce84ca 100644 --- a/protollm_tools/llm-api/tests/unit/test_endpoints.py +++ b/protollm_tools/llm-api/tests/unit/test_endpoints.py @@ -6,6 +6,7 @@ from httpx import AsyncClient, ASGITransport from protollm_sdk.models.job_context_models import ResponseModel, PromptTransactionModel, PromptModel, \ PromptTypes, ChatCompletionModel, ChatCompletionTransactionModel +from protollm_sdk.object_interface import RabbitMQWrapper from protollm_api.backend.endpoints import get_router @@ -47,7 +48,7 @@ async def test_generate_endpoint(test_app, test_local_config): prompt=ChatCompletionModel.from_prompt_model(prompt_data), prompt_type=PromptTypes.CHAT_COMPLETION.value ) - send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model) + send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY) get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY) @@ -85,7 +86,7 @@ async def test_chat_completion_endpoint(test_app, test_local_config): prompt=prompt_data, prompt_type=PromptTypes.CHAT_COMPLETION.value ) - send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model) + send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY) get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY) diff --git a/protollm_tools/llm-api/unit_config.json b/protollm_tools/llm-api/unit_config.json new file mode 100644 index 0000000..bde7f09 --- /dev/null +++ b/protollm_tools/llm-api/unit_config.json @@ -0,0 +1,15 @@ +{ + "listeners": { + "*:6672": { + "pass": "applications/backend" + } + }, + "applications": { + "backend": { + "type": "python", + "path": ".", + "module": "protollm_api.backend.main", + "callable": "app" + } + } +} \ No newline at end of file diff --git a/protollm_tools/sdk/protollm_sdk/object_interface/rabbit_mq_wrapper.py b/protollm_tools/sdk/protollm_sdk/object_interface/rabbit_mq_wrapper.py index 07e588e..f315d5c 100644 --- a/protollm_tools/sdk/protollm_sdk/object_interface/rabbit_mq_wrapper.py +++ b/protollm_tools/sdk/protollm_sdk/object_interface/rabbit_mq_wrapper.py @@ -39,26 +39,34 @@ def get_channel(self): channel.close() connection.close() - def publish_message(self, queue_name: str, message: dict): + def publish_message(self, queue_name: str, message: dict, priority: int = None): """ - Publish a message to a specified queue. + Publish a message to a specified queue with an optional priority. :param queue_name: Name of the queue to publish to :param message: Message to publish (dictionary will be serialized to JSON) + :param priority: Optional priority of the message (0-255) """ try: with self.get_channel() as channel: - channel.queue_declare(queue=queue_name, durable=True) + arguments = {} + if priority is not None: + arguments['x-max-priority'] = 10 + channel.queue_declare(queue=queue_name, durable=True, arguments=arguments) + + properties = pika.BasicProperties( + delivery_mode=2, + priority=priority if priority is not None else 0 + ) channel.basic_publish( exchange='', routing_key=queue_name, body=json.dumps(message), - properties=pika.BasicProperties( - delivery_mode=2 # Make message persistent - ) + properties=properties ) - logger.info(f"Message published to queue '{queue_name}'") + logger.info( + f"Message published to queue '{queue_name}' with priority {priority if priority is not None else 'None'}") except Exception as ex: logger.error(f"Failed to publish message to queue '{queue_name}'. Error: {ex}") raise Exception(f"Failed to publish message to queue '{queue_name}'. Error: {ex}") from ex diff --git a/protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/test_rabbit_mq_wrapper.py b/protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/test_rabbit_mq_wrapper.py index 1f78f8e..cbb01ca 100644 --- a/protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/test_rabbit_mq_wrapper.py +++ b/protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/test_rabbit_mq_wrapper.py @@ -44,6 +44,24 @@ def test_publish_message(rabbit_wrapper): assert method_frame is not None, "Message not found in the queue" assert json.loads(body) == message, "Message in the queue does not match the sent message" +@pytest.mark.local +def test_publish_message_with_priority(rabbit_wrapper): + """ + Tests successful message publishing to a queue with priority. + """ + queue_name = "test_priority_queue" + message = {"key": "value"} + priority = 5 + + rabbit_wrapper.publish_message(queue_name, message, priority=priority) + + with rabbit_wrapper.get_channel() as channel: + method_frame, header_frame, body = channel.basic_get(queue_name, auto_ack=True) + assert method_frame is not None, "Message not found in the queue" + assert json.loads(body) == message, "Message in the queue does not match the sent message" + assert header_frame.priority == priority, f"Expected priority {priority}, but got {header_frame.priority}" + + @pytest.mark.local def test_consume_message(rabbit_wrapper): """ diff --git a/protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/test_rabbit_mq_wrapper.py b/protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/test_rabbit_mq_wrapper.py index a50d101..b9d6117 100644 --- a/protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/test_rabbit_mq_wrapper.py +++ b/protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/test_rabbit_mq_wrapper.py @@ -27,12 +27,39 @@ def test_publish_message(rabbit_wrapper, mock_pika): rabbit_wrapper.publish_message(queue_name, message) - mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True) + mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True, arguments={}) + mock_pika.basic_publish.assert_called_once_with( + exchange="", + routing_key=queue_name, + body=json.dumps(message), + properties=pika.BasicProperties(delivery_mode=2, priority=0), + ) + +@pytest.mark.ci +def test_publish_message_with_priority(rabbit_wrapper, mock_pika): + """ + Tests successful message publishing to a queue with priority. + """ + queue_name = "test_queue" + message = {"key": "value"} + priority = 5 + + rabbit_wrapper.publish_message(queue_name, message, priority=priority) + + mock_pika.queue_declare.assert_called_once_with( + queue=queue_name, + durable=True, + arguments={"x-max-priority": 10} + ) + mock_pika.basic_publish.assert_called_once_with( exchange="", routing_key=queue_name, body=json.dumps(message), - properties=pika.BasicProperties(delivery_mode=2), + properties=pika.BasicProperties( + delivery_mode=2, + priority=priority + ), ) @pytest.mark.ci