diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 77423bfc3b99e..7ffa596ec67a1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -517,7 +517,11 @@ def __new__(cls, name, bases, namespace, **kwargs): partial_desc = vars(new_cls)["partial"] if isinstance(partial_desc, _PartialDescriptor): partial_desc.class_method = classmethod(partial) - new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + + # We patch `__init__` only if the class defines it. + if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__: + new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + return new_cls diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index a73db360aabb2..b94a1b9f819d6 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -42,6 +42,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.providers.common.sql.operators import sql from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup @@ -1115,3 +1116,23 @@ def test_get_task_instances(session): assert task.get_task_instances( session=session, start_date=second_execution_date, end_date=second_execution_date ) == [ti_2] + + +def test_mro(): + class Mixin(sql.BaseSQLOperator): + pass + + class Branch(Mixin, sql.BranchSQLOperator): + pass + + # The following throws an exception if metaclass breaks MRO: + # airflow.exceptions.AirflowException: Invalid arguments were passed to Branch (task_id: test). Invalid arguments were: + # **kwargs: {'sql': 'sql', 'follow_task_ids_if_true': ['x'], 'follow_task_ids_if_false': ['y']} + op = Branch( + task_id="test", + conn_id="abc", + sql="sql", + follow_task_ids_if_true=["x"], + follow_task_ids_if_false=["y"], + ) + assert isinstance(op, Branch)