diff --git a/src/python/MLaaS4HEP/models.py b/src/python/MLaaS4HEP/models.py index 6ac1c7c..d8e1ae7 100644 --- a/src/python/MLaaS4HEP/models.py +++ b/src/python/MLaaS4HEP/models.py @@ -160,4 +160,4 @@ def train_model(model, files, labels, preproc=None, params=None, specs=None, fou print(f"\n####Time for training: {time.time()-time0}\n\n") if fout and hasattr(trainer, 'save'): - trainer.save(fout) \ No newline at end of file + trainer.save(fout) diff --git a/src/python/MLaaS4HEP/reader.py b/src/python/MLaaS4HEP/reader.py index 517ff35..494dc88 100755 --- a/src/python/MLaaS4HEP/reader.py +++ b/src/python/MLaaS4HEP/reader.py @@ -45,12 +45,6 @@ # uproot try: import uproot - try: - # uproot verion 3.X - from awkward import JaggedArray - except ImportError: - # uproot verion 2.X - from uproot.interp.jagged import JaggedArray except ImportError: pass @@ -148,19 +142,18 @@ def dim_jarr(arr): jdim = len(item) return jdim -def min_max_arr(arr): +def min_max_arr(jagged_key, key, arr): """ Helper function to find out min/max values of given array. The array can be either jagged one or normal numpy.ndarray """ try: - if isinstance(arr, JaggedArray): - arr = arr.flatten() + if key in jagged_key: + arr = np.concatenate(arr, axis = 0) return float(np.min(arr)), float(np.max(arr)) except ValueError: return 1e15, -1e15 - class FileReader(object): """ FileReader represents generic interface to read data files @@ -600,11 +593,11 @@ def __init__(self, fin, branch='Events', selected_branches=None, \ if exclude_branches: print(f"Excluded branches: {exclude_branches}") all_branches=self.tree.keys() - exclude_branches=[elem.encode() for elem in exclude_branches] + exclude_branches=[elem for elem in exclude_branches] self.out_branches=[elem for elem in all_branches if (elem not in exclude_branches)] if selected_branches: print(f"Selected branches: {selected_branches}") - selected_branches=[elem.encode() for elem in selected_branches] + selected_branches=[elem for elem in selected_branches] self.out_branches=[elem for elem in selected_branches] # perform initialization @@ -638,19 +631,17 @@ def load_specs(self, specs): self.fkeys = specs['fkeys'] self.nans = specs['nans'] - self.flat_keys_encoded = sorted([key.encode('ascii') for key in self.flat_keys()]) - self.jagged_keys_encoded = sorted([key.encode('ascii') for key in self.jagged_keys()]) + self.flat_keys_encoded = [key for key in self.flat_keys()] + self.jagged_keys_encoded = [key for key in self.jagged_keys()] self.keys = self.flat_keys_encoded + self.jagged_keys_encoded - self.min_list = [self.minv[key.decode('ascii')] for key in self.keys] - self.max_list = [self.maxv[key.decode('ascii')] for key in self.keys] - self.jdimension = [self.jdim[key.decode('ascii')] for key in self.jagged_keys_encoded] + self.min_list = [self.minv[key] for key in self.keys] + self.max_list = [self.maxv[key] for key in self.keys] + self.jdimension = [self.jdim[key] for key in self.jagged_keys_encoded] self.dimension_list = [1] * len(self.flat_keys_encoded) self.dimension_list = self.dimension_list + self.jdimension def fetch_data(self, key): "fetch data for given key from underlying ROOT tree" - if sys.version.startswith('3.') and isinstance(key, str): - key = key.encode('ascii') # convert string to binary if key in self.branches: return self.branches[key] raise Exception('Unable to find "%s" key in ROOT branches' % key) @@ -661,23 +652,27 @@ def read_chunk(self, nevts, set_branches=False, set_min_max=False): start_time = time.time() if not self.gen: if self.out_branches: - self.gen = self.tree.iterate(\ - branches=self.out_branches, \ - entrysteps=nevts, keycache=self.cache) - else: - self.gen = self.tree.iterate(\ - entrysteps=nevts, keycache=self.cache) + self.gen = self.tree.iterate( \ + self.out_branches, \ + step_size=nevts, \ + library="np") + else: + self.gen = self.tree.iterate( \ + step_size=nevts, \ + library='np') self.branches = {} # start with fresh dict try: self.branches = next(self.gen) # python 3.X and 2.X except StopIteration: if self.out_branches: - self.gen = self.tree.iterate(\ - branches=self.out_branches, \ - entrysteps=nevts, keycache=self.cache) + self.gen = self.tree.iterate( \ + branches=self.out_branches, \ + step_size=nevts, \ + library='np') else: - self.gen = self.tree.iterate(entrysteps=nevts, keycache=self.cache) + self.gen = self.tree.iterate(step_size=nevts, library='np') self.branches = next(self.gen) # python 3.X and 2.X + self.time_reading.append(time.time()-start_time) end_time = time.time() self.idx += nevts @@ -685,18 +680,14 @@ def read_chunk(self, nevts, set_branches=False, set_min_max=False): performance(nevts, self.tree, self.branches, start_time, end_time) if set_branches: for key, val in self.branches.items(): - if isinstance(key, bytes): - key = key.decode() - self.minv[key], self.maxv[key] = min_max_arr(val) - if isinstance(val, JaggedArray): + if isinstance(self.tree[key].interpretation, uproot.AsJagged): self.jkeys.append(key) else: self.fkeys.append(key) + self.minv[key], self.maxv[key] = min_max_arr(self.jkeys, key, val) if set_min_max: for key, val in self.branches.items(): - if isinstance(key, bytes): - key = key.decode() - minv, maxv = min_max_arr(val) + minv, maxv = min_max_arr(self.jkeys, key, val) if minv < self.minv[key]: self.minv[key] = minv if maxv > self.maxv[key]: @@ -760,8 +751,6 @@ def init(self): set_branches = False # we do it once for key in self.jkeys: if key not in self.jdim: - if isinstance(key, bytes): - key = key.decode() self.jdim[key] = 0 dim = dim_jarr(self.fetch_data(key)) if dim > self.jdim.get(key, 0): @@ -785,8 +774,6 @@ def init(self): # initialize all nan values (zeros) in normalize phase-space # this should be done after we get all min/max values for key in self.branches.keys(): - if isinstance(key, bytes): - key = key.decode() self.nans[key] = self.normalize(key, 0) # reset internal indexes since we done with first pass reading diff --git a/src/python/MLaaS4HEP/utils.py b/src/python/MLaaS4HEP/utils.py index 62620c7..b46cb43 100644 --- a/src/python/MLaaS4HEP/utils.py +++ b/src/python/MLaaS4HEP/utils.py @@ -20,12 +20,7 @@ # uproot try: - try: - # uproot verion 3.X - from awkward import JaggedArray - except ImportError: - # uproot verion 2.X - from uproot.interp.jagged import JaggedArray + import uproot except ImportError: pass @@ -121,7 +116,7 @@ def performance(nevts, tree, data, start_time, end_time, msg=""): "helper function to show performance metrics of data read from a given tree" try: nbytes = sum(x.content.nbytes + x.stops.nbytes \ - if isinstance(x, JaggedArray) \ + if isinstance(x, uproot.AsJagged) \ else x.nbytes for x in data.values()) print("# %s entries, %s %sbranches, %s MB, %s sec, %s MB/sec, %s kHz" % \ ( \