Skip to content

Commit

Permalink
Refactor: Simplify code in models (apache#33181)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Aug 7, 2023
1 parent 15ede4a commit 5a0494f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 22 deletions.
2 changes: 1 addition & 1 deletion airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_id_collation_args():
# We cannot use session/dialect as at this point we are trying to determine the right connection
# parameters, so we use the connection
conn = conf.get("database", "sql_alchemy_conn", fallback="")
if conn.startswith("mysql") or conn.startswith("mariadb"):
if conn.startswith(("mysql", "mariadb")):
return {"collation": "utf8mb3_bin"}
return {}

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
if arg not in kwargs and arg in default_args:
kwargs[arg] = default_args[arg]

missing_args = non_optional_args - set(kwargs)
missing_args = non_optional_args.difference(kwargs)
if len(missing_args) == 1:
raise AirflowException(f"missing keyword argument {missing_args.pop()!r}")
elif missing_args:
Expand Down
31 changes: 14 additions & 17 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def __hash__(self):
for c in self._comps:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict.keys())
val = tuple(self.task_dict)
else:
val = getattr(self, c, None)
try:
Expand Down Expand Up @@ -1256,7 +1256,7 @@ def tasks(self, val):

@property
def task_ids(self) -> list[str]:
return list(self.task_dict.keys())
return list(self.task_dict)

@property
def teardowns(self) -> list[Operator]:
Expand Down Expand Up @@ -2897,7 +2897,7 @@ def bulk_write_to_db(
log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

dag_ids = set(dag_by_ids.keys())
dag_ids = set(dag_by_ids)
query = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
Expand Down Expand Up @@ -3235,7 +3235,7 @@ def get_serialized_fields(cls):
"auto_register",
"fail_stop",
}
cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) - exclusion_list
return cls.__serialized_fields

def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
Expand Down Expand Up @@ -3594,21 +3594,18 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
.having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
)
}
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
if dataset_triggered_dag_ids:
exclusion_list = {
x
for x in (
session.scalars(
select(DagModel.dag_id)
.join(DagRun.dag_model)
.where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
.where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
.group_by(DagModel.dag_id)
.having(func.count() >= func.max(DagModel.max_active_runs))
)
exclusion_list = set(
session.scalars(
select(DagModel.dag_id)
.join(DagRun.dag_model)
.where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
.where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
.group_by(DagModel.dag_id)
.having(func.count() >= func.max(DagModel.max_active_runs))
)
}
)
if exclusion_list:
dataset_triggered_dag_ids -= exclusion_list
dataset_triggered_dag_info = {
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def dag_ids(self) -> list[str]:
:return: a list of DAG IDs in this bag
"""
return list(self.dags.keys())
return list(self.dags)

@provide_session
def get_dag(self, dag_id, session: Session = None):
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _expand_mapped_field(self, key: str, value: Any, context: Context, *, sessio

def _find_index_for_this_field(index: int) -> int:
# Need to use the original user input to retain argument order.
for mapped_key in reversed(list(self.value)):
for mapped_key in reversed(self.value):
mapped_length = all_lengths[mapped_key]
if mapped_length < 1:
raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


class DependencyMixin:
"""Mixing implementing common dependency setting methods methods like >> and <<."""
"""Mixing implementing common dependency setting methods like >> and <<."""

@property
def roots(self) -> Sequence[DependencyMixin]:
Expand Down

0 comments on commit 5a0494f

Please sign in to comment.