From 0d790428cafd52f92d2756bdcbc7ab2f4c7014b9 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 26 Oct 2024 11:20:15 +0100 Subject: [PATCH] avoid refcycles in trio.run, especially if an Exception is raised --- demo.py | 72 ++++++++++++++++++++++++++++++++++++++++++ src/trio/_core/_run.py | 19 +++++++---- 2 files changed, 85 insertions(+), 6 deletions(-) create mode 100644 demo.py diff --git a/demo.py b/demo.py new file mode 100644 index 000000000..054d8fd6f --- /dev/null +++ b/demo.py @@ -0,0 +1,72 @@ +import trio + + +async def main(): + err = None + with trio.CancelScope() as scope: + scope.cancel() + try: + await trio.sleep_forever() + except BaseException as e: + err = e + raise + breakpoint() + + +# trio.run(main) + +import gc + +import objgraph +from anyio import CancelScope, get_cancelled_exc_class + + +async def test_exception_refcycles_propagate_cancellation_error() -> None: + """Test that TaskGroup deletes cancelled_exc""" + exc = None + + with CancelScope() as cs: + cs.cancel() + try: + await trio.sleep_forever() + except get_cancelled_exc_class() as e: + exc = e + raise + + assert isinstance(exc, get_cancelled_exc_class()) + gc.collect() + objgraph.show_chain( + objgraph.find_backref_chain( + gc.get_referrers(exc)[0], + objgraph.is_proper_module, + ), + ) + + +# trio.run(test_exception_refcycles_propagate_cancellation_error) + + +class MyException(Exception): + pass + + +async def main(): + raise MyException + + +def inner(): + try: + trio.run(main) + except MyException: + pass + + +import refcycle + +gc.disable() +gc.collect() +inner() +garbage = refcycle.garbage() +for i, component in enumerate(garbage.source_components()): + component.export_image(f"{i}_example.svg") +garbage.export_image("example.svg") diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 3961a6e10..76fc7cb06 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -1706,6 +1706,7 @@ def close(self) -> None: self.asyncgens.close() if "after_run" in self.instruments: self.instruments.call("after_run") + self.system_nursery: Nursery | None = None # This is where KI protection gets disabled, so we do it last self.ki_manager.close() @@ -1920,6 +1921,7 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: task._activate_cancel_status(None) self.tasks.remove(task) if task is self.init_task: + self.init_task = None # If the init task crashed, then something is very wrong and we # let the error propagate. (It'll eventually be wrapped in a # TrioInternalError.) @@ -1930,6 +1932,7 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: raise TrioInternalError else: if task is self.main_task: + self.main_task = None self.main_task_outcome = outcome outcome = Value(None) assert task._parent_nursery is not None, task @@ -2394,12 +2397,15 @@ def run( sniffio_library.name = prev_library # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. - if isinstance(runner.main_task_outcome, Value): - return cast(RetT, runner.main_task_outcome.value) - elif isinstance(runner.main_task_outcome, Error): - raise runner.main_task_outcome.error - else: # pragma: no cover - raise AssertionError(runner.main_task_outcome) + try: + if isinstance(runner.main_task_outcome, Value): + return cast(RetT, runner.main_task_outcome.value) + elif isinstance(runner.main_task_outcome, Error): + raise runner.main_task_outcome.error + else: # pragma: no cover + raise AssertionError(runner.main_task_outcome) + finally: + del runner def start_guest_run( @@ -2808,6 +2814,7 @@ def unrolled_run( if isinstance(runner.main_task_outcome, Error): ki.__context__ = runner.main_task_outcome.error runner.main_task_outcome = Error(ki) + del runner ################################################################