Skip to content

Commit

Permalink
add rabbit_mq_wrapper.py and tests (#44)
Browse files Browse the repository at this point in the history
* add: add rabbit_mq_wrapper.py and tests

* test: add CI and local markers to tests for improved filtering

* Update sdk-build.yml

only ci tests are run
  • Loading branch information
1martin1 authored Jan 21, 2025
1 parent 956b731 commit c3eaa57
Show file tree
Hide file tree
Showing 31 changed files with 1,013 additions and 536 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/sdk-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ jobs:
- name: Test sdk with pytest
run: |
cd ./protollm_tools/sdk
pytest -s ./tests
pytest -s ./tests -m ci
6 changes: 2 additions & 4 deletions protollm_tools/llm-api/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM nginx/unit:1.28.0-python3.10

COPY protollm_api /app/llm_api
COPY protollm_api /app/protollm_api
COPY requirements.txt /app
COPY unit_config.json /docker-entrypoint.d/config.json
WORKDIR /app
Expand All @@ -9,6 +9,4 @@ RUN pip install --upgrade pip
RUN pip install -r requirements.txt
RUN apt install git

COPY unit_config.json /docker-entrypoint.d/config.json

# CMD ["unitd", "--no-daemon", "--control", "unix:/var/run/control.unit.sock"]
COPY unit_config.json /docker-entrypoint.d/config.json
2 changes: 1 addition & 1 deletion protollm_tools/sdk/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SDK is a module for intaraction with LLMs in asynchronous task-based mode.
Пока тут пустовато
3 changes: 1 addition & 2 deletions protollm_tools/sdk/examples/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from protollm_sdk.config import Config
from protollm_sdk.celery.app import task_test
from protollm_sdk.celery.job import TextEmbedderJob, ResultStorageJob, LLMAPIJob, \
VectorDBJob, OuterLLMAPIJob
VectorDBJob, OuterLLMAPIJob # , LangchainLLMAPIJob
from protollm_sdk.object_interface import RedisWrapper


Expand Down Expand Up @@ -53,7 +53,6 @@ async def out_llm_resp(redis_client: RedisWrapper):
task_test.apply_async(args=(OuterLLMAPIJob.__name__, llm_request["job_id"]), kwargs=llm_request)
result = await redis_client.wait_item(f"{OuterLLMAPIJob.__name__}:{llm_request['job_id']}", timeout=60)


def get_dict(key):
rd = RedisWrapper(redis_host=Config.redis_host,
redis_port=Config.redis_port)
Expand Down
923 changes: 465 additions & 458 deletions protollm_tools/sdk/poetry.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion protollm_tools/sdk/protollm_sdk/celery/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion protollm_tools/sdk/protollm_sdk/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Job(ABC):
"""
Job interface for integration with outer modules.
Job interface for integration with outer modules to SDK.
All the required data for job should be passed and parameters should be defined in advance,
this also applies to the run method
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class PromptMeta(BaseModel):
temperature: float | None = 0.2
tokens_limit: int | None = 8096
stop_words: list[str] | None = None
model: str | None = None
model: str | None = Field(default=None, examples=[None])


class PromptModel(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from protollm_sdk.object_interface.redis_wrapper import RedisWrapper
from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import logging
from contextlib import contextmanager
import pika

logger = logging.getLogger(__name__)


class RabbitMQWrapper:
def __init__(self, rabbit_host: str, rabbit_port: int, rabbit_user: str, rabbit_password: str, virtual_host: str = '/'):
"""
Initialize RabbitMQ wrapper.
:param rabbit_host: RabbitMQ host
:param rabbit_port: RabbitMQ port
:param rabbit_user: RabbitMQ username
:param rabbit_password: RabbitMQ password
:param virtual_host: RabbitMQ virtual host
"""
self.connection_params = pika.ConnectionParameters(
host=rabbit_host,
port=rabbit_port,
virtual_host=virtual_host,
credentials=pika.PlainCredentials(rabbit_user, rabbit_password)
)

@contextmanager
def get_channel(self):
"""
Provide a channel for RabbitMQ operations.
:yield: pika channel
"""
connection = pika.BlockingConnection(self.connection_params)
channel = connection.channel()
try:
yield channel
finally:
channel.close()
connection.close()

def publish_message(self, queue_name: str, message: dict):
"""
Publish a message to a specified queue.
:param queue_name: Name of the queue to publish to
:param message: Message to publish (dictionary will be serialized to JSON)
"""
try:
with self.get_channel() as channel:
channel.queue_declare(queue=queue_name, durable=True)

channel.basic_publish(
exchange='',
routing_key=queue_name,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2 # Make message persistent
)
)
logger.info(f"Message published to queue '{queue_name}'")
except Exception as ex:
logger.error(f"Failed to publish message to queue '{queue_name}'. Error: {ex}")
raise Exception(f"Failed to publish message to queue '{queue_name}'. Error: {ex}") from ex

def consume_messages(self, queue_name: str, callback):
"""
Start consuming messages from a specified queue.
:param queue_name: Name of the queue to consume from
:param callback: Callback function to process messages
"""
try:
connection = pika.BlockingConnection(self.connection_params)
channel = connection.channel()

channel.queue_declare(queue=queue_name, durable=True)

channel.basic_consume(
queue=queue_name,
on_message_callback=callback,
auto_ack=True
)
logger.info(f"Started consuming messages from queue '{queue_name}'")
channel.start_consuming()
except Exception as ex:
logger.error(f"Failed to consume messages from queue '{queue_name}'. Error: {ex}")
raise Exception(f"Failed to consume messages from queue '{queue_name}'. Error: {ex}")
8 changes: 7 additions & 1 deletion protollm_tools/sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "protollm-sdk"
version = "1.0.1"
version = "1.1.0"
description = ""
authors = ["aimclub"]
readme = "README.md"
Expand All @@ -27,6 +27,12 @@ openai = "^1.42.0"
pytest = "^8.2.2"
pytest-asyncio = "^0.24.0"

[tool.pytest.ini_options]
markers = [
"local: Mark tests as part of the local pipeline (e.g., for Redis/Rabbit/etc)",
"ci: Mark tests as part of the CI pipeline"
]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 0 additions & 1 deletion protollm_tools/sdk/tests/celery/__init__.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def result_storage():
return {"question": "What is the ultimate question answer?",
"answers": "42"}


@pytest.mark.ci
def test_task_test_unknown_job_class(caplog):
task_id = str(uuid.uuid4())
task_class = "unknown_class"
Expand All @@ -24,7 +24,7 @@ def test_task_test_unknown_job_class(caplog):

assert f"Error in task '{task_id}'. Unknown job class: '{task_class}'." in caplog.text


@pytest.mark.local
def test_task_test_known_job_class(caplog, result_storage):
caplog.set_level(logging.INFO)
task_id = str(uuid.uuid4())
Expand All @@ -51,7 +51,7 @@ def run(self, task_id, ctx, **kwargs):
def dummy_job():
return DummyJob()


@pytest.mark.ci
def test_abstract_task_class_input(caplog, dummy_job):
caplog.set_level("INFO")

Expand All @@ -70,7 +70,7 @@ def test_abstract_task_class_input(caplog, dummy_job):
assert job_instance.task_id == task_id
assert job_instance.kwargs == {"test_arg": "value"}


@pytest.mark.ci
def test_abstract_task_instance_input(caplog, dummy_job):
caplog.set_level("INFO")

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,17 @@ def test_llm_request(llm_request):
res = LLMResponse(job_id=llm_request["job_id"], text=r.content)
assert isinstance(res, LLMResponse)


@pytest.mark.skip(reason="Test waits infinitely in GitHub Action")
@pytest.mark.local
def test_text_embedder_request(text_embedder_request):
random_id = uuid.uuid4()
result = task_test.apply_async(args=(TextEmbedderJob.__name__, random_id), kwargs=text_embedder_request)
assert isinstance(result.get(), TextEmbedderResponse)


@pytest.mark.skip(reason="Test waits infinitely in GitHub Action")
@pytest.mark.local
def test_result_storage(result_storage):
random_id = uuid.uuid4()
task_test.apply_async(args=(ResultStorageJob.__name__, random_id), kwargs=result_storage)


@pytest.mark.skip(reason="We don't have local vector DB")
def test_ping_vector_db():
random_id = uuid.uuid4()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def job_invoker():

# ---------------------------- Functional test of BlockingJobResult ----------------------------

@pytest.mark.ci
def test_blocking_job_result_initialization(blocking_job_result):
"""
Test that BlockingJobResult initializes correctly and calls _get_result.
Expand All @@ -60,7 +61,7 @@ def test_blocking_job_result_initialization(blocking_job_result):
assert isinstance(blocking_job_result.result, dict)
mock_storage.for_job.assert_called_once_with("test_job")


@pytest.mark.ci
def test_blocking_job_result_get_result(blocking_job_result):
"""
Test that get_result returns the correct result.
Expand All @@ -72,7 +73,7 @@ def test_blocking_job_result_get_result(blocking_job_result):
assert result == {"status": "success", "job_id": "12345"}
mock_storage.for_job().load_dict.assert_called_once_with("12345")


@pytest.mark.ci
@patch('time.sleep', return_value=None)
def test_blocking_job_result_timeout(mock_sleep, blocking_job_result):
"""
Expand All @@ -88,7 +89,7 @@ def test_blocking_job_result_timeout(mock_sleep, blocking_job_result):

assert mock_sleep.call_count > 0


@pytest.mark.ci
@patch('time.sleep', return_value=None)
def test_blocking_job_result_retries_and_succeeds(mock_sleep, blocking_job_result):
"""
Expand All @@ -103,7 +104,7 @@ def test_blocking_job_result_retries_and_succeeds(mock_sleep, blocking_job_resul
assert result == {"status": "success", "job_id": "12345"}
assert mock_storage.for_job().load_dict.call_count == 5


@pytest.mark.ci
def test_blocking_job_ping_result(blocking_job_result):
"""
Test that _ping_result calls the storage and returns the correct result.
Expand All @@ -121,7 +122,7 @@ def test_blocking_job_ping_result(blocking_job_result):

# ---------------------------- Functional test of WorkerJobResult ----------------------------


@pytest.mark.ci
def test_worker_job_result_initialization(worker_job_result):
"""
Test that WorkerJobResult initializes correctly and calls _get_result.
Expand All @@ -131,7 +132,7 @@ def test_worker_job_result_initialization(worker_job_result):
assert worker_job_result.job_id == "12345"
mock_storage.for_job.assert_called_once_with("test_job")


@pytest.mark.ci
def test_worker_job_result_get_result(worker_job_result):
"""
Test that get_result returns the correct result.
Expand All @@ -143,7 +144,7 @@ def test_worker_job_result_get_result(worker_job_result):
assert result == {"status": "success", "job_id": "12345"}
mock_storage.for_job().load_dict.assert_called_once_with("12345")


@pytest.mark.ci
@patch('time.sleep', return_value=None)
def test_worker_job_result_timeout(mock_sleep, worker_job_result):
"""
Expand All @@ -159,7 +160,7 @@ def test_worker_job_result_timeout(mock_sleep, worker_job_result):

assert mock_sleep.call_count > 0


@pytest.mark.ci
@patch('time.sleep', return_value=None)
def test_worker_job_result_retries_and_succeeds(mock_sleep, worker_job_result):
"""
Expand All @@ -174,7 +175,7 @@ def test_worker_job_result_retries_and_succeeds(mock_sleep, worker_job_result):
assert result == {"status": "success", "job_id": "12345"}
assert mock_storage.for_job().load_dict.call_count == 4


@pytest.mark.ci
def test_worker_job_ping_result(worker_job_result):
"""
Test that _ping_result calls the storage and returns the correct result.
Expand All @@ -192,7 +193,7 @@ def test_worker_job_ping_result(worker_job_result):

# ---------------------------- Functional test of JobInvoke ----------------------------


@pytest.mark.ci
def test_job_invoker_worker_success(job_invoker):
"""
Test JobInvoker with InvokeType.Worker.
Expand All @@ -206,7 +207,7 @@ def test_job_invoker_worker_success(job_invoker):
mock_task.apply_async.assert_called_once_with(args=(mock_job_class, mock.ANY), kwargs={})
assert isinstance(result, WorkerJobResult)


@pytest.mark.ci
def test_job_invoker_blocking_success(job_invoker):
"""
Test JobInvoker with InvokeType.Blocking.
Expand All @@ -222,7 +223,7 @@ def test_job_invoker_blocking_success(job_invoker):
mock_task.assert_called_once_with(mock_job_class, mock.ANY)
assert isinstance(result, BlockingJobResult)


@pytest.mark.ci
def test_job_invoker_no_task(job_invoker):
"""
Test JobInvoker when no task is provided.
Expand All @@ -236,7 +237,7 @@ def test_job_invoker_no_task(job_invoker):
with pytest.raises(Exception, match="Calling JobInvoker without abstract task."):
invoker.invoke(mock_job_class)


@pytest.mark.ci
def test_job_invoker_unknown_invoke_type(job_invoker):
"""
Test JobInvoker with an unknown InvokeType.
Expand Down
Loading

0 comments on commit c3eaa57

Please sign in to comment.