From 5922d9c62523abf744fbe3578a58a7df0b5b046f Mon Sep 17 00:00:00 2001 From: Gregory Borodin Date: Thu, 1 Aug 2024 23:37:36 +0200 Subject: [PATCH] Fix MRO in operators without `__init__` (#41086) * Fix MRO in operators without `__init__` * Add mro test * Add test comment --- airflow/models/baseoperator.py | 6 +++++- tests/models/test_baseoperator.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 77423bfc3b99e..7ffa596ec67a1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -517,7 +517,11 @@ def __new__(cls, name, bases, namespace, **kwargs): partial_desc = vars(new_cls)["partial"] if isinstance(partial_desc, _PartialDescriptor): partial_desc.class_method = classmethod(partial) - new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + + # We patch `__init__` only if the class defines it. + if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: + new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + return new_cls diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index a73db360aabb2..b94a1b9f819d6 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -42,6 +42,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.providers.common.sql.operators import sql from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup @@ -1115,3 +1116,23 @@ def test_get_task_instances(session): assert task.get_task_instances( session=session, start_date=second_execution_date, end_date=second_execution_date ) == [ti_2] + + +def test_mro(): + class Mixin(sql.BaseSQLOperator): + pass + + class Branch(Mixin, sql.BranchSQLOperator): + pass + + # The following throws an exception if metaclass breaks MRO: + # airflow.exceptions.AirflowException: Invalid arguments were passed to Branch (task_id: test). Invalid arguments were: + # **kwargs: {'sql': 'sql', 'follow_task_ids_if_true': ['x'], 'follow_task_ids_if_false': ['y']} + op = Branch( + task_id="test", + conn_id="abc", + sql="sql", + follow_task_ids_if_true=["x"], + follow_task_ids_if_false=["y"], + ) + assert isinstance(op, Branch)