Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rabbitmq worker #45

Merged
merged 8 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions protollm_tools/llm-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ services:
- RABBITMQ_DEFAULT_USER=admin
- RABBITMQ_DEFAULT_PASS=admin
volumes:
- "rabbitmq_data:/var/lib/rabbitmq"
- rabbitmq_data:/var/lib/rabbitmq
networks:
- llm_wrap_network

Expand All @@ -48,4 +48,8 @@ services:
networks:
llm_wrap_network:
name: llm_wrap_network
driver: bridge
driver: bridge

volumes:
rabbitmq_data:
redis_data:
10 changes: 5 additions & 5 deletions protollm_tools/llm-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 9 additions & 30 deletions protollm_tools/llm-api/protollm_api/backend/broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pika
import logging
import json

from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.config import Config
from protollm_sdk.models.job_context_models import (
Expand All @@ -10,35 +10,24 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

async def send_task(config: Config, queue_name: str, transaction: PromptTransactionModel | ChatCompletionTransactionModel, task_type='generate'):
async def send_task(config: Config,
queue_name: str,
transaction: PromptTransactionModel | ChatCompletionTransactionModel,
rabbitmq: RabbitMQWrapper,
task_type='generate'):
"""
Sends a task to the RabbitMQ queue.

Args:
config (Config): Configuration object containing RabbitMQ connection details.
queue_name (str): Name of the RabbitMQ queue where the task will be published.
transaction (PromptTransactionModel | ChatCompletionTransactionModel): Transaction data to be sent.
rabbitmq (RabbitMQWrapper): Rabbit wrapper object to interact with the Rabbit queue.
task_type (str, optional): The type of task to be executed (default is 'generate').

Raises:
Exception: If there is an error during the connection or message publishing process.
"""
connection = pika.BlockingConnection(
pika.ConnectionParameters(
host=config.rabbit_host,
port=config.rabbit_port,
virtual_host='/',
credentials=pika.PlainCredentials(
username=config.rabbit_login,
password=config.rabbit_password
)
)
)
channel = connection.channel()

# Declare the queue if it does not exist
channel.queue_declare(queue=queue_name)

task = {
"type": "task",
"task": task_type,
Expand All @@ -49,18 +38,8 @@ async def send_task(config: Config, queue_name: str, transaction: PromptTransact
"eta": None
}

message = json.dumps(task)
rabbitmq.publish_message(queue_name, task)

# Publish the message to the RabbitMQ queue
channel.basic_publish(
exchange='',
routing_key=queue_name,
body=message,
properties=pika.BasicProperties(
delivery_mode=2, # Make message persistent
)
)
connection.close()


async def get_result(config: Config, task_id: str, redis_db: RedisWrapper) -> ResponseModel:
Expand Down
6 changes: 4 additions & 2 deletions protollm_tools/llm-api/protollm_api/backend/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PromptTypes
)
from protollm_sdk.object_interface.redis_wrapper import RedisWrapper
from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -20,14 +21,15 @@ def get_router(config: Config) -> APIRouter:
)

redis_db = RedisWrapper(config.redis_host, config.redis_port)
rabbitmq = RabbitMQWrapper(config.redis_host, config.redis_port, config.rabbit_login, config.rabbit_password)

@router.post('/generate', response_model=ResponseModel)
async def generate(prompt_data: PromptModel, queue_name: str = config.queue_name):
transaction_model = ChatCompletionTransactionModel(
prompt=ChatCompletionModel.from_prompt_model(prompt_data),
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
await send_task(config, queue_name, transaction_model)
await send_task(config, queue_name, transaction_model, rabbitmq)
logger.info(f"Task {prompt_data.job_id} was sent to LLM.")
return await get_result(config, prompt_data.job_id, redis_db)

Expand All @@ -37,7 +39,7 @@ async def chat_completion(prompt_data: ChatCompletionModel, queue_name: str = co
prompt=prompt_data,
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
await send_task(config, queue_name, transaction_model)
await send_task(config, queue_name, transaction_model, rabbitmq)
logger.info(f"Task {prompt_data.job_id} was sent to LLM.")
return await get_result(config, prompt_data.job_id, redis_db)

Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/llm-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pika = "^1.3.2"
pydantic = "^2.7.4"
flower = "2.0.1"

protollm_sdk = "^1.0.0"
protollm_sdk = "^1.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/llm-api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pika==1.3.2 ; python_version >= "3.10" and python_version < "4.0"
prometheus-client==0.21.0 ; python_version >= "3.10" and python_version < "4.0"
prompt-toolkit==3.0.48 ; python_version >= "3.10" and python_version < "4.0"
propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0"
protollm-sdk==1.0.0 ; python_version >= "3.10" and python_version < "4.0"
protollm-sdk==1.1.0 ; python_version >= "3.10" and python_version < "4.0"
pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0"
pydantic==2.9.2 ; python_version >= "3.10" and python_version < "4.0"
pygments==2.18.0 ; python_version >= "3.10" and python_version < "4.0"
Expand Down
14 changes: 12 additions & 2 deletions protollm_tools/llm-api/tests/integration/test_local_RMQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
import pytest
from protollm_sdk.models.job_context_models import (ChatCompletionModel, PromptMeta, ChatCompletionUnit,
ChatCompletionTransactionModel, PromptTypes)
from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.backend.broker import send_task

@pytest.fixture(scope="module")
def rabbit_client(test_local_config):
assert test_local_config.rabbit_host == "localhost"
client = RabbitMQWrapper(test_local_config.rabbit_host,
test_local_config.rabbit_port,
test_local_config.rabbit_login,
test_local_config.rabbit_password)
return client


@pytest.fixture(scope="module")
def rabbitmq_connection(test_local_config):
Expand All @@ -31,7 +41,7 @@ def rabbitmq_connection(test_local_config):


@pytest.mark.asyncio
async def test_task_in_queue(test_local_config, rabbitmq_connection):
async def test_task_in_queue(test_local_config, rabbitmq_connection, rabbit_client):
queue_name = "test_queue"
prompt = ChatCompletionModel(
job_id=str(uuid.uuid4()),
Expand All @@ -40,7 +50,7 @@ async def test_task_in_queue(test_local_config, rabbitmq_connection):
)
transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value)

await send_task(test_local_config, queue_name, transaction)
await send_task(test_local_config, queue_name, transaction, rabbit_client)

method_frame, header_frame, body = rabbitmq_connection.basic_get(queue=queue_name, auto_ack=True)

Expand Down
15 changes: 7 additions & 8 deletions protollm_tools/llm-api/tests/unit/test_brocker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch, MagicMock, ANY

import pytest
from protollm_sdk.models.job_context_models import ResponseModel, ChatCompletionTransactionModel, ChatCompletionModel, \
Expand All @@ -18,15 +18,14 @@ async def test_send_task(test_local_config):
)
transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value)

with patch("protollm_api.backend.broker.pika.BlockingConnection") as mock_connection:
mock_channel = MagicMock()
mock_connection.return_value.channel.return_value = mock_channel
mock_rabbit = MagicMock()
#mock_connection.return_value.channel.return_value = mock_channel

await send_task(test_local_config, test_local_config.queue_name, transaction)
await send_task(test_local_config, test_local_config.queue_name, transaction, mock_rabbit)

mock_connection.assert_called_once()
mock_channel.queue_declare.assert_called_with(queue=test_local_config.queue_name)
mock_channel.basic_publish.assert_called_once()
#mock_connection.assert_called_once()
mock_rabbit.publish_message.assert_called_once_with(test_local_config.queue_name, ANY)
#mock_channel.basic_publish.assert_called_once()


@pytest.mark.asyncio
Expand Down
5 changes: 3 additions & 2 deletions protollm_tools/llm-api/tests/unit/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from httpx import AsyncClient, ASGITransport
from protollm_sdk.models.job_context_models import ResponseModel, PromptTransactionModel, PromptModel, \
PromptTypes, ChatCompletionModel, ChatCompletionTransactionModel
from protollm_sdk.object_interface import RabbitMQWrapper

from protollm_api.backend.endpoints import get_router

Expand Down Expand Up @@ -47,7 +48,7 @@ async def test_generate_endpoint(test_app, test_local_config):
prompt=ChatCompletionModel.from_prompt_model(prompt_data),
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY)

get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY)

Expand Down Expand Up @@ -85,7 +86,7 @@ async def test_chat_completion_endpoint(test_app, test_local_config):
prompt=prompt_data,
prompt_type=PromptTypes.CHAT_COMPLETION.value
)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model)
send_task_mock.assert_called_once_with(test_local_config, "llm-api-queue", transaction_model, ANY)

get_result_mock.assert_called_once_with(test_local_config, "test-job-id", ANY)

Expand Down
15 changes: 15 additions & 0 deletions protollm_tools/llm-api/unit_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"listeners": {
"*:6672": {
"pass": "applications/backend"
}
},
"applications": {
"backend": {
"type": "python",
"path": ".",
"module": "protollm_api.backend.main",
"callable": "app"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,34 @@ def get_channel(self):
channel.close()
connection.close()

def publish_message(self, queue_name: str, message: dict):
def publish_message(self, queue_name: str, message: dict, priority: int = None):
"""
Publish a message to a specified queue.
Publish a message to a specified queue with an optional priority.

:param queue_name: Name of the queue to publish to
:param message: Message to publish (dictionary will be serialized to JSON)
:param priority: Optional priority of the message (0-255)
"""
try:
with self.get_channel() as channel:
channel.queue_declare(queue=queue_name, durable=True)
arguments = {}
if priority is not None:
arguments['x-max-priority'] = 10

channel.queue_declare(queue=queue_name, durable=True, arguments=arguments)

properties = pika.BasicProperties(
delivery_mode=2,
priority=priority if priority is not None else 0
)
channel.basic_publish(
exchange='',
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2 # Make message persistent
)
properties=properties
)
logger.info(f"Message published to queue '{queue_name}'")
logger.info(
f"Message published to queue '{queue_name}' with priority {priority if priority is not None else 'None'}")
except Exception as ex:
logger.error(f"Failed to publish message to queue '{queue_name}'. Error: {ex}")
raise Exception(f"Failed to publish message to queue '{queue_name}'. Error: {ex}") from ex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def test_publish_message(rabbit_wrapper):
assert method_frame is not None, "Message not found in the queue"
assert json.loads(body) == message, "Message in the queue does not match the sent message"

@pytest.mark.local
def test_publish_message_with_priority(rabbit_wrapper):
"""
Tests successful message publishing to a queue with priority.
"""
queue_name = "test_priority_queue"
message = {"key": "value"}
priority = 5

rabbit_wrapper.publish_message(queue_name, message, priority=priority)

with rabbit_wrapper.get_channel() as channel:
method_frame, header_frame, body = channel.basic_get(queue_name, auto_ack=True)
assert method_frame is not None, "Message not found in the queue"
assert json.loads(body) == message, "Message in the queue does not match the sent message"
assert header_frame.priority == priority, f"Expected priority {priority}, but got {header_frame.priority}"


@pytest.mark.local
def test_consume_message(rabbit_wrapper):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,39 @@ def test_publish_message(rabbit_wrapper, mock_pika):

rabbit_wrapper.publish_message(queue_name, message)

mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True)
mock_pika.queue_declare.assert_called_once_with(queue=queue_name, durable=True, arguments={})
mock_pika.basic_publish.assert_called_once_with(
exchange="",
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(delivery_mode=2, priority=0),
)

@pytest.mark.ci
def test_publish_message_with_priority(rabbit_wrapper, mock_pika):
"""
Tests successful message publishing to a queue with priority.
"""
queue_name = "test_queue"
message = {"key": "value"}
priority = 5

rabbit_wrapper.publish_message(queue_name, message, priority=priority)

mock_pika.queue_declare.assert_called_once_with(
queue=queue_name,
durable=True,
arguments={"x-max-priority": 10}
)

mock_pika.basic_publish.assert_called_once_with(
exchange="",
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(delivery_mode=2),
properties=pika.BasicProperties(
delivery_mode=2,
priority=priority
),
)

@pytest.mark.ci
Expand Down
Loading