Skip to content

Commit

Permalink
Fix MRO in operators without __init__ (apache#41086)
Browse files Browse the repository at this point in the history
* Fix MRO in operators without `__init__`

* Add mro test

* Add test comment
  • Loading branch information
grihabor authored Aug 1, 2024
1 parent 9fbaca0 commit 5922d9c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
6 changes: 5 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,11 @@ def __new__(cls, name, bases, namespace, **kwargs):
partial_desc = vars(new_cls)["partial"]
if isinstance(partial_desc, _PartialDescriptor):
partial_desc.class_method = classmethod(partial)
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)

# We patch `__init__` only if the class defines it.
if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__:
new_cls.__init__ = cls._apply_defaults(new_cls.__init__)

return new_cls


Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.sql.operators import sql
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -1115,3 +1116,23 @@ def test_get_task_instances(session):
assert task.get_task_instances(
session=session, start_date=second_execution_date, end_date=second_execution_date
) == [ti_2]


def test_mro():
class Mixin(sql.BaseSQLOperator):
pass

class Branch(Mixin, sql.BranchSQLOperator):
pass

# The following throws an exception if metaclass breaks MRO:
# airflow.exceptions.AirflowException: Invalid arguments were passed to Branch (task_id: test). Invalid arguments were:
# **kwargs: {'sql': 'sql', 'follow_task_ids_if_true': ['x'], 'follow_task_ids_if_false': ['y']}
op = Branch(
task_id="test",
conn_id="abc",
sql="sql",
follow_task_ids_if_true=["x"],
follow_task_ids_if_false=["y"],
)
assert isinstance(op, Branch)

0 comments on commit 5922d9c

Please sign in to comment.