From 5a0494f83e8ad0e5cbf0d3dcad3022a3ea89d789 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Mon, 7 Aug 2023 21:54:15 +0000 Subject: [PATCH] Refactor: Simplify code in models (#33181) --- airflow/models/base.py | 2 +- airflow/models/baseoperator.py | 2 +- airflow/models/dag.py | 31 ++++++++++++++----------------- airflow/models/dagbag.py | 2 +- airflow/models/expandinput.py | 2 +- airflow/models/taskmixin.py | 2 +- 6 files changed, 19 insertions(+), 22 deletions(-) diff --git a/airflow/models/base.py b/airflow/models/base.py index 5f6b7e9893dc4..934b9b1b74795 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -69,7 +69,7 @@ def get_id_collation_args(): # We cannot use session/dialect as at this point we are trying to determine the right connection # parameters, so we use the connection conn = conf.get("database", "sql_alchemy_conn", fallback="") - if conn.startswith("mysql") or conn.startswith("mariadb"): + if conn.startswith(("mysql", "mariadb")): return {"collation": "utf8mb3_bin"} return {} diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7e861e20a6120..45462bf726199 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -414,7 +414,7 @@ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: if arg not in kwargs and arg in default_args: kwargs[arg] = default_args[arg] - missing_args = non_optional_args - set(kwargs) + missing_args = non_optional_args.difference(kwargs) if len(missing_args) == 1: raise AirflowException(f"missing keyword argument {missing_args.pop()!r}") elif missing_args: diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 10f4c595d7798..1cb9220e58cac 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -740,7 +740,7 @@ def __hash__(self): for c in self._comps: # task_ids returns a list and lists can't be hashed if c == "task_ids": - val = tuple(self.task_dict.keys()) + val = tuple(self.task_dict) else: val = getattr(self, c, None) try: @@ -1256,7 +1256,7 @@ def tasks(self, val): @property def task_ids(self) -> list[str]: - return list(self.task_dict.keys()) + return list(self.task_dict) @property def teardowns(self) -> list[Operator]: @@ -2897,7 +2897,7 @@ def bulk_write_to_db( log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} - dag_ids = set(dag_by_ids.keys()) + dag_ids = set(dag_by_ids) query = ( select(DagModel) .options(joinedload(DagModel.tags, innerjoin=False)) @@ -3235,7 +3235,7 @@ def get_serialized_fields(cls): "auto_register", "fail_stop", } - cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list + cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) - exclusion_list return cls.__serialized_fields def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: @@ -3594,21 +3594,18 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[ .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) ) } - dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) + dataset_triggered_dag_ids = set(dataset_triggered_dag_info) if dataset_triggered_dag_ids: - exclusion_list = { - x - for x in ( - session.scalars( - select(DagModel.dag_id) - .join(DagRun.dag_model) - .where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) - .where(DagModel.dag_id.in_(dataset_triggered_dag_ids)) - .group_by(DagModel.dag_id) - .having(func.count() >= func.max(DagModel.max_active_runs)) - ) + exclusion_list = set( + session.scalars( + select(DagModel.dag_id) + .join(DagRun.dag_model) + .where(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING))) + .where(DagModel.dag_id.in_(dataset_triggered_dag_ids)) + .group_by(DagModel.dag_id) + .having(func.count() >= func.max(DagModel.max_active_runs)) ) - } + ) if exclusion_list: dataset_triggered_dag_ids -= exclusion_list dataset_triggered_dag_info = { diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index e53c2ce3bd4f9..a8f5b4d6fc28c 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -169,7 +169,7 @@ def dag_ids(self) -> list[str]: :return: a list of DAG IDs in this bag """ - return list(self.dags.keys()) + return list(self.dags) @provide_session def get_dag(self, dag_id, session: Session = None): diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8a9a3d874012d..36fb5f41650ad 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -168,7 +168,7 @@ def _expand_mapped_field(self, key: str, value: Any, context: Context, *, sessio def _find_index_for_this_field(index: int) -> int: # Need to use the original user input to retain argument order. - for mapped_key in reversed(list(self.value)): + for mapped_key in reversed(self.value): mapped_length = all_lengths[mapped_key] if mapped_length < 1: raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index f52749c7ff4b0..8c197491049d8 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -37,7 +37,7 @@ class DependencyMixin: - """Mixing implementing common dependency setting methods methods like >> and <<.""" + """Mixing implementing common dependency setting methods like >> and <<.""" @property def roots(self) -> Sequence[DependencyMixin]: