Skip to content

Commit

Permalink
Allow truncating task executions (#251)
Browse files Browse the repository at this point in the history
* Allow truncating task executions

* Get rid of truncated_rpush

* Fix test
  • Loading branch information
neob91-close authored Dec 19, 2022
1 parent 9b807dd commit 60e3b56
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 13 deletions.
20 changes: 17 additions & 3 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
retry_on=None,
retry_method=None,
max_queue_size=None,
max_stored_executions=None,
runner_class=None,
# internal variables
_data=None,
Expand Down Expand Up @@ -95,6 +96,11 @@ def __init__(
if max_queue_size is None:
max_queue_size = getattr(func, "_task_max_queue_size", None)

if max_stored_executions is None:
max_stored_executions = getattr(
func, "_task_max_stored_executions", None
)

if runner_class is None:
runner_class = getattr(func, "_task_runner_class", None)

Expand Down Expand Up @@ -141,6 +147,8 @@ def __init__(
]
if max_queue_size:
task["max_queue_size"] = max_queue_size
if max_stored_executions is not None:
task["max_stored_executions"] = max_stored_executions
if runner_class:
serialized_runner_class = serialize_func_name(runner_class)
task["runner_class"] = serialized_runner_class
Expand Down Expand Up @@ -237,6 +245,10 @@ def func(self):
self._func = import_attribute(self.serialized_func)
return self._func

@property
def max_stored_executions(self):
return self._data.get("max_stored_executions")

@property
def serialized_runner_class(self):
return self._data.get("runner_class")
Expand Down Expand Up @@ -558,11 +570,13 @@ def n_executions(self):
"""
pipeline = self.tiger.connection.pipeline()
pipeline.exists(self.tiger._key("task", self.id))
pipeline.llen(self.tiger._key("task", self.id, "executions"))
exists, n_executions = pipeline.execute()
pipeline.get(self.tiger._key("task", self.id, "executions_count"))

exists, executions_count = pipeline.execute()
if not exists:
raise TaskNotFound("Task {} not found.".format(self.id))
return n_executions

return int(executions_count or 0)

def retry(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions tasktiger/tasktiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def task(
schedule=None,
batch=False,
max_queue_size=None,
max_stored_executions=None,
runner_class=None,
):
"""
Expand Down Expand Up @@ -333,6 +334,8 @@ def _wrap(func):
func._task_schedule = schedule
if max_queue_size is not None:
func._task_max_queue_size = max_queue_size
if max_stored_executions is not None:
func._task_max_stored_executions = max_stored_executions
if runner_class is not None:
func._task_runner_class = runner_class

Expand Down Expand Up @@ -407,6 +410,7 @@ def delay(
retry_on=None,
retry_method=None,
max_queue_size=None,
max_stored_executions=None,
runner_class=None,
):
"""
Expand All @@ -427,6 +431,7 @@ def delay(
retry=retry,
retry_on=retry_on,
retry_method=retry_method,
max_stored_executions=max_stored_executions,
runner_class=runner_class,
)

Expand Down
15 changes: 11 additions & 4 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,11 +1186,18 @@ def _store_task_execution(self, tasks, execution) -> None:
serialized_execution = json.dumps(execution)

for task in tasks:
pipeline = self.connection.pipeline()
pipeline.incr(self._key("task", task.id, "executions_count"))
pipeline.rpush(
self._key("task", task.id, "executions"), serialized_execution
executions_key = self._key("task", task.id, "executions")
executions_count_key = self._key(
"task", task.id, "executions_count"
)

pipeline = self.connection.pipeline()
pipeline.incr(executions_count_key)
pipeline.rpush(executions_key, serialized_execution)

if task.max_stored_executions:
pipeline.ltrim(executions_key, -task.max_stored_executions, -1)

pipeline.execute()

def run(self, once=False, force_once=False):
Expand Down
31 changes: 29 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ def test_exception_task(self, store_tracebacks):
else:
assert "traceback" not in execution

@pytest.mark.parametrize("max_stored_executions", [2, 3, 6, 11, None])
def test_max_stored_executions(self, max_stored_executions):
def _get_stored_executions():
return self.conn.llen(f"t:task:{task.id}:executions")

task = self.tiger.delay(
exception_task,
max_stored_executions=max_stored_executions,
retry_method=fixed(DELAY, 20),
)

assert _get_stored_executions() == 0

for __ in range(6):
Worker(self.tiger).run(once=True)
time.sleep(DELAY)

Worker(self.tiger).run(once=True)

assert task.n_executions() == 6
assert _get_stored_executions() == min(max_stored_executions or 6, 6)

def test_long_task_ok(self):
self.tiger.delay(long_task_ok)
Worker(self.tiger).run(once=True)
Expand Down Expand Up @@ -433,9 +455,14 @@ def test_lock_key(self):
error={"default": 0},
)

def test_retry(self):
@pytest.mark.parametrize("max_stored_executions", [0, 1, 2, 3, 4, None])
def test_retry(self, max_stored_executions):
# Use the default retry method we configured.
task = self.tiger.delay(exception_task, retry=True)
task = self.tiger.delay(
exception_task,
max_stored_executions=max_stored_executions,
retry=True,
)
self._ensure_queues(
queued={"default": 1},
scheduled={"default": 0},
Expand Down
31 changes: 27 additions & 4 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from .utils import get_tiger


class TestTaskFromId:
@pytest.fixture
def tiger(self):
return get_tiger()
@pytest.fixture
def tiger():
return get_tiger()


class TestTaskFromId:
@pytest.fixture
def queued_task(self, tiger):
return tiger.delay(simple_task)
Expand All @@ -26,3 +27,25 @@ def test_task_wrong_state(self, tiger, queued_task):
def test_task_wrong_queue(self, tiger, queued_task):
with pytest.raises(TaskNotFound):
Task.from_id(tiger, "other", "active", queued_task.id)


class TestTaskMaxTrackedExecutions:
def test_max_stored_executions_passed_to_tiger_delay(self, tiger):
task = tiger.delay(simple_task, max_stored_executions=17)
assert task.max_stored_executions == 17

def test_max_stored_executions_passed_to_decorator(self, tiger):
@tiger.task(max_stored_executions=17)
def some_task():
pass

task = some_task.delay()
assert task.max_stored_executions == 17

def test_max_stored_executions_overridden_in_tiger_delay(self, tiger):
@tiger.task(max_stored_executions=17)
def some_task():
pass

task = tiger.delay(some_task, max_stored_executions=11)
assert task.max_stored_executions == 11

0 comments on commit 60e3b56

Please sign in to comment.