From 39e48e96d886df2dc53d659ee14c94dc38133845 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 5 Dec 2024 11:54:53 +0100 Subject: [PATCH 1/4] Only save one preprocessing task at a time --- esmvalcore/_task.py | 93 +++++++++++++++++------------ esmvalcore/preprocessor/__init__.py | 18 +++++- tests/integration/test_task.py | 14 ++++- 3 files changed, 81 insertions(+), 44 deletions(-) diff --git a/esmvalcore/_task.py b/esmvalcore/_task.py index 27a6b83d14..30b029ac34 100644 --- a/esmvalcore/_task.py +++ b/esmvalcore/_task.py @@ -5,6 +5,7 @@ import datetime import importlib import logging +import multiprocessing import numbers import os import pprint @@ -14,7 +15,6 @@ import threading import time from copy import deepcopy -from multiprocessing import Pool from pathlib import Path, PosixPath from shutil import which from typing import Optional @@ -260,6 +260,7 @@ def __init__(self, ancestors=None, name="", products=None): self.name = name self.activity = None self.priority = 0 + self.scheduler_lock = None def initialize_provenance(self, recipe_entity): """Initialize task provenance activity.""" @@ -854,45 +855,58 @@ def done(task): """Assume a task is done if it not scheduled or running.""" return not (task in scheduled or task in running) - with Pool(processes=max_parallel_tasks) as pool: - while scheduled or running: - # Submit new tasks to pool - for task in sorted(scheduled, key=lambda t: t.priority): - if len(running) >= max_parallel_tasks: - break - if all(done(t) for t in task.ancestors): - future = pool.apply_async( - _run_task, [task, scheduler_address] + with multiprocessing.Manager() as manager: + # Use a lock to avoid overloading the Dask workers by making only + # one :class:`esmvalcore.preprocessor.PreprocessingTask` submit its + # data save task graph to the scheduler at a time. + # + # See https://github.com/ESMValGroup/ESMValCore/issues/2609 for + # additional detail. + scheduler_lock = manager.Lock() + + with multiprocessing.Pool(processes=max_parallel_tasks) as pool: + while scheduled or running: + # Submit new tasks to pool + for task in sorted(scheduled, key=lambda t: t.priority): + if len(running) >= max_parallel_tasks: + break + if all(done(t) for t in task.ancestors): + future = pool.apply_async( + _run_task, + [task, scheduler_address, scheduler_lock], + ) + running[task] = future + scheduled.remove(task) + + # Handle completed tasks + ready = {t for t in running if running[t].ready()} + for task in ready: + _copy_results(task, running[task]) + running.pop(task) + + # Wait if there are still tasks running + if running: + time.sleep(0.1) + + # Log progress message + if ( + len(scheduled) != n_scheduled + or len(running) != n_running + ): + n_scheduled, n_running = len(scheduled), len(running) + n_done = n_tasks - n_scheduled - n_running + logger.info( + "Progress: %s tasks running, %s tasks waiting for " + "ancestors, %s/%s done", + n_running, + n_scheduled, + n_done, + n_tasks, ) - running[task] = future - scheduled.remove(task) - - # Handle completed tasks - ready = {t for t in running if running[t].ready()} - for task in ready: - _copy_results(task, running[task]) - running.pop(task) - - # Wait if there are still tasks running - if running: - time.sleep(0.1) - - # Log progress message - if len(scheduled) != n_scheduled or len(running) != n_running: - n_scheduled, n_running = len(scheduled), len(running) - n_done = n_tasks - n_scheduled - n_running - logger.info( - "Progress: %s tasks running, %s tasks waiting for " - "ancestors, %s/%s done", - n_running, - n_scheduled, - n_done, - n_tasks, - ) - logger.info("Successfully completed all tasks.") - pool.close() - pool.join() + logger.info("Successfully completed all tasks.") + pool.close() + pool.join() def _copy_results(task, future): @@ -900,7 +914,7 @@ def _copy_results(task, future): task.output_files, task.products = future.get() -def _run_task(task, scheduler_address): +def _run_task(task, scheduler_address, scheduler_lock): """Run task and return the result.""" if scheduler_address is None: client = contextlib.nullcontext() @@ -908,6 +922,7 @@ def _run_task(task, scheduler_address): client = Client(scheduler_address) with client: + task.scheduler_lock = scheduler_lock output_files = task.run() return output_files, task.products diff --git a/esmvalcore/preprocessor/__init__.py b/esmvalcore/preprocessor/__init__.py index 2c956aa0ad..6ba0d7c946 100644 --- a/esmvalcore/preprocessor/__init__.py +++ b/esmvalcore/preprocessor/__init__.py @@ -736,9 +736,23 @@ def _run(self, _) -> list[str]: delayed = product.close() delayeds.append(delayed) - logger.info("Computing and saving data for task %s", self.name) delayeds = [d for d in delayeds if d is not None] - _compute_with_progress(delayeds, description=self.name) + + if self.scheduler_lock is not None: + logger.debug("Acquiring save lock for task %s", self.name) + self.scheduler_lock.acquire() + logger.debug("Acquired save lock for task %s", self.name) + try: + logger.info( + "Computing and saving data for preprocessing task %s", + self.name, + ) + _compute_with_progress(delayeds, description=self.name) + finally: + if self.scheduler_lock is not None: + self.scheduler_lock.release() + logger.debug("Released save lock for task %s", self.name) + metadata_files = write_metadata( self.products, self.write_ncl_interface ) diff --git a/tests/integration/test_task.py b/tests/integration/test_task.py index 9570ec8e58..d8fec5a416 100644 --- a/tests/integration/test_task.py +++ b/tests/integration/test_task.py @@ -92,7 +92,9 @@ def test_run_tasks(monkeypatch, max_parallel_tasks, example_tasks, mpmethod): get_distributed_client_mock(None), ) monkeypatch.setattr( - esmvalcore._task, "Pool", multiprocessing.get_context(mpmethod).Pool + esmvalcore._task.multiprocessing, + "Pool", + multiprocessing.get_context(mpmethod).Pool, ) example_tasks.run(max_parallel_tasks=max_parallel_tasks) @@ -152,7 +154,7 @@ def _run(self, input_files): return [f"{self.name}_test.nc"] monkeypatch.setattr(MockBaseTask, "_run", _run) - monkeypatch.setattr(esmvalcore._task, "Pool", ThreadPool) + monkeypatch.setattr(esmvalcore._task.multiprocessing, "Pool", ThreadPool) runner(example_tasks) print(order) @@ -165,11 +167,17 @@ def test_run_task(mocker, address): # Set up mock Dask distributed client mocker.patch.object(esmvalcore._task, "Client") + # Set up a mock multiprocessing.Lock + scheduler_lock = mocker.sentinel + task = mocker.create_autospec(DiagnosticTask, instance=True) task.products = mocker.Mock() - output_files, products = _run_task(task, scheduler_address=address) + output_files, products = _run_task( + task, scheduler_address=address, scheduler_lock=scheduler_lock + ) assert output_files == task.run.return_value assert products == task.products + assert task.scheduler_lock == scheduler_lock if address is None: esmvalcore._task.Client.assert_not_called() else: From e2bd7e89bbea0e4b23ac210d4ab0c28a73836609 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 5 Dec 2024 12:13:38 +0100 Subject: [PATCH 2/4] Add test --- .../preprocessor/test_preprocessing_task.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/integration/preprocessor/test_preprocessing_task.py b/tests/integration/preprocessor/test_preprocessing_task.py index 5b74a94cda..7bb3b89dbf 100644 --- a/tests/integration/preprocessor/test_preprocessing_task.py +++ b/tests/integration/preprocessor/test_preprocessing_task.py @@ -2,6 +2,7 @@ import iris import iris.cube +import pytest from prov.model import ProvDocument import esmvalcore.preprocessor @@ -9,7 +10,8 @@ from esmvalcore.preprocessor import PreprocessingTask, PreprocessorFile -def test_load_save_task(tmp_path): +@pytest.mark.parametrize("scheduler_lock", [False, True]) +def test_load_save_task(tmp_path, mocker, scheduler_lock): """Test that a task that just loads and saves a file.""" # Prepare a test dataset cube = iris.cube.Cube(data=[273.0], var_name="tas", units="K") @@ -36,6 +38,9 @@ def test_load_save_task(tmp_path): activity = provenance.activity("software:esmvalcore") task.initialize_provenance(activity) + if scheduler_lock: + task.scheduler_lock = mocker.Mock() + task.run() assert len(task.products) == 1 @@ -45,6 +50,12 @@ def test_load_save_task(tmp_path): result.attributes.clear() assert result == cube + if scheduler_lock: + assert task.scheduler_lock.acquire.called_once_with() + assert task.scheduler_lock.release.called_once_with() + else: + assert task.scheduler_lock is None + def test_load_save_and_other_task(tmp_path, monkeypatch): """Test that a task just copies one file and preprocesses another file.""" From a0ded99ff7c0d594f9e3976ae72b9d0a449285d5 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 5 Dec 2024 12:20:43 +0100 Subject: [PATCH 3/4] Fix test --- tests/integration/preprocessor/test_preprocessing_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/preprocessor/test_preprocessing_task.py b/tests/integration/preprocessor/test_preprocessing_task.py index 7bb3b89dbf..43dc7af6a6 100644 --- a/tests/integration/preprocessor/test_preprocessing_task.py +++ b/tests/integration/preprocessor/test_preprocessing_task.py @@ -51,8 +51,8 @@ def test_load_save_task(tmp_path, mocker, scheduler_lock): assert result == cube if scheduler_lock: - assert task.scheduler_lock.acquire.called_once_with() - assert task.scheduler_lock.release.called_once_with() + task.scheduler_lock.acquire.assert_called_once_with() + task.scheduler_lock.release.assert_called_once_with() else: assert task.scheduler_lock is None From 1d4e23fa8dce9a66076c1df0128990cf1bf10a5b Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 13 Dec 2024 17:48:03 +0100 Subject: [PATCH 4/4] Only use lock with distributed scheduler --- esmvalcore/_task.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/esmvalcore/_task.py b/esmvalcore/_task.py index 30b029ac34..cb9269b087 100644 --- a/esmvalcore/_task.py +++ b/esmvalcore/_task.py @@ -858,11 +858,13 @@ def done(task): with multiprocessing.Manager() as manager: # Use a lock to avoid overloading the Dask workers by making only # one :class:`esmvalcore.preprocessor.PreprocessingTask` submit its - # data save task graph to the scheduler at a time. + # data save task graph to the distributed scheduler at a time. # # See https://github.com/ESMValGroup/ESMValCore/issues/2609 for # additional detail. - scheduler_lock = manager.Lock() + scheduler_lock = ( + None if scheduler_address is None else manager.Lock() + ) with multiprocessing.Pool(processes=max_parallel_tasks) as pool: while scheduled or running: