Skip to content

Commit

Permalink
use unpack collections to consistently calculate metas
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Jan 13, 2024
1 parent cbbb5b1 commit b23ce90
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/coffea/ml_tools/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask
import dask_awkward
import numpy
from dask.base import unpack_collections


class nonserializable_attribute:
Expand Down Expand Up @@ -345,7 +346,12 @@ def __call__(self, wrapper, *args):
dak_args, dak_kwargs = self.prepare_awkward(*args, **kwargs)
wrap = _callable_wrap((dak_args, dak_kwargs))
packed_args = wrap.pair_to_args(*dak_args, **dak_kwargs)
wrap_meta = wrap(self, *tuple(arg._meta for arg in packed_args))

flattened_args, repack = unpack_collections(*packed_args, traverse=True)
flattened_metas = tuple(arg._meta for arg in flattened_args)
packed_metas = repack(flattened_metas)

wrap_meta = wrap(self, *packed_metas)
delayed_wrapper = dask.delayed(self)
arr = dask_awkward.lib.core.map_partitions(
wrap,
Expand Down

0 comments on commit b23ce90

Please sign in to comment.