Skip to content

Commit

Permalink
Merge pull request #163 from lanl/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
MaksimEkin authored Apr 29, 2024
2 parents 014900f + 79c76a4 commit a0b2be7
Show file tree
Hide file tree
Showing 78 changed files with 333 additions and 279 deletions.
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ authors:
- family-names: Alexandrov
given-names: Boian
title: "Tensor Extraction of Latent Features (T-ELF)"
version: 0.0.17
version: 0.0.18
url: https://github.com/lanl/T-ELF
doi: 10.5281/zenodo.10257897
date-released: 2023-12-04
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ If you use T-ELF please cite.

**APA:**
```latex
Eren, M., Solovyev, N., Barron, R., Bhattarai, M., Truong, D., Boureima, I., Skau, E., Rasmussen, K., & Alexandrov, B. (2023). Tensor Extraction of Latent Features (T-ELF) (Version 0.0.17) [Computer software]. https://doi.org/10.5281/zenodo.10257897
Eren, M., Solovyev, N., Barron, R., Bhattarai, M., Truong, D., Boureima, I., Skau, E., Rasmussen, K., & Alexandrov, B. (2023). Tensor Extraction of Latent Features (T-ELF) (Version 0.0.18) [Computer software]. https://doi.org/10.5281/zenodo.10257897
```

**BibTeX:**
Expand Down
81 changes: 57 additions & 24 deletions TELF/factorization/NMFk.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _nmf_parallel_wrapper(
mask=None,
consensus_mat=False,
predict_k=False,
predict_k_method="sill",
predict_k_method="WH_sill",
pruned=True,
perturb_rows=None,
perturb_cols=None,
Expand Down Expand Up @@ -357,9 +357,21 @@ def _nmf_parallel_wrapper(
#
if K_search_settings["k_search_method"] != "linear":
with K_search_settings['lock']:
if min(sils_min_W, sils_min_H) >= K_search_settings["sill_thresh"]:

if predict_k_method in ["WH_sill", "sill"]:
curr_score = min(sils_min_W, sils_min_H)
elif predict_k_method == "W_sill":
curr_score = sils_min_W
elif predict_k_method == "H_sill":
curr_score = sils_min_H
elif predict_k_method == "pvalue":
curr_score = sils_min_W
else:
raise Exception("Unknown predict_k_method!")

if curr_score >= K_search_settings["sill_thresh"]:
K_search_settings['k_min'] = k
if K_search_settings["H_sill_thresh"] >= 0 and (sils_min_H <= K_search_settings["H_sill_thresh"]):
if K_search_settings["H_sill_thresh"] is not None and (sils_min_H <= K_search_settings["H_sill_thresh"]):
K_search_settings['k_max'] = k

if n_nodes > 1:
Expand Down Expand Up @@ -488,7 +500,7 @@ def __init__(
save_output=True,
collect_output=False,
predict_k=False,
predict_k_method="sill",
predict_k_method="WH_sill",
verbose=True,
nmf_verbose=False,
perturb_verbose=False,
Expand All @@ -507,7 +519,7 @@ def __init__(
get_plot_data=False,
simple_plot=True,
k_search_method="linear",
H_sill_thresh=-1
H_sill_thresh=None
):
"""
NMFk is a Non-negative Matrix Factorization module with the capability to do automatic model determination.
Expand Down Expand Up @@ -549,13 +561,16 @@ def __init__(
Even when ``predict_k=False``, number of latent factors can be estimated using the figures saved in ``save_path``.
predict_k_method : str, optional
Method to use when performing automatic k prediction. Default is "sill".\n
Method to use when performing automatic k prediction. Default is "WH_sill".\n
* ``predict_k_method='pvalue'`` will use L-Statistics with column-wise error for automatically estimating the number of latent factors.\n
* ``predict_k_method='sill'`` will use Silhouette score for estimating the number of latent factors.
* ``predict_k_method='WH_sill'`` will use Silhouette scores from minimum of W and H latent factors for estimating the number of latent factors.\n
* ``predict_k_method='W_sill'`` will use Silhouette scores from W latent factor for estimating the number of latent factors.\n
* ``predict_k_method='H_sill'`` will use Silhouette scores from H latent factor for estimating the number of latent factors.\n
* ``predict_k_method='sill'`` will default to ``predict_k_method='WH_sill'``.
.. warning::
``predict_k_method='pvalue'`` prediction will result in significantly longer processing time, altough it is more accurate! ``predict_k_method='sill'``, on the other hand, will be much faster.
``predict_k_method='pvalue'`` prediction will result in significantly longer processing time, altough it is more accurate! ``predict_k_method='WH_sill'``, on the other hand, will be much faster.
verbose : bool, optional
If True, shows progress in each k. The default is True.
Expand Down Expand Up @@ -622,13 +637,13 @@ def __init__(
k_search_method : str, optional
Which approach to use when searching for the rank or k. The default is "linear".\n
* ``k_search_method='linear'`` will linearly visit each K given in ``Ks`` hyper-parameter of the ``fit()`` function.\n
* ``k_search_method='bst_post'`` will perform post-order binary search. When an ideal rank is found with ``min(W silhouette, H silhouette) >= sill_thresh``, all lower ranks are pruned from the search space.
* ``k_search_method='bst_pre'`` will perform pre-order binary search. When an ideal rank is found with ``min(W silhouette, H silhouette) >= sill_thresh``, all lower ranks are pruned from the search space.
* ``k_search_method='bst_post'`` will perform post-order binary search. When an ideal rank is found, determined by the selected ``predict_k_method``, all lower ranks are pruned from the search space.
* ``k_search_method='bst_pre'`` will perform pre-order binary search. When an ideal rank is found, determined by the selected ``predict_k_method``, all lower ranks are pruned from the search space.
H_sill_thresh : float, optional
Setting for removing higher ranks from the search space.\n
When searching for the optimal rank with binary search using ``k_search='bst_post'`` or ``k_search='bst_pre'``, this hyper-parameter can be used to cut off higher ranks from search space.\n
The cut-off of higher ranks from the search space is based on threshold for H silhouette. When a H silhouette below ``H_sill_thresh`` is found for a given rank or K, all higher ranks are removed from the search space.\n
If ``H_sill_thresh=-1``, it is not used. The default is -1.
If ``H_sill_thresh=None``, it is not used. The default is None.
Returns
-------
None.
Expand Down Expand Up @@ -680,7 +695,12 @@ def __init__(

# warnings
assert self.k_search_method in ["linear", "bst_pre", "bst_post"], "Invalid k_search_method method. Choose from linear, bst_pre, or bst_post."
assert self.predict_k_method in ["pvalue", "sill"], "Invalid predict_k_method method. Choose from pvalue, sill."
assert self.predict_k_method in ["pvalue", "WH_sill", "W_sill", "H_sill", "sill"], "Invalid predict_k_method method. Choose from pvalue, WH_sill, W_sill, H_sill, or sill. sill defaults to WH_sill."

if self.predict_k_method == "sill":
self.predict_k_method = "WH_sill"
warnings.warn("predict_k_method is defaulted to WH_sill!")

if self.calculate_pac and not self.consensus_mat:
self.consensus_mat = True
warnings.warn("consensus_mat was False when calculate_pac was True! consensus_mat changed to True.")
Expand Down Expand Up @@ -841,10 +861,12 @@ def fit(self, X, Ks, name="NMFk", note=""):
Ks.sort()
if self.K_search_settings["k_search_method"] != "linear":
node = BST.sorted_array_to_bst(Ks)
if self.K_search_settings["k_search_method"] != "bst_pre":
if self.K_search_settings["k_search_method"] == "bst_pre":
Ks = list(node.preorder())
if self.K_search_settings["k_search_method"] != "bst_post":
if self.K_search_settings["k_search_method"] == "bst_post":
Ks = list(node.postorder())
else:
raise Exception("Unknown k_search_method!")

#
# check X format
Expand Down Expand Up @@ -1136,18 +1158,29 @@ def fit(self, X, Ks, name="NMFk", note=""):
combined_result["col_err"], Ks, combined_result["sils_min_W"], SILL_thr=self.sill_thresh
)[0]

elif self.predict_k_method == "sill":
else:
if self.predict_k_method in ["WH_sill", "sill"]:
curr_sill_max_score = min([max(combined_result["sils_min_W"]), max(combined_result["sils_min_H"])])

elif self.predict_k_method == "W_sill":
curr_sill_max_score = max(combined_result["sils_min_W"])

elif self.predict_k_method == "H_sill":
curr_sill_max_score = max(combined_result["sils_min_H"])

# check if that sill threshold exist
if self.sill_thresh > min([max(combined_result["sils_min_W"]), max(combined_result["sils_min_H"])]):
self.sill_thresh = min([max(combined_result["sils_min_W"]), max(combined_result["sils_min_H"])])
if self.sill_thresh > curr_sill_max_score:
self.sill_thresh = curr_sill_max_score
warnings.warn(f'W or H Silhouettes were all less than sill_thresh. Setting sill_thresh to minimum for K prediction. sill_thresh={round(self.sill_thresh, 3)}')

k_predict_W = Ks[np.max(np.argwhere(
np.array(combined_result["sils_min_W"]) >= self.sill_thresh).flatten())]
k_predict_H = Ks[np.max(np.argwhere(
np.array(combined_result["sils_min_H"]) >= self.sill_thresh).flatten())]
k_predict = min(k_predict_W, k_predict_H)

if self.predict_k_method in ["WH_sill", "sill"]:
k_predict_W = Ks[np.max(np.argwhere(np.array(combined_result["sils_min_W"]) >= self.sill_thresh).flatten())]
k_predict_H = Ks[np.max(np.argwhere(np.array(combined_result["sils_min_H"]) >= self.sill_thresh).flatten())]
k_predict = min(k_predict_W, k_predict_H)
elif self.predict_k_method == "W_sill":
k_predict = Ks[np.max(np.argwhere(np.array(combined_result["sils_min_W"]) >= self.sill_thresh).flatten())]
elif self.predict_k_method == "H_sill":
k_predict = Ks[np.max(np.argwhere(np.array(combined_result["sils_min_H"]) >= self.sill_thresh).flatten())]

else:
k_predict = 0
Expand Down
2 changes: 1 addition & 1 deletion TELF/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.17'
__version__ = '0.0.18'
6 changes: 3 additions & 3 deletions docs/Beaver.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<title>TELF.pre_processing.Beaver: Fast matrix and tensor building tool &#8212; TELF 0.0.17 documentation</title>
<title>TELF.pre_processing.Beaver: Fast matrix and tensor building tool &#8212; TELF 0.0.18 documentation</title>



Expand Down Expand Up @@ -37,7 +37,7 @@
<link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=365ca57ee442770a23c6" />
<script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=365ca57ee442770a23c6"></script>

<script src="_static/documentation_options.js?v=0150aef7"></script>
<script src="_static/documentation_options.js?v=4bf62f09"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824"></script>
Expand Down Expand Up @@ -127,7 +127,7 @@



<p class="title logo__title">TELF 0.0.17 documentation</p>
<p class="title logo__title">TELF 0.0.18 documentation</p>

</a></div>
<div class="sidebar-primary-item"><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
Expand Down
6 changes: 3 additions & 3 deletions docs/Cheetah.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<title>TELF.applications.Cheetah: Advanced search by keywords and phrases &#8212; TELF 0.0.17 documentation</title>
<title>TELF.applications.Cheetah: Advanced search by keywords and phrases &#8212; TELF 0.0.18 documentation</title>



Expand Down Expand Up @@ -37,7 +37,7 @@
<link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=365ca57ee442770a23c6" />
<script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=365ca57ee442770a23c6"></script>

<script src="_static/documentation_options.js?v=0150aef7"></script>
<script src="_static/documentation_options.js?v=4bf62f09"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824"></script>
Expand Down Expand Up @@ -127,7 +127,7 @@



<p class="title logo__title">TELF 0.0.17 documentation</p>
<p class="title logo__title">TELF 0.0.18 documentation</p>

</a></div>
<div class="sidebar-primary-item"><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
Expand Down
6 changes: 3 additions & 3 deletions docs/HNMFk.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<title>TELF.factorization.HNMFk: Hierarchical Non-negative Matrix Factorization with Automatic Model Determination &#8212; TELF 0.0.17 documentation</title>
<title>TELF.factorization.HNMFk: Hierarchical Non-negative Matrix Factorization with Automatic Model Determination &#8212; TELF 0.0.18 documentation</title>



Expand Down Expand Up @@ -37,7 +37,7 @@
<link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=365ca57ee442770a23c6" />
<script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=365ca57ee442770a23c6"></script>

<script src="_static/documentation_options.js?v=0150aef7"></script>
<script src="_static/documentation_options.js?v=4bf62f09"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824"></script>
Expand Down Expand Up @@ -127,7 +127,7 @@



<p class="title logo__title">TELF 0.0.17 documentation</p>
<p class="title logo__title">TELF 0.0.18 documentation</p>

</a></div>
<div class="sidebar-primary-item"><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
Expand Down
Loading

0 comments on commit a0b2be7

Please sign in to comment.