Skip to content

Commit

Permalink
feat: add publish batch messages (#167)
Browse files Browse the repository at this point in the history
* feat: add publish batch messages

* feat: remove get attr module

* feat: update variable name to make code legible
  • Loading branch information
gabriel-f-santos authored Feb 19, 2024
1 parent 0e4deec commit 9a9079a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
8 changes: 7 additions & 1 deletion serpens/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +20,7 @@ def __init__(self, provider: Optional[MessageProvider] = None):
logger.debug(f"Provider: {self._provider.value}")
module = importlib.import_module(f"serpens.{self._provider.value}")
self._publish = module.publish_message
self._publish_batch = module.publish_message_batch

def publish(
self,
Expand All @@ -30,6 +31,11 @@ def publish(
) -> Dict[str, Any]:
return self._publish(destination, body, order_key, attributes)

def publish_batch(
self, destination: str, messages: List[Any], order_key: Optional[str] = None
) -> Dict[str, Any]:
return self._publish_batch(destination, messages, order_key)

@classmethod
def instance(cls):
if cls._instance is None:
Expand Down
3 changes: 2 additions & 1 deletion serpens/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def build_message_attributes(attributes):
return message_attributes


def publish_message_batch(queue_url, messages, message_group_id=None):
def publish_message_batch(queue_url, messages, order_key=None):
message_group_id = order_key
client = boto3.client("sqs")
entries = []
result = []
Expand Down
50 changes: 50 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,44 @@ def setUp(self):
self.attributes = {"app_name": "platform-default"}
self.order_key = "group-test-id"

self.messages = [
{
"body": "message 1",
"attributes": {
"key1": "value1",
"key2": 123,
"key3": b"binary data",
},
},
{
"body": "message 2",
"attributes": {
"key1": "value2",
"key2": "123",
"key3": 123456,
},
},
]

self.expected_entries = [
{
"MessageBody": "message 1",
"MessageAttributes": {
"key1": {"StringValue": "value1", "DataType": "String"},
"key2": {"StringValue": 123, "DataType": "Number"},
"key3": {"BinaryValue": b"binary data", "DataType": "Binary"},
},
},
{
"MessageBody": "message 2",
"MessageAttributes": {
"key1": {"StringValue": "value2", "DataType": "String"},
"key2": {"StringValue": "123", "DataType": "String"},
"key3": {"StringValue": 123456, "DataType": "Number"},
},
},
]

def tearDown(self):
self.patch_boto3.stop()
self.patch_pubsub_v1.stop()
Expand All @@ -42,6 +80,18 @@ def test_publish_message_sqs(self):
MessageDeduplicationId=self.order_key,
)

@patch.dict(os.environ, {"MESSAGE_PROVIDER": "sqs"})
def test_publish_message_batch_sqs(self):
MessageClient().publish_batch(self.destination, self.messages)

call_entries = self.sqs_client.send_message_batch.call_args.kwargs["Entries"]

for entry in call_entries:
del entry["Id"]

self.sqs_client.send_message_batch.assert_called_once()
self.assertListEqual(call_entries, self.expected_entries)

@patch.dict(os.environ, {"MESSAGE_PROVIDER": "pubsub"})
def test_publish_message_pubsub(self):
MessageClient().publish(self.destination, self.body, self.order_key, self.attributes)
Expand Down

0 comments on commit 9a9079a

Please sign in to comment.