Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only save data from one preprocessing task at a time #2610

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 56 additions & 39 deletions esmvalcore/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import importlib
import logging
import multiprocessing
import numbers
import os
import pprint
Expand All @@ -14,7 +15,6 @@
import threading
import time
from copy import deepcopy
from multiprocessing import Pool
from pathlib import Path, PosixPath
from shutil import which
from typing import Optional
Expand Down Expand Up @@ -260,6 +260,7 @@ def __init__(self, ancestors=None, name="", products=None):
self.name = name
self.activity = None
self.priority = 0
self.scheduler_lock = None

def initialize_provenance(self, recipe_entity):
"""Initialize task provenance activity."""
Expand Down Expand Up @@ -854,60 +855,76 @@ def done(task):
"""Assume a task is done if it not scheduled or running."""
return not (task in scheduled or task in running)

with Pool(processes=max_parallel_tasks) as pool:
while scheduled or running:
# Submit new tasks to pool
for task in sorted(scheduled, key=lambda t: t.priority):
if len(running) >= max_parallel_tasks:
break
if all(done(t) for t in task.ancestors):
future = pool.apply_async(
_run_task, [task, scheduler_address]
with multiprocessing.Manager() as manager:
# Use a lock to avoid overloading the Dask workers by making only
# one :class:`esmvalcore.preprocessor.PreprocessingTask` submit its
# data save task graph to the distributed scheduler at a time.
#
# See https://github.com/ESMValGroup/ESMValCore/issues/2609 for
# additional detail.
scheduler_lock = (
None if scheduler_address is None else manager.Lock()
)

with multiprocessing.Pool(processes=max_parallel_tasks) as pool:
while scheduled or running:
# Submit new tasks to pool
for task in sorted(scheduled, key=lambda t: t.priority):
if len(running) >= max_parallel_tasks:
break
if all(done(t) for t in task.ancestors):
future = pool.apply_async(
_run_task,
[task, scheduler_address, scheduler_lock],
)
running[task] = future
scheduled.remove(task)

# Handle completed tasks
ready = {t for t in running if running[t].ready()}
for task in ready:
_copy_results(task, running[task])
running.pop(task)

# Wait if there are still tasks running
if running:
time.sleep(0.1)

# Log progress message
if (
len(scheduled) != n_scheduled
or len(running) != n_running
):
n_scheduled, n_running = len(scheduled), len(running)
n_done = n_tasks - n_scheduled - n_running
logger.info(
"Progress: %s tasks running, %s tasks waiting for "
"ancestors, %s/%s done",
n_running,
n_scheduled,
n_done,
n_tasks,
)
running[task] = future
scheduled.remove(task)

# Handle completed tasks
ready = {t for t in running if running[t].ready()}
for task in ready:
_copy_results(task, running[task])
running.pop(task)

# Wait if there are still tasks running
if running:
time.sleep(0.1)

# Log progress message
if len(scheduled) != n_scheduled or len(running) != n_running:
n_scheduled, n_running = len(scheduled), len(running)
n_done = n_tasks - n_scheduled - n_running
logger.info(
"Progress: %s tasks running, %s tasks waiting for "
"ancestors, %s/%s done",
n_running,
n_scheduled,
n_done,
n_tasks,
)

logger.info("Successfully completed all tasks.")
pool.close()
pool.join()
logger.info("Successfully completed all tasks.")
pool.close()
pool.join()


def _copy_results(task, future):
"""Update task with the results from the remote process."""
task.output_files, task.products = future.get()


def _run_task(task, scheduler_address):
def _run_task(task, scheduler_address, scheduler_lock):
"""Run task and return the result."""
if scheduler_address is None:
client = contextlib.nullcontext()
else:
client = Client(scheduler_address)

with client:
task.scheduler_lock = scheduler_lock
output_files = task.run()

return output_files, task.products
18 changes: 16 additions & 2 deletions esmvalcore/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,23 @@ def _run(self, _) -> list[str]:
delayed = product.close()
delayeds.append(delayed)

logger.info("Computing and saving data for task %s", self.name)
delayeds = [d for d in delayeds if d is not None]
_compute_with_progress(delayeds, description=self.name)

if self.scheduler_lock is not None:
logger.debug("Acquiring save lock for task %s", self.name)
self.scheduler_lock.acquire()
logger.debug("Acquired save lock for task %s", self.name)
try:
logger.info(
"Computing and saving data for preprocessing task %s",
self.name,
)
_compute_with_progress(delayeds, description=self.name)
finally:
if self.scheduler_lock is not None:
self.scheduler_lock.release()
logger.debug("Released save lock for task %s", self.name)

metadata_files = write_metadata(
self.products, self.write_ncl_interface
)
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/preprocessor/test_preprocessing_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import iris
import iris.cube
import pytest
from prov.model import ProvDocument

import esmvalcore.preprocessor
from esmvalcore.dataset import Dataset
from esmvalcore.preprocessor import PreprocessingTask, PreprocessorFile


def test_load_save_task(tmp_path):
@pytest.mark.parametrize("scheduler_lock", [False, True])
def test_load_save_task(tmp_path, mocker, scheduler_lock):
"""Test that a task that just loads and saves a file."""
# Prepare a test dataset
cube = iris.cube.Cube(data=[273.0], var_name="tas", units="K")
Expand All @@ -36,6 +38,9 @@ def test_load_save_task(tmp_path):
activity = provenance.activity("software:esmvalcore")
task.initialize_provenance(activity)

if scheduler_lock:
task.scheduler_lock = mocker.Mock()

task.run()

assert len(task.products) == 1
Expand All @@ -45,6 +50,12 @@ def test_load_save_task(tmp_path):
result.attributes.clear()
assert result == cube

if scheduler_lock:
task.scheduler_lock.acquire.assert_called_once_with()
task.scheduler_lock.release.assert_called_once_with()
else:
assert task.scheduler_lock is None


def test_load_save_and_other_task(tmp_path, monkeypatch):
"""Test that a task just copies one file and preprocesses another file."""
Expand Down
14 changes: 11 additions & 3 deletions tests/integration/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def test_run_tasks(monkeypatch, max_parallel_tasks, example_tasks, mpmethod):
get_distributed_client_mock(None),
)
monkeypatch.setattr(
esmvalcore._task, "Pool", multiprocessing.get_context(mpmethod).Pool
esmvalcore._task.multiprocessing,
"Pool",
multiprocessing.get_context(mpmethod).Pool,
)
example_tasks.run(max_parallel_tasks=max_parallel_tasks)

Expand Down Expand Up @@ -152,7 +154,7 @@ def _run(self, input_files):
return [f"{self.name}_test.nc"]

monkeypatch.setattr(MockBaseTask, "_run", _run)
monkeypatch.setattr(esmvalcore._task, "Pool", ThreadPool)
monkeypatch.setattr(esmvalcore._task.multiprocessing, "Pool", ThreadPool)

runner(example_tasks)
print(order)
Expand All @@ -165,11 +167,17 @@ def test_run_task(mocker, address):
# Set up mock Dask distributed client
mocker.patch.object(esmvalcore._task, "Client")

# Set up a mock multiprocessing.Lock
scheduler_lock = mocker.sentinel

task = mocker.create_autospec(DiagnosticTask, instance=True)
task.products = mocker.Mock()
output_files, products = _run_task(task, scheduler_address=address)
output_files, products = _run_task(
task, scheduler_address=address, scheduler_lock=scheduler_lock
)
assert output_files == task.run.return_value
assert products == task.products
assert task.scheduler_lock == scheduler_lock
if address is None:
esmvalcore._task.Client.assert_not_called()
else:
Expand Down