Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support aioredis 2.0 #99

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version:
- 3.6
- 3.7
- 3.8
- 3.9
aioredis-version:
- 'aioredis<2.0.0'
- 'aioredis>=2.0.0'

services:
redis:
Expand All @@ -27,6 +34,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[test,cicd,package]
pip install "${{ matrix.aioredis-version }}"
- name: Check syntax
run: |
make syntax
Expand Down
49 changes: 36 additions & 13 deletions aioredlock/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import aioredis

try:
from aioredis.errors import ReplyError as NoScriptError, RedisError
except ImportError:
from aioredis.exceptions import NoScriptError, RedisError

from aioredlock.errors import LockError, LockAcquiringError, LockRuntimeError
from aioredlock.sentinel import Sentinel
from aioredlock.utility import clean_password
Expand Down Expand Up @@ -115,6 +120,21 @@ async def _create_redis_pool(*args, **kwargs):
else: # pragma no cover
return await aioredis.create_pool(*args, **kwargs)

def _evalsha(self, redis, sha, keys, args):
if StrictVersion(aioredis.__version__) >= StrictVersion('2.0.0'):
return redis.evalsha(
sha,
len(keys),
*keys,
*args,
)
else:
return redis.evalsha(
digest=sha,
keys=keys,
args=args,
)

async def _register_scripts(self, redis):
tasks = []
for script in [
Expand All @@ -123,7 +143,7 @@ async def _register_scripts(self, redis):
self.GET_LOCK_TTL_SCRIPT,
]:
script = re.sub(r'^\s+', '', script, flags=re.M).strip()
tasks.append(redis.script_load(script))
tasks.append(redis.script_load(script=script))
(
self.set_lock_script_sha1,
self.unset_lock_script_sha1,
Expand Down Expand Up @@ -195,18 +215,19 @@ async def set_lock(self, resource, lock_identifier, lock_timeout, register_scrip
with await self.connect() as redis:
if register_scripts is True:
await self._register_scripts(redis)
await redis.evalsha(
await self._evalsha(
redis,
self.set_lock_script_sha1,
keys=[resource],
args=[lock_identifier, lock_timeout_ms]
)
except aioredis.errors.ReplyError as exc: # script fault
if exc.args[0].startswith('NOSCRIPT'):
except NoScriptError as exc: # script fault
if exc.__class__.__name__ == 'NoScriptError' or exc.args[0].startswith('NOSCRIPT'):
return await self.set_lock(resource, lock_identifier, lock_timeout, register_scripts=True)
self.log.debug('Can not set lock "%s" on %s',
resource, repr(self))
raise LockAcquiringError('Can not set lock') from exc
except (aioredis.errors.RedisError, OSError) as exc:
except (RedisError, OSError) as exc:
self.log.error('Can not set lock "%s" on %s: %s',
resource, repr(self), repr(exc))
raise LockRuntimeError('Can not set lock') from exc
Expand All @@ -233,18 +254,19 @@ async def get_lock_ttl(self, resource, lock_identifier, register_scripts=False):
with await self.connect() as redis:
if register_scripts is True:
await self._register_scripts(redis)
ttl = await redis.evalsha(
ttl = await self._evalsha(
redis,
self.get_lock_ttl_script_sha1,
keys=[resource],
args=[lock_identifier]
)
except aioredis.errors.ReplyError as exc: # script fault
if exc.args[0].startswith('NOSCRIPT'):
except NoScriptError as exc: # script fault
if exc.__class__.__name__ == 'NoScriptError' or exc.args[0].startswith('NOSCRIPT'):
return await self.get_lock_ttl(resource, lock_identifier, register_scripts=True)
self.log.debug('Can not get lock "%s" on %s',
resource, repr(self))
raise LockAcquiringError('Can not get lock') from exc
except (aioredis.errors.RedisError, OSError) as exc:
except (RedisError, OSError) as exc:
self.log.error('Can not get lock "%s" on %s: %s',
resource, repr(self), repr(exc))
raise LockRuntimeError('Can not get lock') from exc
Expand All @@ -271,18 +293,19 @@ async def unset_lock(self, resource, lock_identifier, register_scripts=False):
with await self.connect() as redis:
if register_scripts is True:
await self._register_scripts(redis)
await redis.evalsha(
await self._evalsha(
redis,
self.unset_lock_script_sha1,
keys=[resource],
args=[lock_identifier]
)
except aioredis.errors.ReplyError as exc: # script fault
if exc.args[0].startswith('NOSCRIPT'):
except NoScriptError as exc: # script fault
if exc.__class__.__name__ == 'NoScriptError' or exc.args[0].startswith('NOSCRIPT'):
return await self.unset_lock(resource, lock_identifier, register_scripts=True)
self.log.debug('Can not unset lock "%s" on %s',
resource, repr(self))
raise LockAcquiringError('Can not unset lock') from exc
except (aioredis.errors.RedisError, OSError) as exc:
except (RedisError, OSError) as exc:
self.log.error('Can not unset lock "%s" on %s: %s',
resource, repr(self), repr(exc))
raise LockRuntimeError('Can not unset lock') from exc
Expand Down
7 changes: 5 additions & 2 deletions aioredlock/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import ssl
import urllib.parse

import aioredis.sentinel
try:
from aioredis.sentinel import create_sentinel
except ImportError:
from aioredis.sentinel import Sentinel as create_sentinel


class SentinelConfigError(Exception):
Expand Down Expand Up @@ -102,7 +105,7 @@ async def get_sentinel(self):
'''
Retrieve sentinel object from aioredis.
'''
return await aioredis.sentinel.create_sentinel(
return await create_sentinel(
sentinels=self.connection,
**self.redis_kwargs,
)
Expand Down
74 changes: 53 additions & 21 deletions tests/ut/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import asyncio
import hashlib
import sys
from distutils.version import StrictVersion
from unittest.mock import MagicMock, call, patch

import aioredis
import pytest

try:
from aioredis.errors import ReplyError as ResponseError
except ImportError:
from aioredis.exceptions import ResponseError

from aioredlock.errors import LockError, LockAcquiringError, LockRuntimeError
from aioredlock.redis import Instance, Redis
from aioredlock.sentinel import Sentinel
Expand All @@ -19,7 +25,7 @@ def callculate_sha1(text):


EVAL_OK = b'OK'
EVAL_ERROR = aioredis.errors.ReplyError('ERROR')
EVAL_ERROR = ResponseError('ERROR')
CANCELLED = asyncio.CancelledError('CANCELLED')
CONNECT_ERROR = OSError('ERROR')
RANDOM_ERROR = Exception('FAULT')
Expand Down Expand Up @@ -221,9 +227,11 @@ async def test_lock(self, fake_instance):
await instance.set_lock('resource', 'lock_id', 10.0)

pool.evalsha.assert_called_once_with(
instance.set_lock_script_sha1,
keys=['resource'],
args=['lock_id', 10000]
**_setup_evalsha_call_args(
digest=instance.set_lock_script_sha1,
keys=['resource'],
args=['lock_id', 10000],
).kwargs
)

@pytest.mark.asyncio
Expand All @@ -234,9 +242,11 @@ async def test_get_lock_ttl(self, fake_instance):

await instance.get_lock_ttl('resource', 'lock_id')
pool.evalsha.assert_called_with(
instance.get_lock_ttl_script_sha1,
keys=['resource'],
args=['lock_id']
**_setup_evalsha_call_args(
instance.get_lock_ttl_script_sha1,
keys=['resource'],
args=['lock_id'],
).kwargs
)

@pytest.mark.asyncio
Expand All @@ -256,9 +266,11 @@ async def hold_lock(instance):
await instance.set_lock('resource', 'lock_id', 10.0)

pool.evalsha.assert_called_once_with(
instance.set_lock_script_sha1,
keys=['resource'],
args=['lock_id', 10000]
**_setup_evalsha_call_args(
instance.set_lock_script_sha1,
keys=['resource'],
args=['lock_id', 10000],
).kwargs
)

instance._pool = None
Expand Down Expand Up @@ -286,9 +298,11 @@ async def test_lock_without_scripts(self, fake_coro, fake_instance, func, args,
assert pool.script_load.call_count == 6 # for 3 scripts.

pool.evalsha.assert_called_with(
getattr(instance, '{0}_script_sha1'.format(func)),
keys=expected_keys,
args=expected_args,
**_setup_evalsha_call_args(
getattr(instance, '{0}_script_sha1'.format(func)),
keys=expected_keys,
args=expected_args,
).kwargs
)

@pytest.mark.asyncio
Expand All @@ -300,9 +314,11 @@ async def test_unset_lock(self, fake_instance):
await instance.unset_lock('resource', 'lock_id')

pool.evalsha.assert_called_once_with(
instance.unset_lock_script_sha1,
keys=['resource'],
args=['lock_id']
**_setup_evalsha_call_args(
instance.unset_lock_script_sha1,
keys=['resource'],
args=['lock_id'],
).kwargs
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -363,6 +379,22 @@ def mock_redis_three_instances(redis_three_connections):
yield redis, pool


def _setup_evalsha_call_args(digest, keys, args):
if StrictVersion(aioredis.__version__) >= StrictVersion('2.0.0'):
return call(
digest,
len(keys),
*keys,
*args,
)
else:
return call(
digest=digest,
keys=keys,
args=args,
)


class TestRedis:

def test_initialization(self, redis_two_connections):
Expand Down Expand Up @@ -398,7 +430,7 @@ async def test_lock(

script_sha1 = getattr(redis.instances[0], '%s_script_sha1' % method_name)

calls = [call(script_sha1, **call_args)] * 2
calls = [_setup_evalsha_call_args(script_sha1, **call_args)] * 2
pool.evalsha.assert_has_calls(calls)

@pytest.mark.asyncio
Expand Down Expand Up @@ -434,7 +466,7 @@ async def test_lock_one_of_two_instances_failed(

script_sha1 = getattr(redis.instances[0], '%s_script_sha1' % method_name)

calls = [call(script_sha1, **call_args)] * 2
calls = [_setup_evalsha_call_args(script_sha1, **call_args)] * 2
pool.evalsha.assert_has_calls(calls)

@pytest.mark.asyncio
Expand Down Expand Up @@ -472,7 +504,7 @@ async def test_three_instances_combination(
script_sha1 = getattr(redis.instances[0],
'%s_script_sha1' % method_name)

calls = [call(script_sha1, **call_args)] * 3
calls = [_setup_evalsha_call_args(script_sha1, **call_args)] * 3
pool.evalsha.assert_has_calls(calls)

@pytest.mark.asyncio
Expand Down Expand Up @@ -508,7 +540,7 @@ async def test_three_instances_combination_errors(
script_sha1 = getattr(redis.instances[0],
'%s_script_sha1' % method_name)

calls = [call(script_sha1, **call_args)] * 3
calls = [_setup_evalsha_call_args(script_sha1, **call_args)] * 3
pool.evalsha.assert_has_calls(calls)

@pytest.mark.asyncio
Expand Down Expand Up @@ -538,6 +570,6 @@ async def test_get_lock(self, mock_redis_two_instances, ):

script_sha1 = getattr(redis.instances[0], 'get_lock_ttl_script_sha1')

calls = [call(script_sha1, keys=['resource'], args=['lock_id'])]
calls = [_setup_evalsha_call_args(script_sha1, keys=['resource'], args=['lock_id'])]
pool.evalsha.assert_has_calls(calls)
# assert 0
25 changes: 7 additions & 18 deletions tests/ut/test_sentinel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import contextlib
import ssl
import sys
from unittest import mock

import aioredlock.sentinel
Expand All @@ -15,19 +14,12 @@

@contextlib.contextmanager
def mock_aioredis_sentinel():
if sys.version_info < (3, 8, 0):
mock_obj = mock.MagicMock()
mock_obj.master_for.return_value = asyncio.Future()
mock_obj.master_for.return_value.set_result(True)
else:
mock_obj = mock.AsyncMock()
mock_obj.master_for.return_value = True
with mock.patch.object(aioredlock.sentinel.aioredis.sentinel, 'create_sentinel') as mock_sentinel:
if sys.version_info < (3, 8, 0):
mock_sentinel.return_value = asyncio.Future()
mock_sentinel.return_value.set_result(mock_obj)
else:
mock_sentinel.return_value = mock_obj
mock_obj = mock.MagicMock()
mock_obj.master_for.return_value = asyncio.Future()
mock_obj.master_for.return_value.set_result(True)
with mock.patch.object(aioredlock.sentinel, 'create_sentinel') as mock_sentinel:
mock_sentinel.return_value = asyncio.Future()
mock_sentinel.return_value.set_result(mock_obj)
yield mock_sentinel


Expand Down Expand Up @@ -143,10 +135,7 @@ async def test_sentinel(ssl_context, connection, kwargs, expected_kwargs, expect
if with_ssl or kwargs.get('ssl_context') is True:
expected_kwargs['ssl'] = ssl_context
mock_sentinel.assert_called_with(**expected_kwargs)
if sys.version_info < (3, 8, 0):
result = mock_sentinel.return_value.result()
else:
result = mock_sentinel.return_value
result = mock_sentinel.return_value.result()
assert result.master_for.called
result.master_for.assert_called_with(expected_master)
if with_ssl:
Expand Down