Skip to content

Commit

Permalink
Scheduling based on dataset aliases (apache#40693)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Jul 22, 2024
1 parent e30f810 commit 8dff8ae
Show file tree
Hide file tree
Showing 17 changed files with 2,354 additions and 1,949 deletions.
46 changes: 45 additions & 1 deletion airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator

import attr
from sqlalchemy import select

from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from urllib.parse import SplitResult

from sqlalchemy.orm.session import Session


from airflow.configuration import conf

Expand Down Expand Up @@ -127,6 +131,23 @@ def extract_event_key(value: str | Dataset | DatasetAlias) -> str:
return _sanitize_uri(str(value))


@provide_session
def expand_alias_to_datasets(
alias: str | DatasetAlias, *, session: Session = NEW_SESSION
) -> list[BaseDataset]:
"""Expand dataset alias to resolved datasets."""
from airflow.models.dataset import DatasetAliasModel

alias_name = alias.name if isinstance(alias, DatasetAlias) else alias

dataset_alias_obj = session.scalar(
select(DatasetAliasModel).where(DatasetAliasModel.name == alias_name).limit(1)
)
if dataset_alias_obj:
return [Dataset(uri=dataset.uri, extra=dataset.extra) for dataset in dataset_alias_obj.datasets]
return []


class BaseDataset:
"""
Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
Expand Down Expand Up @@ -233,7 +254,10 @@ class _DatasetBooleanCondition(BaseDataset):
def __init__(self, *objects: BaseDataset) -> None:
if not all(isinstance(o, BaseDataset) for o in objects):
raise TypeError("expect dataset expressions in condition")
self.objects = objects

self.objects = [
_DatasetAliasCondition(obj.name) if isinstance(obj, DatasetAlias) else obj for obj in objects
]

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)
Expand Down Expand Up @@ -271,6 +295,26 @@ def as_expression(self) -> dict[str, Any]:
return {"any": [o.as_expression() for o in self.objects]}


class _DatasetAliasCondition(DatasetAny):
"""
Use to expand DataAlias as DatasetAny of its resolved Datasets.
:meta private:
"""

def __init__(self, name: str) -> None:
self.name = name
self.objects = expand_alias_to_datasets(name)

def as_expression(self) -> Any:
"""
Serialize the dataset into its scheduling expression.
:meta private:
"""
return {"alias": self.name}


class DatasetAll(_DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "or" relationship."""

Expand Down
6 changes: 4 additions & 2 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ def register_dataset_change(
}
)
dataset_event = DatasetEvent(**event_kwargs)
session.add(dataset_event)
if source_alias_names:
dataset_alias_models = session.scalars(
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(source_alias_names))
)
dataset_event.source_aliases.extend(dataset_alias_models)
session.add(dataset_event)
for dsa in dataset_alias_models:
dsa.dataset_events.append(dataset_event)
session.add(dsa)
session.flush()

cls.notify_dataset_changed(dataset=dataset)
Expand Down
69 changes: 69 additions & 0 deletions airflow/migrations/versions/0150_2_10_0_dataset_alias_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add dataset_alias_dataset association table.
Revision ID: 8684e37832e6
Revises: 41b3bc7c0272
Create Date: 2024-07-18 06:21:06.242569
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "8684e37832e6"
down_revision = "41b3bc7c0272"
branch_labels = None
depends_on = None
airflow_version = "2.10.0"


def upgrade():
"""Add dataset_alias_dataset association table."""
op.create_table(
"dataset_alias_dataset",
sa.Column("alias_id", sa.Integer(), nullable=False),
sa.Column("dataset_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["alias_id"],
["dataset_alias.id"],
name=op.f("dataset_alias_dataset_alias_id_fkey"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["dataset_id"],
["dataset.id"],
name=op.f("dataset_alias_dataset_dataset_id_fkey"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("alias_id", "dataset_id", name=op.f("dataset_alias_dataset_pkey")),
)
op.create_index(
"idx_dataset_alias_dataset_alias_dataset_id", "dataset_alias_dataset", ["dataset_id"], unique=False
)
op.create_index("idx_dataset_alias_dataset_alias_id", "dataset_alias_dataset", ["alias_id"], unique=False)


def downgrade():
"""Drop dataset_alias_dataset association table."""
op.drop_table("dataset_alias_dataset")
9 changes: 6 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@
# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
# See also: https://discuss.python.org/t/9126/7
ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDataset, Collection["Dataset"]]
ScheduleArg = Union[
ArgNotSet, ScheduleInterval, Timetable, BaseDataset, Collection[Union["Dataset", "DatasetAlias"]]
]

SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]

Expand Down Expand Up @@ -669,8 +671,8 @@ def __init__(
self.timetable = DatasetTriggeredTimetable(schedule)
self.schedule_interval = self.timetable.summary
elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
if not all(isinstance(x, (Dataset, DatasetAlias)) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets or dataset aliases")
self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule))
self.schedule_interval = self.timetable.summary
elif isinstance(schedule, ArgNotSet):
Expand Down Expand Up @@ -4009,6 +4011,7 @@ def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None:
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]

if not dag_ready(dag_id, cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
Expand Down
76 changes: 54 additions & 22 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,48 @@
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime

alias_association_table = Table(
"dataset_alias_dataset",
Base.metadata,
Column("alias_id", ForeignKey("dataset_alias.id", ondelete="CASCADE"), primary_key=True),
Column("dataset_id", ForeignKey("dataset.id", ondelete="CASCADE"), primary_key=True),
Index("idx_dataset_alias_dataset_alias_id", "alias_id"),
Index("idx_dataset_alias_dataset_alias_dataset_id", "dataset_id"),
ForeignKeyConstraint(
("alias_id",),
["dataset_alias.id"],
name="ds_dsa_alias_id",
ondelete="CASCADE",
),
ForeignKeyConstraint(
("dataset_id",),
["dataset.id"],
name="ds_dsa_dataset_id",
ondelete="CASCADE",
),
)

dataset_alias_dataset_event_assocation_table = Table(
"dataset_alias_dataset_event",
Base.metadata,
Column("alias_id", ForeignKey("dataset_alias.id", ondelete="CASCADE"), primary_key=True),
Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
Index("idx_dataset_alias_dataset_event_alias_id", "alias_id"),
Index("idx_dataset_alias_dataset_event_event_id", "event_id"),
ForeignKeyConstraint(
("alias_id",),
["dataset_alias.id"],
name="dss_de_alias_id",
ondelete="CASCADE",
),
ForeignKeyConstraint(
("event_id",),
["dataset_event.id"],
name="dss_de_event_id",
ondelete="CASCADE",
),
)


class DatasetAliasModel(Base):
"""
Expand All @@ -64,6 +106,17 @@ class DatasetAliasModel(Base):

__tablename__ = "dataset_alias"

datasets = relationship(
"DatasetModel",
secondary=alias_association_table,
backref="aliases",
)
dataset_events = relationship(
"DatasetEvent",
secondary=dataset_alias_dataset_event_assocation_table,
back_populates="source_aliases",
)

@classmethod
def from_public(cls, obj: DatasetAlias) -> DatasetAliasModel:
return cls(name=obj.name)
Expand Down Expand Up @@ -284,27 +337,6 @@ def __repr__(self):
Index("idx_dagrun_dataset_events_event_id", "event_id"),
)

dataset_alias_dataset_event_assocation_table = Table(
"dataset_alias_dataset_event",
Base.metadata,
Column("alias_id", ForeignKey("dataset_alias.id", ondelete="CASCADE"), primary_key=True),
Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
Index("idx_dataset_alias_dataset_event_alias_id", "alias_id"),
Index("idx_dataset_alias_dataset_event_event_id", "event_id"),
ForeignKeyConstraint(
("alias_id",),
["dataset_alias.id"],
name="dss_de_alias_id",
ondelete="CASCADE",
),
ForeignKeyConstraint(
("event_id",),
["dataset_event.id"],
name="dss_de_event_id",
ondelete="CASCADE",
),
)


class DatasetEvent(Base):
"""
Expand Down Expand Up @@ -346,7 +378,7 @@ class DatasetEvent(Base):
source_aliases = relationship(
"DatasetAliasModel",
secondary=dataset_alias_dataset_event_assocation_table,
backref="dataset_events",
back_populates="dataset_events",
)

source_task_instance = relationship(
Expand Down
8 changes: 7 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetModel
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
Expand Down Expand Up @@ -2996,6 +2996,12 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
self.log.warning('Created a new Dataset(uri="%s") as it did not exists.', uri)
dataset_objs_cache[uri] = dataset_obj

for alias in alias_names:
alias_obj = session.scalar(
select(DatasetAliasModel).where(DatasetAliasModel.name == alias).limit(1)
)
dataset_obj.aliases.append(alias_obj)

extra = {k: v for k, v in extra_items}
self.log.info(
'Create dataset event Dataset(uri="%s", extra="%s") through dataset aliases "%s"',
Expand Down
10 changes: 9 additions & 1 deletion airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import TYPE_CHECKING, Any, Collection, Sequence

from airflow.datasets import DatasetAlias, _DatasetAliasCondition
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.utils import timezone

Expand Down Expand Up @@ -165,6 +166,13 @@ class DatasetTriggeredTimetable(_TrivialTimetable):
def __init__(self, datasets: BaseDataset) -> None:
super().__init__()
self.dataset_condition = datasets
if isinstance(self.dataset_condition, DatasetAlias):
self.dataset_condition = _DatasetAliasCondition(self.dataset_condition.name)

if not next(self.dataset_condition.iter_datasets(), False):
self._summary = "unresolved DatasetAlias"
else:
self._summary = "Dataset"

@classmethod
def deserialize(cls, data: dict[str, Any]) -> Timetable:
Expand All @@ -174,7 +182,7 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable:

@property
def summary(self) -> str:
return "Dataset"
return self._summary

def serialize(self) -> dict[str, Any]:
from airflow.serialization.serialized_objects import encode_dataset_condition
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class MappedClassProtocol(Protocol):
"2.8.1": "88344c1d9134",
"2.9.0": "1949afb29106",
"2.9.2": "686269002441",
"2.10.0": "41b3bc7c0272",
"2.10.0": "8684e37832e6",
}


Expand Down
27 changes: 27 additions & 0 deletions docs/apache-airflow/authoring-and-scheduling/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,33 @@ Only one dataset event is emitted for an added dataset, even if it is added to t
# This line will emit an additional dataset event as the extra is different.
outlet_events["my-task-outputs-3"].add(Dataset("s3://bucket/my-task"), extra={"k2": "v2"})
Scheduling based on dataset aliases
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since dataset events added to an alias are just simple dataset events, a downstream depending on the actual dataset can read dataset events of it normally, without considering the associated aliases. A downstream can also depend on a dataset alias. The authoring syntax is referencing the ``DatasetAlias`` by name, and the associated dataset events are picked up for scheduling. Note that a DAG can be triggered by a task with ``outlets=DatasetAlias("xxx")`` if and only if the alias is resolved into ``Dataset("s3://bucket/my-task")``. The DAG runs whenever a task with outlet ``DatasetAlias("out")`` gets associated with at least one dataset at runtime, regardless of the dataset's identity. The downstream DAG is not triggered if no datasets are associated to the alias for a particular given task run. This also means we can do conditional dataset-triggering.

.. code-block:: python
with DAG(dag_id="dataset-producer"):
@task(outlets=[Dataset("example-alias")])
def produce_dataset_events():
pass
with DAG(dag_id="dataset-alias-producer"):
@task(outlets=[DatasetAlias("example-alias")])
def produce_dataset_events(*, outlet_events):
outlet_events["example-alias"].add(Dataset("s3://bucket/my-task"))
with DAG(dag_id="dataset-consumer", schedule=Dataset("s3://bucket/my-task")):
...
with DAG(dag_id="dataset-alias-consumer", schedule=DatasetAlias("example-alias")):
...
In the example above, before the DAG "dataset-alias-producer" is executed, the dataset alias ``DatasetAlias("example-alias")`` is not yet resolved to ``Dataset("s3://bucket/my-task")``. Consequently, completing the execution of the DAG "dataset-producer" will only trigger the DAG "dataset-consumer" and not the DAG "dataset-alias-consumer". However, upon triggering the DAG "dataset-alias-producer", the ``DatasetAlias("example-alias")`` will be resolved to ``Dataset("s3://bucket/my-task")``, and it will produce a dataset event that triggers the DAG "dataset-consumer". At this point, ``DatasetAlias("example-alias")`` is resolved to ``Dataset("s3://bucket/my-task")``. Therefore, completing the execution of either DAG "dataset-producer" or "dataset-alias-producer" will trigger both the DAG "dataset-consumer" and "dataset-alias-consumer".

Combining dataset and time-based schedules
------------------------------------------
Expand Down
Loading

0 comments on commit 8dff8ae

Please sign in to comment.