Skip to content

Commit

Permalink
enable subclassing of Taskify
Browse files Browse the repository at this point in the history
  • Loading branch information
cardinam committed Feb 13, 2024
1 parent 52ce49e commit d7dd719
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 221 deletions.
47 changes: 5 additions & 42 deletions taskq/consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging
import threading
from time import sleep
Expand All @@ -11,11 +10,10 @@
from django_pglocks import advisory_lock

from .constants import TASKQ_DEFAULT_CONSUMER_SLEEP_RATE, TASKQ_DEFAULT_TASK_TIMEOUT
from .exceptions import Cancel, TaskLoadingError, TaskFatalError
from .exceptions import Cancel, TaskFatalError
from .models import Task
from .scheduler import Scheduler
from .task import Taskify
from .utils import task_from_scheduled_task, traceback_filter_taskq_frames, ordinal
from .utils import traceback_filter_taskq_frames, ordinal

logger = logging.getLogger("taskq")

Expand Down Expand Up @@ -85,8 +83,7 @@ def create_scheduled_tasks(self):
if task_exists:
continue

task = task_from_scheduled_task(scheduled_task)
task.save()
scheduled_task.create_task()

self._scheduler.update_all_tasks_due_dates()

Expand Down Expand Up @@ -167,8 +164,8 @@ def process_task(self, task):
logger.info("%s : Started (%s retry)", task, nth)

def _execute_task():
function, args, kwargs = self.load_task(task)
self.execute_task(function, args, kwargs)
with transaction.atomic():
task.execute()

try:
task.status = Task.STATUS_RUNNING
Expand Down Expand Up @@ -218,37 +215,3 @@ def fail_task(self, task, error):
type_name = type(error).__name__
exc_info = (type(error), error, exc_traceback)
logger.exception("%s : %s %s", task, type_name, error, exc_info=exc_info)

def load_task(self, task):
function = self.import_taskified_function(task.function_name)
args, kwargs = task.decode_function_args()

return (function, args, kwargs)

def import_taskified_function(self, import_path):
"""Load a @taskified function from a python module.
Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute_task(self, function, args, kwargs):
"""Execute the code of the task"""
with transaction.atomic():
function._protected_call(args, kwargs)
112 changes: 112 additions & 0 deletions taskq/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import copy
import datetime
import importlib
import logging
import uuid

from django.core.exceptions import ValidationError
from django.db import models
from django.utils import timezone

from .exceptions import TaskLoadingError
from .json import JSONDecoder, JSONEncoder
from .utils import parse_timedelta

logger = logging.getLogger("taskq")


def generate_task_uuid():
Expand Down Expand Up @@ -100,6 +106,40 @@ def update_due_at_after_failure(self):

self.due_at = timezone.now() + delay

def load_task(self):
taskified_function = self.import_taskified_function(self.function_name)
args, kwargs = self.decode_function_args()

return (taskified_function, args, kwargs)

@staticmethod
def import_taskified_function(import_path):
"""Load a @taskified function from a python module.
Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute(self):
taskified_function, args, kwargs = self.load_task()
taskified_function._protected_call(args, kwargs)

def __str__(self):
status = dict(self.STATUS_CHOICES)[self.status]

Expand All @@ -109,3 +149,75 @@ def __str__(self):
str_repr += f"{self.uuid}, status={status}>"

return str_repr


class Taskify:
def __init__(self, function, name=None):
self._function = function
self._name = name

def __call__(self, *args, **kwargs):
return self._function(*args, **kwargs)

# If you rename this method, update the code in utils.traceback_filter_taskq_frames
def _protected_call(self, args, kwargs):
self.__call__(*args, **kwargs)

def apply(self, *args, **kwargs):
return self.__call__(*args, **kwargs)

def apply_async(
self,
due_at=None,
max_retries=3,
retry_delay=0,
retry_backoff=False,
retry_backoff_factor=2,
timeout=None,
args=None,
kwargs=None,
):
"""Apply a task asynchronously.
.
:param Tuple args: The positional arguments to pass on to the task.
:parm Dict kwargs: The keyword arguments to pass on to the task.
:parm due_at: When the task should be executed. (None = now).
:type due_at: timedelta or None
:param timeout: The maximum time a task may run.
(None = no timeout)
(int = number of seconds)
:type timeout: timedelta or int or None
"""

if due_at is None:
due_at = timezone.now()
if args is None:
args = []
if kwargs is None:
kwargs = {}

task = Task()
task.due_at = due_at
task.name = self.name
task.status = Task.STATUS_QUEUED
task.function_name = self.func_name
task.encode_function_args(args, kwargs)
task.max_retries = max_retries
task.retry_delay = parse_timedelta(retry_delay)
task.retry_backoff = retry_backoff
task.retry_backoff_factor = retry_backoff_factor
task.timeout = parse_timedelta(timeout, nullable=True)
task.save()

return task

@property
def func_name(self):
return "%s.%s" % (self._function.__module__, self._function.__name__)

@property
def name(self):
return self._name if self._name else self.func_name
24 changes: 23 additions & 1 deletion taskq/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime

from croniter import croniter
from django.conf import settings
from django.utils import timezone
from croniter import croniter

from .models import Task
from .utils import parse_timedelta


Expand Down Expand Up @@ -49,6 +50,27 @@ def is_due(self):
now = timezone.now()
return self.due_at <= now

@property
def as_task(self):
"""
Note that the returned Task is not saved in database, you still need to call.save() on it.
"""
task = Task()
task.name = self.name
task.due_at = self.due_at
task.function_name = self.function_name
task.encode_function_args(kwargs=self.args)
task.max_retries = self.max_retries
task.retry_delay = self.retry_delay
task.retry_backoff = self.retry_backoff
task.retry_backoff_factor = self.retry_backoff_factor
task.timeout = self.timeout

return task

def create_task(self):
self.as_task.save()


class Scheduler:
def __init__(self):
Expand Down
88 changes: 12 additions & 76 deletions taskq/task.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,21 @@
import logging
import importlib

from django.utils import timezone
from django.conf import settings

from .models import Task as TaskModel
from .utils import parse_timedelta
from taskq.models import Taskify

logger = logging.getLogger("taskq")

def taskify(func=None, *, name=None, base=None, **kwargs):
if base is None:
default_cls_str = getattr(settings, "TASKQ", {}).get("default_taskify_class")
if default_cls_str:
module_name, unit_name = default_cls_str.rsplit(".", 1)
base = getattr(importlib.import_module(module_name), unit_name)
else:
base = Taskify

class Taskify:
def __init__(self, function, name=None):
self._function = function
self._name = name

# If you rename this method, update the code in utils.traceback_filter_taskq_frames
def _protected_call(self, args, kwargs):
self._function(*args, **kwargs)

def apply(self, *args, **kwargs):
return self._function(*args, **kwargs)

def apply_async(
self,
due_at=None,
max_retries=3,
retry_delay=0,
retry_backoff=False,
retry_backoff_factor=2,
timeout=None,
args=None,
kwargs=None,
):
"""Apply a task asynchronously.
.
:param Tuple args: The positional arguments to pass on to the task.
:parm Dict kwargs: The keyword arguments to pass on to the task.
:parm due_at: When the task should be executed. (None = now).
:type due_at: timedelta or None
:param timeout: The maximum time a task may run.
(None = no timeout)
(int = number of seconds)
:type timeout: timedelta or int or None
"""

if due_at is None:
due_at = timezone.now()
if args is None:
args = []
if kwargs is None:
kwargs = {}

task = TaskModel()
task.due_at = due_at
task.name = self.name
task.status = TaskModel.STATUS_QUEUED
task.function_name = self.func_name
task.encode_function_args(args, kwargs)
task.max_retries = max_retries
task.retry_delay = parse_timedelta(retry_delay)
task.retry_backoff = retry_backoff
task.retry_backoff_factor = retry_backoff_factor
task.timeout = parse_timedelta(timeout, nullable=True)
task.save()

return task

@property
def func_name(self):
return "%s.%s" % (self._function.__module__, self._function.__name__)

@property
def name(self):
return self._name if self._name else self.func_name


def taskify(func=None, name=None):
def wrapper_taskify(_func):
return Taskify(_func, name=name)
return base(_func, name=name, **kwargs)

if func is None:
return wrapper_taskify
Expand Down
22 changes: 0 additions & 22 deletions taskq/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import datetime
import traceback

from .models import Task


def ordinal(n: int):
"""Output the ordinal representation ("1st", "2nd", "3rd", etc.) of any number."""
Expand Down Expand Up @@ -31,26 +29,6 @@ def parse_timedelta(delay, nullable=False):
raise TypeError("Unexpected delay type")


def task_from_scheduled_task(scheduled_task):
"""Create a new Task initialized with the content of `scheduled_task`.
Note that the returned Task is not saved in database, you still need to
call .save() on it.
"""
task = Task()
task.name = scheduled_task.name
task.due_at = scheduled_task.due_at
task.function_name = scheduled_task.function_name
task.encode_function_args(kwargs=scheduled_task.args)
task.max_retries = scheduled_task.max_retries
task.retry_delay = scheduled_task.retry_delay
task.retry_backoff = scheduled_task.retry_backoff
task.retry_backoff_factor = scheduled_task.retry_backoff_factor
task.timeout = scheduled_task.timeout

return task


def traceback_filter_taskq_frames(exception):
"""Will return the traceback of the passed exception without the taskq
internal frames except the last one (which will be "_protected_call" in
Expand Down
Loading

0 comments on commit d7dd719

Please sign in to comment.