Skip to content

Commit

Permalink
Merge pull request #1 from palamatt95/master_uproot4
Browse files Browse the repository at this point in the history
Master uproot4
  • Loading branch information
Luca Giommi authored Jan 18, 2022
2 parents f8f9b50 + b37a933 commit b30b1be
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/python/MLaaS4HEP/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ def train_model(model, files, labels, preproc=None, params=None, specs=None, fou
print(f"\n####Time for training: {time.time()-time0}\n\n")

if fout and hasattr(trainer, 'save'):
trainer.save(fout)
trainer.save(fout)
67 changes: 27 additions & 40 deletions src/python/MLaaS4HEP/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@
# uproot
try:
import uproot
try:
# uproot verion 3.X
from awkward import JaggedArray
except ImportError:
# uproot verion 2.X
from uproot.interp.jagged import JaggedArray
except ImportError:
pass

Expand Down Expand Up @@ -148,19 +142,18 @@ def dim_jarr(arr):
jdim = len(item)
return jdim

def min_max_arr(arr):
def min_max_arr(jagged_key, key, arr):
"""
Helper function to find out min/max values of given array.
The array can be either jagged one or normal numpy.ndarray
"""
try:
if isinstance(arr, JaggedArray):
arr = arr.flatten()
if key in jagged_key:
arr = np.concatenate(arr, axis = 0)
return float(np.min(arr)), float(np.max(arr))
except ValueError:
return 1e15, -1e15


class FileReader(object):
"""
FileReader represents generic interface to read data files
Expand Down Expand Up @@ -600,11 +593,11 @@ def __init__(self, fin, branch='Events', selected_branches=None, \
if exclude_branches:
print(f"Excluded branches: {exclude_branches}")
all_branches=self.tree.keys()
exclude_branches=[elem.encode() for elem in exclude_branches]
exclude_branches=[elem for elem in exclude_branches]
self.out_branches=[elem for elem in all_branches if (elem not in exclude_branches)]
if selected_branches:
print(f"Selected branches: {selected_branches}")
selected_branches=[elem.encode() for elem in selected_branches]
selected_branches=[elem for elem in selected_branches]
self.out_branches=[elem for elem in selected_branches]

# perform initialization
Expand Down Expand Up @@ -638,19 +631,17 @@ def load_specs(self, specs):
self.fkeys = specs['fkeys']
self.nans = specs['nans']

self.flat_keys_encoded = sorted([key.encode('ascii') for key in self.flat_keys()])
self.jagged_keys_encoded = sorted([key.encode('ascii') for key in self.jagged_keys()])
self.flat_keys_encoded = [key for key in self.flat_keys()]
self.jagged_keys_encoded = [key for key in self.jagged_keys()]
self.keys = self.flat_keys_encoded + self.jagged_keys_encoded
self.min_list = [self.minv[key.decode('ascii')] for key in self.keys]
self.max_list = [self.maxv[key.decode('ascii')] for key in self.keys]
self.jdimension = [self.jdim[key.decode('ascii')] for key in self.jagged_keys_encoded]
self.min_list = [self.minv[key] for key in self.keys]
self.max_list = [self.maxv[key] for key in self.keys]
self.jdimension = [self.jdim[key] for key in self.jagged_keys_encoded]
self.dimension_list = [1] * len(self.flat_keys_encoded)
self.dimension_list = self.dimension_list + self.jdimension

def fetch_data(self, key):
"fetch data for given key from underlying ROOT tree"
if sys.version.startswith('3.') and isinstance(key, str):
key = key.encode('ascii') # convert string to binary
if key in self.branches:
return self.branches[key]
raise Exception('Unable to find "%s" key in ROOT branches' % key)
Expand All @@ -661,42 +652,42 @@ def read_chunk(self, nevts, set_branches=False, set_min_max=False):
start_time = time.time()
if not self.gen:
if self.out_branches:
self.gen = self.tree.iterate(\
branches=self.out_branches, \
entrysteps=nevts, keycache=self.cache)
else:
self.gen = self.tree.iterate(\
entrysteps=nevts, keycache=self.cache)
self.gen = self.tree.iterate( \
self.out_branches, \
step_size=nevts, \
library="np")
else:
self.gen = self.tree.iterate( \
step_size=nevts, \
library='np')
self.branches = {} # start with fresh dict
try:
self.branches = next(self.gen) # python 3.X and 2.X
except StopIteration:
if self.out_branches:
self.gen = self.tree.iterate(\
branches=self.out_branches, \
entrysteps=nevts, keycache=self.cache)
self.gen = self.tree.iterate( \
branches=self.out_branches, \
step_size=nevts, \
library='np')
else:
self.gen = self.tree.iterate(entrysteps=nevts, keycache=self.cache)
self.gen = self.tree.iterate(step_size=nevts, library='np')
self.branches = next(self.gen) # python 3.X and 2.X

self.time_reading.append(time.time()-start_time)
end_time = time.time()
self.idx += nevts
if self.verbose:
performance(nevts, self.tree, self.branches, start_time, end_time)
if set_branches:
for key, val in self.branches.items():
if isinstance(key, bytes):
key = key.decode()
self.minv[key], self.maxv[key] = min_max_arr(val)
if isinstance(val, JaggedArray):
if isinstance(self.tree[key].interpretation, uproot.AsJagged):
self.jkeys.append(key)
else:
self.fkeys.append(key)
self.minv[key], self.maxv[key] = min_max_arr(self.jkeys, key, val)
if set_min_max:
for key, val in self.branches.items():
if isinstance(key, bytes):
key = key.decode()
minv, maxv = min_max_arr(val)
minv, maxv = min_max_arr(self.jkeys, key, val)
if minv < self.minv[key]:
self.minv[key] = minv
if maxv > self.maxv[key]:
Expand Down Expand Up @@ -760,8 +751,6 @@ def init(self):
set_branches = False # we do it once
for key in self.jkeys:
if key not in self.jdim:
if isinstance(key, bytes):
key = key.decode()
self.jdim[key] = 0
dim = dim_jarr(self.fetch_data(key))
if dim > self.jdim.get(key, 0):
Expand All @@ -785,8 +774,6 @@ def init(self):
# initialize all nan values (zeros) in normalize phase-space
# this should be done after we get all min/max values
for key in self.branches.keys():
if isinstance(key, bytes):
key = key.decode()
self.nans[key] = self.normalize(key, 0)

# reset internal indexes since we done with first pass reading
Expand Down
9 changes: 2 additions & 7 deletions src/python/MLaaS4HEP/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@

# uproot
try:
try:
# uproot verion 3.X
from awkward import JaggedArray
except ImportError:
# uproot verion 2.X
from uproot.interp.jagged import JaggedArray
import uproot
except ImportError:
pass

Expand Down Expand Up @@ -121,7 +116,7 @@ def performance(nevts, tree, data, start_time, end_time, msg=""):
"helper function to show performance metrics of data read from a given tree"
try:
nbytes = sum(x.content.nbytes + x.stops.nbytes \
if isinstance(x, JaggedArray) \
if isinstance(x, uproot.AsJagged) \
else x.nbytes for x in data.values())
print("# %s entries, %s %sbranches, %s MB, %s sec, %s MB/sec, %s kHz" % \
( \
Expand Down

0 comments on commit b30b1be

Please sign in to comment.