Skip to content

Commit

Permalink
Make (s)publish commands available in transactions: Make (s)publish c…
Browse files Browse the repository at this point in the history
…ommands available in transactions: 1. Add publish apis to ClusterTransaction and Transaction. 2. Reverted StandaloneCommands.publish() to return int since it was not a good idea to diverge from the protocol. 3. Improved docs. 4 Updated tests. 5. Added handling of the missing unsubscriptions notifications.
  • Loading branch information
ikolomi committed Jul 1, 2024
1 parent fbbcb17 commit 36d5c01
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 67 deletions.
11 changes: 6 additions & 5 deletions python/python/glide/async_commands/cluster_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,22 +593,23 @@ async def publish(self, message: str, channel: str, sharded: bool = False) -> in
"""
Publish a message on pubsub channel.
This command aggregates PUBLISH and SPUBLISH commands functionalities.
The mode is selected using the 'sharded' parameter
The mode is selected using the 'sharded' parameter.
For both sharded and non-sharded mode, request is routed using hashed channel as key.
See https://valkey.io/commands/publish and https://valkey.io/commands/spublish for more details.
Args:
message (str): Message to publish
channel (str): Channel to publish the message on.
sharded (bool): Use sharded pubsub mode.
sharded (bool): Use sharded pubsub mode. Available since Redis version 7.0.
Returns:
int: Number of subscriptions in that shard that received the message.
int: Number of subscriptions in that node that received the message.
Examples:
>>> await client.publish("Hi all!", "global-channel", False)
1 # Publishes "Hi all!" message on global-channel channel using non-sharded mode
1 # Published 1 instance of "Hi all!" message on global-channel channel using non-sharded mode
>>> await client.publish("Hi to sharded channel1!", "channel1, True)
2 # Publishes "Hi to sharded channel1!" message on channel1 using sharded mode
2 # Published 2 instances of "Hi to sharded channel1!" message on channel1 using sharded mode
"""
result = await self._execute_command(
RequestType.SPublish if sharded else RequestType.Publish, [channel, message]
Expand Down
11 changes: 6 additions & 5 deletions python/python/glide/async_commands/standalone_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ async def sort_store(
result = await self._execute_command(RequestType.Sort, args)
return cast(int, result)

async def publish(self, message: str, channel: str) -> TOK:
async def publish(self, message: str, channel: str) -> int:
"""
Publish a message on pubsub channel.
See https://valkey.io/commands/publish for more details.
Expand All @@ -507,14 +507,15 @@ async def publish(self, message: str, channel: str) -> TOK:
channel (str): Channel to publish the message on.
Returns:
TOK: a simple `OK` response.
int: Number of subscriptions in primary node that received the message.
Note that this value does not include subscriptions that configured on replicas.
Examples:
>>> await client.publish("Hi all!", "global-channel")
"OK"
1 # This message was posted to 1 subscription which is configured on primary node
"""
await self._execute_command(RequestType.Publish, [channel, message])
return cast(TOK, OK)
result = await self._execute_command(RequestType.Publish, [channel, message])
return cast(int, result)

async def flushall(self, flush_mode: Optional[FlushMode] = None) -> TOK:
"""
Expand Down
36 changes: 36 additions & 0 deletions python/python/glide/async_commands/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4456,6 +4456,21 @@ def copy(

return self.append_command(RequestType.Copy, args)

def publish(self: TTransaction, message: str, channel: str) -> TTransaction:
"""
Publish a message on pubsub channel.
See https://valkey.io/commands/publish for more details.
Args:
message (str): Message to publish
channel (str): Channel to publish the message on.
Returns:
TOK: a simple `OK` response.
"""
return self.append_command(RequestType.Publish, [channel, message])


class ClusterTransaction(BaseTransaction):
"""
Expand Down Expand Up @@ -4551,4 +4566,25 @@ def copy(

return self.append_command(RequestType.Copy, args)

def publish(
self: TTransaction, message: str, channel: str, sharded: bool = False
) -> TTransaction:
"""
Publish a message on pubsub channel.
This command aggregates PUBLISH and SPUBLISH commands functionalities.
The mode is selected using the 'sharded' parameter
See https://valkey.io/commands/publish and https://valkey.io/commands/spublish for more details.
Args:
message (str): Message to publish
channel (str): Channel to publish the message on.
sharded (bool): Use sharded pubsub mode. Available since Redis version 7.0.
Returns:
int: Number of subscriptions in that shard that received the message.
"""
return self.append_command(
RequestType.SPublish if sharded else RequestType.Publish, [channel, message]
)

# TODO: add all CLUSTER commands
2 changes: 1 addition & 1 deletion python/python/glide/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class PubSubChannelModes(IntEnum):
Pattern = 1
""" Use channel name patterns """
Sharded = 2
""" Use sharded pubsub """
""" Use sharded pubsub. Available since Redis version 7.0. """

@dataclass
class PubSubSubscriptions:
Expand Down
2 changes: 2 additions & 0 deletions python/python/glide/glide_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def _notification_to_pubsub_message_safe(
or message_kind == "Subscribe"
or message_kind == "SSubscribe"
or message_kind == "Unsubscribe"
or message_kind == "PUnsubscribe"
or message_kind == "SUnsubscribe"
):
pass
else:
Expand Down
86 changes: 33 additions & 53 deletions python/python/tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ async def test_pubsub_exact_happy_path(
"""
channel = get_random_string(10)
message = get_random_string(5)
publish_response = 1 if cluster_mode else OK

callback, context = None, None
callback_messages: List[CoreCommands.PubSubMsg] = []
Expand All @@ -190,7 +189,9 @@ async def test_pubsub_exact_happy_path(
)

try:
assert await publishing_client.publish(message, channel) == publish_response
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1
# allow the message to propagate
await asyncio.sleep(1)

Expand Down Expand Up @@ -224,7 +225,6 @@ async def test_pubsub_exact_happy_path_coexistence(
channel = get_random_string(10)
message = get_random_string(5)
message2 = get_random_string(7)
publish_response = 1 if cluster_mode else OK

pub_sub = create_pubsub_subscription(
cluster_mode,
Expand All @@ -237,10 +237,11 @@ async def test_pubsub_exact_happy_path_coexistence(
)

try:
assert await publishing_client.publish(message, channel) == publish_response
assert (
await publishing_client.publish(message2, channel) == publish_response
)
for msg in [message, message2]:
result = await publishing_client.publish(msg, channel)
if cluster_mode:
assert result == 1

# allow the message to propagate
await asyncio.sleep(1)

Expand Down Expand Up @@ -288,7 +289,6 @@ async def test_pubsub_exact_happy_path_many_channels(
"""
NUM_CHANNELS = 256
shard_prefix = "{same-shard}"
publish_response = 1 if cluster_mode else OK

# Create a map of channels to random messages with shard prefix
channels_and_messages = {
Expand Down Expand Up @@ -324,10 +324,9 @@ async def test_pubsub_exact_happy_path_many_channels(
try:
# Publish messages to each channel
for channel, message in channels_and_messages.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# Allow the messages to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -371,7 +370,6 @@ async def test_pubsub_exact_happy_path_many_channels_co_existence(
"""
NUM_CHANNELS = 256
shard_prefix = "{same-shard}"
publish_response = 1 if cluster_mode else OK

# Create a map of channels to random messages with shard prefix
channels_and_messages = {
Expand Down Expand Up @@ -400,10 +398,9 @@ async def test_pubsub_exact_happy_path_many_channels_co_existence(
try:
# Publish messages to each channel
for channel, message in channels_and_messages.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# Allow the messages to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -684,7 +681,6 @@ async def test_pubsub_pattern(
"{{{}}}:{}".format("channel", get_random_string(5)): get_random_string(5),
"{{{}}}:{}".format("channel", get_random_string(5)): get_random_string(5),
}
publish_response = 1 if cluster_mode else OK

callback, context = None, None
callback_messages: List[CoreCommands.PubSubMsg] = []
Expand All @@ -705,10 +701,9 @@ async def test_pubsub_pattern(

try:
for channel, message in channels.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# allow the message to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -749,7 +744,6 @@ async def test_pubsub_pattern_co_existence(self, request, cluster_mode: bool):
"{{{}}}:{}".format("channel", get_random_string(5)): get_random_string(5),
"{{{}}}:{}".format("channel", get_random_string(5)): get_random_string(5),
}
publish_response = 1 if cluster_mode else OK

pub_sub = create_pubsub_subscription(
cluster_mode,
Expand All @@ -763,10 +757,9 @@ async def test_pubsub_pattern_co_existence(self, request, cluster_mode: bool):

try:
for channel, message in channels.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# allow the message to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -817,7 +810,6 @@ async def test_pubsub_pattern_many_channels(
"{{{}}}:{}".format("channel", get_random_string(5)): get_random_string(5)
for _ in range(NUM_CHANNELS)
}
publish_response = 1 if cluster_mode else OK

callback, context = None, None
callback_messages: List[CoreCommands.PubSubMsg] = []
Expand All @@ -838,10 +830,9 @@ async def test_pubsub_pattern_many_channels(

try:
for channel, message in channels.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# allow the message to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -904,8 +895,6 @@ async def test_pubsub_combined_exact_and_pattern_one_client(
**pattern_channels_and_messages,
}

publish_response = 1 if cluster_mode else OK

callback, context = None, None
callback_messages: List[CoreCommands.PubSubMsg] = []

Expand Down Expand Up @@ -941,10 +930,9 @@ async def test_pubsub_combined_exact_and_pattern_one_client(
try:
# Publish messages to all channels
for channel, message in all_channels_and_messages.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# allow the message to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -1018,8 +1006,6 @@ async def test_pubsub_combined_exact_and_pattern_multiple_clients(
**pattern_channels_and_messages,
}

publish_response = 1 if cluster_mode else OK

callback, context = None, None
callback_messages: List[CoreCommands.PubSubMsg] = []

Expand Down Expand Up @@ -1071,10 +1057,9 @@ async def test_pubsub_combined_exact_and_pattern_multiple_clients(
try:
# Publish messages to all channels
for channel, message in all_channels_and_messages.items():
assert (
await publishing_client.publish(message, channel)
== publish_response
)
result = await publishing_client.publish(message, channel)
if cluster_mode:
assert result == 1

# allow the messages to propagate
await asyncio.sleep(1)
Expand Down Expand Up @@ -1638,7 +1623,6 @@ async def test_pubsub_two_publishing_clients_same_name(
CHANNEL_NAME = "channel-name"
MESSAGE_EXACT = get_random_string(10)
MESSAGE_PATTERN = get_random_string(7)
publish_response = 2 if cluster_mode else OK
callback, context_exact, context_pattern = None, None, None
callback_messages_exact: List[CoreCommands.PubSubMsg] = []
callback_messages_pattern: List[CoreCommands.PubSubMsg] = []
Expand Down Expand Up @@ -1671,14 +1655,10 @@ async def test_pubsub_two_publishing_clients_same_name(

try:
# Publish messages to each channel - both clients publishing
assert (
await client_pattern.publish(MESSAGE_EXACT, CHANNEL_NAME)
== publish_response
)
assert (
await client_exact.publish(MESSAGE_PATTERN, CHANNEL_NAME)
== publish_response
)
for msg in [MESSAGE_EXACT, MESSAGE_PATTERN]:
result = await client_pattern.publish(msg, CHANNEL_NAME)
if cluster_mode:
assert result == 2

# allow the message to propagate
await asyncio.sleep(1)
Expand Down
12 changes: 9 additions & 3 deletions python/python/tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ async def test_cluster_transaction(self, redis_client: GlideClusterClient):
keyslot = get_random_string(3)
transaction = ClusterTransaction()
transaction.info()
if await check_if_server_version_lt(redis_client, "7.0.0"):
transaction.publish("test_message", keyslot, False)
else:
transaction.publish("test_message", keyslot, True)
expected = await transaction_test(transaction, keyslot, redis_client)
result = await redis_client.exec(transaction)
assert isinstance(result, list)
Expand All @@ -768,7 +772,8 @@ async def test_cluster_transaction(self, redis_client: GlideClusterClient):
assert isinstance(result[0], str)
# Making sure the "info" command is indeed a return at position 0
assert "# Memory" in result[0]
assert result[1:] == expected
assert result[1] == 0
assert result[2:] == expected

@pytest.mark.parametrize("cluster_mode", [True, False])
@pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3])
Expand Down Expand Up @@ -845,6 +850,7 @@ async def test_standalone_transaction(self, redis_client: GlideClient):
)
transaction.select(0)
transaction.get(key)
transaction.publish("test_message", "test_channel")
expected = await transaction_test(transaction, keyslot, redis_client)
result = await redis_client.exec(transaction)
assert isinstance(result, list)
Expand All @@ -853,8 +859,8 @@ async def test_standalone_transaction(self, redis_client: GlideClient):
assert isinstance(result[0], str)
assert "# Memory" in result[0]
assert result[1:5] == [OK, False, OK, value.encode()]
assert result[5:12] == [2, 2, 2, [b"Bob", b"Alice"], 2, OK, None]
assert result[12:] == expected
assert result[5:13] == [2, 2, 2, [b"Bob", b"Alice"], 2, OK, None, 0]
assert result[13:] == expected

def test_transaction_clear(self):
transaction = Transaction()
Expand Down

0 comments on commit 36d5c01

Please sign in to comment.