Skip to content

Commit

Permalink
provide interface for serializing taskgraphs to/from disk
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Feb 8, 2024
1 parent 041a931 commit aafec0a
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 44 deletions.
9 changes: 8 additions & 1 deletion src/coffea/dataset_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from coffea.dataset_tools.apply_processor import apply_to_dataset, apply_to_fileset
from coffea.dataset_tools.apply_processor import (
apply_to_dataset,
apply_to_fileset,
load_taskgraph,
save_taskgraph,
)
from coffea.dataset_tools.manipulations import (
filter_files,
get_failed_steps_for_dataset,
Expand All @@ -14,6 +19,8 @@
"preprocess",
"apply_to_dataset",
"apply_to_fileset",
"save_taskgraph",
"load_taskgraph",
"max_chunks",
"slice_chunks",
"filter_files",
Expand Down
123 changes: 91 additions & 32 deletions src/coffea/dataset_tools/apply_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from coffea.nanoevents import BaseSchema, NanoAODSchema, NanoEventsFactory
from coffea.processor import ProcessorABC
from coffea.util import decompress_form
from coffea.util import decompress_form, load, save

DaskOutputBaseType = Union[
dask.base.DaskMethodsMixin,
Expand Down Expand Up @@ -48,8 +48,6 @@ def _pack_meta_to_wire(*collections):
attrs=unpacked[i]._meta.attrs,
)
packed_out = repacker(output)
if len(packed_out) == 1:
return packed_out[0]
return packed_out


Expand All @@ -68,21 +66,13 @@ def _unpack_meta_from_wire(*collections):
attrs=unpacked[i]._meta.attrs,
)
packed_out = repacker(output)
if len(packed_out) == 1:
return packed_out[0]
return packed_out


def _apply_analysis_wire(analysis, events_and_maybe_report_wire):
events = _unpack_meta_from_wire(events_and_maybe_report_wire)
report = None
if isinstance(events, tuple):
events, report = events
def _apply_analysis_wire(analysis, events_wire):
(events,) = _unpack_meta_from_wire(events_wire)
events._meta.attrs["@original_array"] = events

out = analysis(events)
if report is not None:
return _pack_meta_to_wire(out, report)
return _pack_meta_to_wire(out)


Expand Down Expand Up @@ -145,16 +135,14 @@ def apply_to_dataset(

out = None
if parallelize_with_dask:
if not isinstance(events_and_maybe_report, tuple):
events_and_maybe_report = (events_and_maybe_report,)
wired_events = _pack_meta_to_wire(*events_and_maybe_report)
(wired_events,) = _pack_meta_to_wire(events)
out = dask.delayed(partial(_apply_analysis_wire, analysis, wired_events))()
else:
out = analysis(events)

if report is not None:
return out, report
return out
return events, out, report
return events, out


def apply_to_fileset(
Expand Down Expand Up @@ -184,11 +172,14 @@ def apply_to_fileset(
Returns
-------
events: dict[str, dask_awkward.Array]
The NanoEvents objects the analysis function was applied to.
out : dict[str, DaskOutputType]
The output of the analysis workflow applied to the datasets, keyed by dataset name.
report : dask_awkward.Array, optional
The file access report for running the analysis on the input dataset. Needs to be computed in simultaneously with the analysis to be accurate.
"""
events = {}
out = {}
analyses_to_compute = {}
report = {}
Expand All @@ -206,24 +197,92 @@ def apply_to_fileset(
parallelize_with_dask,
)
if parallelize_with_dask:
analyses_to_compute[name] = dataset_out
elif isinstance(dataset_out, tuple):
out[name], report[name] = dataset_out
if len(dataset_out) == 3:
events[name], analyses_to_compute[name], report[name] = dataset_out
elif len(dataset_out) == 2:
events[name], analyses_to_compute[name] = dataset_out
else:
raise ValueError(
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
)
elif isinstance(dataset_out, tuple) and len(dataset_out) == 3:
events[name], out[name], report[name] = dataset_out
elif isinstance(dataset_out, tuple) and len(dataset_out) == 2:
events[name], out[name] = dataset_out
else:
out[name] = dataset_out
raise ValueError(
"apply_to_dataset only returns (events, outputs) or (events, outputs, reports)"
)

if parallelize_with_dask:
(calculated_graphs,) = dask.compute(analyses_to_compute, scheduler=scheduler)
for name, dataset_out_wire in calculated_graphs.items():
to_unwire = dataset_out_wire
if not isinstance(dataset_out_wire, tuple):
to_unwire = (dataset_out_wire,)
dataset_out = _unpack_meta_from_wire(*to_unwire)
if isinstance(dataset_out, tuple):
out[name], report[name] = dataset_out
else:
out[name] = dataset_out
(out[name],) = _unpack_meta_from_wire(*dataset_out_wire)

if len(report) > 0:
return out, report
return out
return events, out, report
return events, out


def save_taskgraph(filename, events, *data_products, optimize_graph=False):
"""
Save a task graph and its originating nanoevents to a file
Parameters
----------
filename: str
Where to save the resulting serialized taskgraph and nanoevents.
Suggested postfix ".hlg", after dask's HighLevelGraph object.
events: dict[str, dask_awkward.Array]
A dictionary of nanoevents objects.
data_products: dict[str, DaskOutputBaseType]
The data products resulting from applying an analysis to
a NanoEvents object. This may include report objects.
optimize_graph: bool, default False
Whether or not to save the task graph in its optimized form.
Returns
-------
"""
(events_wire,) = _pack_meta_to_wire(events)

if len(data_products) == 0:
raise ValueError(
"You must supply at least one analysis data product to save a task graph!"
)

data_products_out = data_products
if optimize_graph:
data_products_out = dask.optimize(data_products)

data_products_wire = _pack_meta_to_wire(*data_products_out)

save(
{
"events": events_wire,
"data_products": data_products_wire,
"optimized": optimize_graph,
},
filename,
)


def load_taskgraph(filename):
"""
Load a task graph and its originating nanoevents from a file.
Parameters
----------
filename: str
The file from which to load the task graph.
Returns
_______
"""
graph_information_wire = load(filename)

(events,) = _unpack_meta_from_wire(graph_information_wire["events"])
(data_products,) = _unpack_meta_from_wire(*graph_information_wire["data_products"])
optimized = graph_information_wire["optimized"]

for dataset_name in events:
events[dataset_name]._meta.attrs["@original_array"] = events[dataset_name]

return events, data_products, optimized
11 changes: 5 additions & 6 deletions src/coffea/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,20 @@
import lz4.frame


def load(filename):
def load(filename, mode="rb"):
"""Load a coffea file from disk"""
with lz4.frame.open(filename) as fin:
with lz4.frame.open(filename, mode) as fin:
output = cloudpickle.load(fin)
return output


def save(output, filename):
def save(output, filename, mode="wb"):
"""Save a coffea object or collection thereof to disk
This function can accept any picklable object. Suggested suffix: ``.coffea``
"""
with lz4.frame.open(filename, "wb") as fout:
thepickle = cloudpickle.dumps(output)
fout.write(thepickle)
with lz4.frame.open(filename, mode) as fout:
cloudpickle.dump(output, fout)


def _hex(string):
Expand Down
59 changes: 54 additions & 5 deletions tests/test_dataset_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
apply_to_fileset,
filter_files,
get_failed_steps_for_fileset,
load_taskgraph,
max_chunks,
max_files,
preprocess,
save_taskgraph,
slice_chunks,
slice_files,
)
Expand Down Expand Up @@ -202,7 +204,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
proc, schemaclass = proc_and_schema

with Client() as _:
to_compute = apply_to_fileset(
_, to_compute = apply_to_fileset(
proc(),
_runnable_result,
schemaclass=schemaclass,
Expand All @@ -215,7 +217,7 @@ def test_apply_to_fileset(proc_and_schema, delayed_taskgraph_calc):
assert out["Data"]["cutflow"]["Data_pt"] == 84
assert out["Data"]["cutflow"]["Data_mass"] == 66

to_compute = apply_to_fileset(
_, to_compute = apply_to_fileset(
proc(),
max_chunks(_runnable_result, 1),
schemaclass=schemaclass,
Expand All @@ -240,7 +242,7 @@ def test_apply_to_fileset_hinted_form():
save_form=True,
)

to_compute = apply_to_fileset(
_, to_compute = apply_to_fileset(
NanoEventsProcessor(),
dataset_runnable,
schemaclass=NanoAODSchema,
Expand Down Expand Up @@ -445,14 +447,14 @@ def test_slice_chunks():
@pytest.mark.parametrize("delayed_taskgraph_calc", [True, False])
def test_recover_failed_chunks(delayed_taskgraph_calc):
with Client() as _:
to_compute = apply_to_fileset(
_, to_compute, reports = apply_to_fileset(
NanoEventsProcessor(),
_starting_fileset_with_steps,
schemaclass=NanoAODSchema,
uproot_options={"allow_read_errors_with_report": True},
parallelize_with_dask=delayed_taskgraph_calc,
)
out, reports = dask.compute(*to_compute)
out, reports = dask.compute(to_compute, reports)

failed_fset = get_failed_steps_for_fileset(_starting_fileset_with_steps, reports)
assert failed_fset == {
Expand All @@ -474,3 +476,50 @@ def test_recover_failed_chunks(delayed_taskgraph_calc):
}
}
}


@pytest.mark.parametrize(
"proc_and_schema",
[(NanoTestProcessor, BaseSchema), (NanoEventsProcessor, NanoAODSchema)],
)
@pytest.mark.parametrize(
"with_report",
[True, False],
)
def test_task_graph_serialization(proc_and_schema, with_report):
proc, schemaclass = proc_and_schema

with Client() as _:
output = apply_to_fileset(
proc(),
_runnable_result,
schemaclass=schemaclass,
parallelize_with_dask=False,
uproot_options={"allow_read_errors_with_report": with_report},
)

events = output[0]
to_compute = output[1:]

save_taskgraph(
"./test_task_graph_serialization.hlg",
events,
to_compute,
optimize_graph=False,
)

_, to_compute_serdes, is_optimized = load_taskgraph(
"./test_task_graph_serialization.hlg"
)

print(to_compute_serdes)

if len(to_compute_serdes) > 1:
(out, _) = dask.compute(*to_compute_serdes)
else:
(out,) = dask.compute(*to_compute_serdes)

assert out["ZJets"]["cutflow"]["ZJets_pt"] == 18
assert out["ZJets"]["cutflow"]["ZJets_mass"] == 6
assert out["Data"]["cutflow"]["Data_pt"] == 84
assert out["Data"]["cutflow"]["Data_mass"] == 66

0 comments on commit aafec0a

Please sign in to comment.