Skip to content

Commit

Permalink
Add PythonVirtualenvDecorator to Taskflow API (apache#14761)
Browse files Browse the repository at this point in the history
To improve the usability of the TaskFlow API, we will add the ability to
define virtualenv environments so users can run tasks with
environments that do not match that of the Airflow system
  • Loading branch information
dimberman authored Apr 8, 2021
1 parent ce91872 commit 5661273
Show file tree
Hide file tree
Showing 12 changed files with 716 additions and 97 deletions.
85 changes: 84 additions & 1 deletion airflow/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.

from typing import Callable, Optional
from typing import Callable, Dict, Iterable, List, Optional, Union

from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import _virtualenv_task
from airflow.decorators.task_group import task_group # noqa # pylint: disable=unused-import
from airflow.models.dag import dag # noqa # pylint: disable=unused-import

Expand Down Expand Up @@ -56,5 +57,87 @@ def python(python_callable: Optional[Callable] = None, multiple_outputs: Optiona
"""
return python_task(python_callable=python_callable, multiple_outputs=multiple_outputs, **kwargs)

@staticmethod
def virtualenv(
python_callable: Optional[Callable] = None,
multiple_outputs: Optional[bool] = None,
requirements: Optional[Iterable[str]] = None,
python_version: Optional[Union[str, int, float]] = None,
use_dill: bool = False,
system_site_packages: bool = True,
string_args: Optional[Iterable[str]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
**kwargs,
):
"""
Allows one to run a function in a virtualenv that is
created and destroyed automatically (with certain caveats).
The function must be defined using def, and not be
part of a class. All imports must happen inside the function
and no variables outside of the scope may be referenced. A global scope
variable named virtualenv_string_args will be available (populated by
string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
can use a return value.
Note that if your virtualenv runs in a different Python major version than Airflow,
you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
Airflow through plugins. You can use string_args though.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:PythonVirtualenvOperator`
:param python_callable: A python function with no references to outside variables,
defined with def, which will be run in a virtualenv
:type python_callable: function
:param multiple_outputs: if set, function return value will be
unrolled to multiple XCom values. List/Tuples will unroll to xcom values
with index as key. Dict will unroll to xcom values with keys as XCom keys.
Defaults to False.
:type multiple_outputs: bool
:param requirements: A list of requirements as specified in a pip install command
:type requirements: list[str]
:param python_version: The Python version to run the virtualenv with. Note that
both 2 and 2.7 are acceptable forms.
:type python_version: Optional[Union[str, int, float]]
:param use_dill: Whether to use dill to serialize
the args and result (pickle is default). This allow more complex types
but requires you to include dill in your requirements.
:type use_dill: bool
:param system_site_packages: Whether to include
system_site_packages in your virtualenv.
See virtualenv documentation for more information.
:type system_site_packages: bool
:param op_args: A list of positional arguments to pass to python_callable.
:type op_args: list
:param op_kwargs: A dict of keyword arguments to pass to python_callable.
:type op_kwargs: dict
:param string_args: Strings that are present in the global var virtualenv_string_args,
available to python_callable at runtime as a list[str]. Note that args are split
by newline.
:type string_args: list[str]
:param templates_dict: a dictionary where the values are templates that
will get templated by the Airflow engine sometime between
``__init__`` and ``execute`` takes place and are made available
in your callable's context after the template has been applied
:type templates_dict: dict of str
:param templates_exts: a list of file extensions to resolve while
processing templated fields, for examples ``['.sql', '.hql']``
:type templates_exts: list[str]
"""
return _virtualenv_task(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
requirements=requirements,
python_version=python_version,
use_dill=use_dill,
system_site_packages=system_site_packages,
string_args=string_args,
templates_dict=templates_dict,
templates_exts=templates_exts,
**kwargs,
)


task = _TaskDecorator()
147 changes: 90 additions & 57 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,60 @@
from airflow.utils.task_group import TaskGroup, TaskGroupContext


class BaseDecoratedOperator(BaseOperator):
def validate_python_callable(python_callable):
"""
Validate that python callable can be wrapped by operator.
Raises exception if invalid.
:param python_callable: Python object to be validated
:raises: TypeError, AirflowException
"""
if not callable(python_callable):
raise TypeError('`python_callable` param must be callable')
if 'self' in signature(python_callable).parameters.keys():
raise AirflowException('@task does not support methods')


def get_unique_task_id(
task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None
) -> str:
"""
Generate unique task id given a DAG (or if run in a DAG context)
Ids are generated by appending a unique number to the end of
the original task id.
Example:
task_id
task_id__1
task_id__2
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
if not dag:
return task_id

# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id

if tg_task_id not in dag.task_ids:
return task_id
core = re.split(r'__\d+$', task_id)[0]
suffixes = sorted(
[
int(re.split(r'^.+__', task_id)[1])
for task_id in dag.task_ids
if re.match(rf'^{core}__\d+$', task_id)
]
)
if not suffixes:
return f'{core}__1'
return f'{core}__{suffixes[-1] + 1}'


class DecoratedOperator(BaseOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
Expand All @@ -45,6 +98,10 @@ class BaseDecoratedOperator(BaseOperator):
unrolled to multiple XCom values. Dict will unroll to xcom values with keys as keys.
Defaults to False.
:type multiple_outputs: bool
:param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
PythonOperator). This gives a user the option to upstream kwargs as needed.
:type kwargs_to_upstream: dict
"""

template_fields = ('op_args', 'op_kwargs')
Expand All @@ -63,73 +120,49 @@ def __init__(
op_args: Tuple[Any],
op_kwargs: Dict[str, Any],
multiple_outputs: bool = False,
kwargs_to_upstream: dict = None,
**kwargs,
) -> None:
kwargs['task_id'] = self._get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
super().__init__(**kwargs)
kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}

# Check that arguments can be binded
signature(python_callable).bind(*op_args, **op_kwargs)
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
super().__init__(**kwargs_to_upstream, **kwargs)

@staticmethod
def _get_unique_task_id(
task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None
) -> str:
"""
Generate unique task id given a DAG (or if run in a DAG context)
Ids are generated by appending a unique number to the end of
the original task id.
Example:
task_id
task_id__1
task_id__2
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
if not dag:
return task_id

# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id

if tg_task_id not in dag.task_ids:
return task_id
core = re.split(r'__\d+$', task_id)[0]
suffixes = sorted(
[
int(re.split(r'^.+__', task_id)[1])
for task_id in dag.task_ids
if re.match(rf'^{core}__\d+$', task_id)
]
)
if not suffixes:
return f'{core}__1'
return f'{core}__{suffixes[-1] + 1}'

@staticmethod
def validate_python_callable(python_callable):
"""
Validate that python callable can be wrapped by operator.
Raises exception if invalid.
def execute(self, context: Dict):
return_value = super().execute(context)
self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
return return_value

:param python_callable: Python object to be validated
:raises: TypeError, AirflowException
def _handle_output(self, return_value: Any, context: Dict, xcom_push: Callable):
"""
if not callable(python_callable):
raise TypeError('`python_callable` param must be callable')
if 'self' in signature(python_callable).parameters.keys():
raise AirflowException('@task does not support methods')
Handles logic for whether a decorator needs to push a single return value or multiple return values.
def execute(self, context: Dict):
raise NotImplementedError()
:param return_value:
:param context:
:param xcom_push:
"""
if not self.multiple_outputs:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
'Returned dictionary keys must be strings when using '
f'multiple_outputs, found {key} ({type(key)}) instead'
)
for key, value in return_value.items():
xcom_push(context, key, value)
else:
raise AirflowException(
f'Returned output was type {type(return_value)} expected dictionary for multiple_outputs'
)
return return_value


T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
Expand All @@ -138,7 +171,7 @@ def execute(self, context: Dict):
def task_decorator_factory(
python_callable: Optional[Callable] = None,
multiple_outputs: Optional[bool] = None,
decorated_operator_class: BaseDecoratedOperator = None,
decorated_operator_class: BaseOperator = None,
**kwargs,
) -> Callable[[T], T]:
"""
Expand Down Expand Up @@ -169,7 +202,7 @@ def wrapper(f: T):
Python wrapper to generate PythonDecoratedOperator out of simple python functions.
Used for Airflow Decorated interface
"""
BaseDecoratedOperator.validate_python_callable(f)
validate_python_callable(f)
kwargs.setdefault('task_id', f.__name__)

@functools.wraps(f)
Expand Down
38 changes: 10 additions & 28 deletions airflow/decorators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@
# specific language governing permissions and limitations
# under the License.

from typing import Callable, Dict, Optional, TypeVar
from typing import Callable, Optional, TypeVar

from airflow.decorators.base import BaseDecoratedOperator, task_decorator_factory
from airflow.exceptions import AirflowException
from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.operators.python import PythonOperator
from airflow.utils.decorators import apply_defaults

PYTHON_OPERATOR_UI_COLOR = '#ffefeb'


class _PythonDecoratedOperator(BaseDecoratedOperator):
class _PythonDecoratedOperator(DecoratedOperator, PythonOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
Expand All @@ -45,8 +43,6 @@ class _PythonDecoratedOperator(BaseDecoratedOperator):
template_fields = ('op_args', 'op_kwargs')
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}

ui_color = PYTHON_OPERATOR_UI_COLOR

# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs = ('python_callable',)
Expand All @@ -56,27 +52,13 @@ def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
kwargs_to_upstream = {
"python_callable": kwargs["python_callable"],
"op_args": kwargs["op_args"],
"op_kwargs": kwargs["op_kwargs"],
}

def execute(self, context: Dict):
return_value = self.python_callable(*self.op_args, **self.op_kwargs)
self.log.debug("Done. Returned value was: %s", return_value)
if not self.multiple_outputs:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
'Returned dictionary keys must be strings when using '
f'multiple_outputs, found {key} ({type(key)}) instead'
)
for key, value in return_value.items():
self.xcom_push(context, key, value)
else:
raise AirflowException(
f'Returned output was type {type(return_value)} expected dictionary ' 'for multiple_outputs'
)
return return_value
super().__init__(kwargs_to_upstream=kwargs_to_upstream, **kwargs)


T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
Expand Down
Loading

0 comments on commit 5661273

Please sign in to comment.