Skip to content

Commit

Permalink
integrate rabbitmq wrapper to sdk (#45)
Browse files Browse the repository at this point in the history
* add: add rabbit_mq_wrapper.py and tests

* test: add CI and local markers to tests for improved filtering

* Update sdk-build.yml

only ci tests are run

* update: use rabbit_worker in api

* update: req and unit_config.json

* add: priority in rabbit wrapper
  • Loading branch information
1martin1 authored Jan 23, 2025
1 parent c3eaa57 commit 709c027
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 62 deletions.
8 changes: 6 additions & 2 deletions protollm_tools/llm-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ services:
- RABBITMQ_DEFAULT_USER=admin
- RABBITMQ_DEFAULT_PASS=admin
volumes:
- "rabbitmq_data:/var/lib/rabbitmq"
- rabbitmq_data:/var/lib/rabbitmq
networks:
- llm_wrap_network

Expand All @@ -48,4 +48,8 @@ services:
networks:
llm_wrap_network:
name: llm_wrap_network
driver: bridge
driver: bridge

volumes:
rabbitmq_data:
redis_data:
10 changes: 5 additions & 5 deletions protollm_tools/llm-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 9 additions & 30 deletions protollm_tools/llm-api/protollm_api/backend/broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pika
import logging
import json

from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.config import Config
from protollm_sdk.models.job_context_models import (
Expand All @@ -10,35 +10,24 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

async def send_task(config: Config, queue_name: str, transaction: PromptTransactionModel | ChatCompletionTransactionModel, task_type='generate'):
async def send_task(config: Config,
queue_name: str,
transaction: PromptTransactionModel | ChatCompletionTransactionModel,
rabbitmq: RabbitMQWrapper,
task_type='generate'):
"""
Sends a task to the RabbitMQ queue.
Args:
config (Config): Configuration object containing RabbitMQ connection details.
queue_name (str): Name of the RabbitMQ queue where the task will be published.
transaction (PromptTransactionModel | ChatCompletionTransactionModel): Transaction data to be sent.
rabbitmq (RabbitMQWrapper): Rabbit wrapper object to interact with the Rabbit queue.
task_type (str, optional): The type of task to be executed (default is 'generate').
Raises:
Exception: If there is an error during the connection or message publishing process.
"""
connection = pika.BlockingConnection(
pika.ConnectionParameters(
host=config.rabbit_host,
port=config.rabbit_port,
virtual_host='/',
credentials=pika.PlainCredentials(
username=config.rabbit_login,
password=config.rabbit_password
)
)
)
channel = connection.channel()

# Declare the queue if it does not exist
channel.queue_declare(queue=queue_name)

task = {
"type": "task",
"task": task_type,
Expand All @@ -49,18 +38,8 @@ async def send_task(config: Config, queue_name: str, transaction: PromptTransact
"eta": None
}

message = json.dumps(task)
rabbitmq.publish_message(queue_name, task)

# Publish the message to the RabbitMQ queue
channel.basic_publish(
exchange='',
routing_key=queue_name,
body=message,
properties=pika.BasicProperties(
delivery_mode=2, # Make message persistent
)
)
connection.close()


async def get_result(config: Config, task_id: str, redis_db: RedisWrapper) -> ResponseModel:
Expand Down
6 changes: 4 additions & 2 deletions protollm_tools/llm-api/protollm_api/backend/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PromptTypes
)
from protollm_sdk.object_interface.redis_wrapper import RedisWrapper
from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -20,14 +21,15 @@ def get_router(config: Config) -> APIRouter:
)

redis_db = RedisWrapper(config.redis_host, config.redis_port)
rabbitmq = RabbitMQWrapper(config.redis_host, config.redis_port, config.rabbit_login, config.rabbit_password)

@router.post('/generate', response_model=ResponseModel)
async def generate(prompt_data: PromptModel, queue_name: str = config.queue_name):
transaction_model = ChatCompletionTransactionModel(
prompt=ChatCompletionModel.from_prompt_model(prompt_data),
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
await send_task(config, queue_name, transaction_model)
await send_task(config, queue_name, transaction_model, rabbitmq)
logger.info(f"Task {prompt_data.job_id} was sent to LLM.")
return await get_result(config, prompt_data.job_id, redis_db)

Expand All @@ -37,7 +39,7 @@ async def chat_completion(prompt_data: ChatCompletionModel, queue_name: str = co
prompt=prompt_data,
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
await send_task(config, queue_name, transaction_model)
await send_task(config, queue_name, transaction_model, rabbitmq)
logger.info(f"Task {prompt_data.job_id} was sent to LLM.")
return await get_result(config, prompt_data.job_id, redis_db)

Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/llm-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pika = "^1.3.2"
pydantic = "^2.7.4"
flower = "2.0.1"

protollm_sdk = "^1.0.0"
protollm_sdk = "^1.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/llm-api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0"
prometheus-client==0.21.0 ; python_version >= "3.10" and python_version < "4.0"
prompt-toolkit==3.0.48 ; python_version >= "3.10" and python_version < "4.0"
propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0"
protollm-sdk==1.0.0 ; python_version >= "3.10" and python_version < "4.0"
protollm-sdk==1.1.0 ; python_version >= "3.10" and python_version < "4.0"
pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0"
pydantic==2.9.2 ; python_version >= "3.10" and python_version < "4.0"
pygments==2.18.0 ; python_version >= "3.10" and python_version < "4.0"
Expand Down
14 changes: 12 additions & 2 deletions protollm_tools/llm-api/tests/integration/test_local_RMQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
import pytest
from protollm_sdk.models.job_context_models import (ChatCompletionModel, PromptMeta, ChatCompletionUnit,
ChatCompletionTransactionModel, PromptTypes)
from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.backend.broker import send_task

@pytest.fixture(scope="module")
def rabbit_client(test_local_config):
assert test_local_config.rabbit_host == "localhost"
client = RabbitMQWrapper(test_local_config.rabbit_host,
test_local_config.rabbit_port,
test_local_config.rabbit_login,
test_local_config.rabbit_password)
return client


@pytest.fixture(scope="module")
def rabbitmq_connection(test_local_config):
Expand All @@ -31,7 +41,7 @@ def rabbitmq_connection(test_local_config):


@pytest.mark.asyncio
async def test_task_in_queue(test_local_config, rabbitmq_connection):
async def test_task_in_queue(test_local_config, rabbitmq_connection, rabbit_client):
queue_name = "test_queue"
prompt = ChatCompletionModel(
job_id=str(uuid.uuid4()),
Expand All @@ -40,7 +50,7 @@ async def test_task_in_queue(test_local_config, rabbitmq_connection):
)
transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value)

await send_task(test_local_config, queue_name, transaction)
await send_task(test_local_config, queue_name, transaction, rabbit_client)

method_frame, header_frame, body = rabbitmq_connection.basic_get(queue=queue_name, auto_ack=True)

Expand Down
15 changes: 7 additions & 8 deletions protollm_tools/llm-api/tests/unit/test_brocker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch, MagicMock, ANY

import pytest
from protollm_sdk.models.job_context_models import ResponseModel, ChatCompletionTransactionModel, ChatCompletionModel, \
Expand All @@ -18,15 +18,14 @@ async def test_send_task(test_local_config):
)
transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value)

with patch("protollm_api.backend.broker.pika.BlockingConnection") as mock_connection:
mock_channel = MagicMock()
mock_connection.return_value.channel.return_value = mock_channel
mock_rabbit = MagicMock()
#mock_connection.return_value.channel.return_value = mock_channel

await send_task(test_local_config, test_local_config.queue_name, transaction)
await send_task(test_local_config, test_local_config.queue_name, transaction, mock_rabbit)

mock_connection.assert_called_once()
mock_channel.queue_declare.assert_called_with(queue=test_local_config.queue_name)
mock_channel.basic_publish.assert_called_once()
#mock_connection.assert_called_once()
mock_rabbit.publish_message.assert_called_once_with(test_local_config.queue_name, ANY)
#mock_channel.basic_publish.assert_called_once()


@pytest.mark.asyncio
Expand Down
5 changes: 3 additions & 2 deletions protollm_tools/llm-api/tests/unit/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from httpx import AsyncClient, ASGITransport
from protollm_sdk.models.job_context_models import ResponseModel, PromptTransactionModel, PromptModel, \
PromptTypes, ChatCompletionModel, ChatCompletionTransactionModel
from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.backend.endpoints import get_router

Expand Down Expand Up @@ -47,7 +48,7 @@ async def test_generate_endpoint(test_app, test_local_config):
prompt=ChatCompletionModel.from_prompt_model(prompt_data),
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY)

get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY)

Expand Down Expand Up @@ -85,7 +86,7 @@ async def test_chat_completion_endpoint(test_app, test_local_config):
prompt=prompt_data,
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY)

get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY)

Expand Down
15 changes: 15 additions & 0 deletions protollm_tools/llm-api/unit_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"listeners": {
"*:6672": {
"pass": "applications/backend"
}
},
"applications": {
"backend": {
"type": "python",
"path": ".",
"module": "protollm_api.backend.main",
"callable": "app"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,34 @@ def get_channel(self):
channel.close()
connection.close()

def publish_message(self, queue_name: str, message: dict):
def publish_message(self, queue_name: str, message: dict, priority: int = None):
"""
Publish a message to a specified queue.
Publish a message to a specified queue with an optional priority.
:param queue_name: Name of the queue to publish to
:param message: Message to publish (dictionary will be serialized to JSON)
:param priority: Optional priority of the message (0-255)
"""
try:
with self.get_channel() as channel:
channel.queue_declare(queue=queue_name, durable=True)
arguments = {}
if priority is not None:
arguments['x-max-priority'] = 10

channel.queue_declare(queue=queue_name, durable=True, arguments=arguments)

properties = pika.BasicProperties(
delivery_mode=2,
priority=priority if priority is not None else 0
)
channel.basic_publish(
exchange='',
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2 # Make message persistent
)
properties=properties
)
logger.info(f"Message published to queue '{queue_name}'")
logger.info(
f"Message published to queue '{queue_name}' with priority {priority if priority is not None else 'None'}")
except Exception as ex:
logger.error(f"Failed to publish message to queue '{queue_name}'. Error: {ex}")
raise Exception(f"Failed to publish message to queue '{queue_name}'. Error: {ex}") from ex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def test_publish_message(rabbit_wrapper):
assert method_frame is not None, "Message not found in the queue"
assert json.loads(body) == message, "Message in the queue does not match the sent message"

@pytest.mark.local
def test_publish_message_with_priority(rabbit_wrapper):
"""
Tests successful message publishing to a queue with priority.
"""
queue_name = "test_priority_queue"
message = {"key": "value"}
priority = 5

rabbit_wrapper.publish_message(queue_name, message, priority=priority)

with rabbit_wrapper.get_channel() as channel:
method_frame, header_frame, body = channel.basic_get(queue_name, auto_ack=True)
assert method_frame is not None, "Message not found in the queue"
assert json.loads(body) == message, "Message in the queue does not match the sent message"
assert header_frame.priority == priority, f"Expected priority {priority}, but got {header_frame.priority}"


@pytest.mark.local
def test_consume_message(rabbit_wrapper):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,39 @@ def test_publish_message(rabbit_wrapper, mock_pika):

rabbit_wrapper.publish_message(queue_name, message)

mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True)
mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True, arguments={})
mock_pika.basic_publish.assert_called_once_with(
exchange="",
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(delivery_mode=2, priority=0),
)

@pytest.mark.ci
def test_publish_message_with_priority(rabbit_wrapper, mock_pika):
"""
Tests successful message publishing to a queue with priority.
"""
queue_name = "test_queue"
message = {"key": "value"}
priority = 5

rabbit_wrapper.publish_message(queue_name, message, priority=priority)

mock_pika.queue_declare.assert_called_once_with(
queue=queue_name,
durable=True,
arguments={"x-max-priority": 10}
)

mock_pika.basic_publish.assert_called_once_with(
exchange="",
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(delivery_mode=2),
properties=pika.BasicProperties(
delivery_mode=2,
priority=priority
),
)

@pytest.mark.ci
Expand Down

0 comments on commit 709c027

Please sign in to comment.