Skip to content

Commit

Permalink
Merge pull request #101 from lanl/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
MaksimEkin authored Apr 1, 2024
2 parents a7e1036 + eb89c74 commit 943a50c
Show file tree
Hide file tree
Showing 89 changed files with 486 additions and 341 deletions.
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ authors:
- family-names: Alexandrov
given-names: Boian
title: "Tensor Extraction of Latent Features (T-ELF)"
version: 0.0.11
version: 0.0.12
url: https://github.com/lanl/T-ELF
doi: 10.5281/zenodo.10257897
date-released: 2023-12-04
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ If you use T-ELF please cite.

**APA:**
```latex
Eren, M., Solovyev, N., Barron, R., Bhattarai, M., Truong, D., Boureima, I., Skau, E., Rasmussen, K., & Alexandrov, B. (2023). Tensor Extraction of Latent Features (T-ELF) (Version 0.0.11) [Computer software]. https://doi.org/10.5281/zenodo.10257897
Eren, M., Solovyev, N., Barron, R., Bhattarai, M., Truong, D., Boureima, I., Skau, E., Rasmussen, K., & Alexandrov, B. (2023). Tensor Extraction of Latent Features (T-ELF) (Version 0.0.12) [Computer software]. https://doi.org/10.5281/zenodo.10257897
```

**BibTeX:**
Expand Down
66 changes: 44 additions & 22 deletions TELF/factorization/NMFk.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,10 @@ def _nmf_parallel_wrapper(

#
# cluster the solutions
#
#
W, W_clust = custom_k_means(W_all, use_gpu=False)
sils_all = silhouettes(W_clust)
sils_all_W = silhouettes(W_clust)
sils_all_H = silhouettes(np.array(H_all).transpose((1, 0, 2)))

#
# concensus matrix
Expand Down Expand Up @@ -320,7 +321,8 @@ def _nmf_parallel_wrapper(
save_data = {
"W": W,
"H": H,
"sils_all": sils_all,
"sils_all_W": sils_all_W,
"sils_all_H": sils_all_H,
"error_reg": error_reg,
"errors": errors,
"reordered_con_mat": reordered_con_mat,
Expand All @@ -341,12 +343,18 @@ def _nmf_parallel_wrapper(
for key in logging_stats:
if key == 'k':
plot_data["k"] = k
elif key == 'sils_min':
sils_min = np.min(np.mean(sils_all, 1))
plot_data["sils_min"] = '{0:.3f}'.format(sils_min)
elif key == 'sils_mean':
sils_mean = np.mean(np.mean(sils_all, 1))
plot_data["sils_mean"] = '{0:.3f}'.format(sils_mean)
elif key == 'sils_min_W':
sils_min_W = np.min(np.mean(sils_all_W, 1))
plot_data["sils_min_W"] = '{0:.3f}'.format(sils_min_W)
elif key == 'sils_mean_W':
sils_mean_W = np.mean(np.mean(sils_all_W, 1))
plot_data["sils_mean_W"] = '{0:.3f}'.format(sils_mean_W)
elif key == 'sils_min_H':
sils_min_H = np.min(np.mean(sils_all_H, 1))
plot_data["sils_min_H"] = '{0:.3f}'.format(sils_min_H)
elif key == 'sils_mean_H':
sils_mean_H = np.mean(np.mean(sils_all_H, 1))
plot_data["sils_mean_H"] = '{0:.3f}'.format(sils_mean_H)
elif key == 'err_mean':
err_mean = np.mean(errors)
plot_data["err_mean"] = '{0:.3f}'.format(err_mean)
Expand All @@ -373,10 +381,17 @@ def _nmf_parallel_wrapper(
"err_mean":np.mean(errors),
"err_std":np.std(errors),
"err_reg":error_reg,
"sils_min":np.min(np.mean(sils_all, 1)),
"sils_mean":np.mean(np.mean(sils_all, 1)),
"sils_std":np.std(np.mean(sils_all, 1)),
"sils_all":sils_all,

"sils_min_W":np.min(np.mean(sils_all_W, 1)),
"sils_mean_W":np.mean(np.mean(sils_all_W, 1)),
"sils_std_W":np.std(np.mean(sils_all_W, 1)),
"sils_all_W":sils_all_W,

"sils_min_H":np.min(np.mean(sils_all_H, 1)),
"sils_mean_H":np.mean(np.mean(sils_all_H, 1)),
"sils_std_H":np.std(np.mean(sils_all_H, 1)),
"sils_all_H":sils_all_H,

"cophenetic_coeff":coeff_k,
"col_err":curr_col_err,
}
Expand Down Expand Up @@ -404,7 +419,7 @@ def __init__(
save_output=True,
collect_output=False,
predict_k=False,
predict_k_method="pvalue",
predict_k_method="sill",
verbose=True,
nmf_verbose=False,
perturb_verbose=False,
Expand Down Expand Up @@ -462,13 +477,13 @@ def __init__(
Even when ``predict_k=False``, number of latent factors can be estimated using the figures saved in ``save_path``.
predict_k_method : str, optional
Method to use when performing automatic k prediction. Default is "pvalue".\n
Method to use when performing automatic k prediction. Default is "sill".\n
* ``predict_k_method='pvalue'`` will use L-Statistics with column-wise error for automatically estimating the number of latent factors.\n
* ``predict_k_method='sill'`` will use Silhouette score for estimating the number of latent factors.
.. warning::
``predict_k_method='pvalue'`` prediction will result in significantly longer processing time! ``predict_k_method='sill'``, on the other hand, will be much faster.
``predict_k_method='pvalue'`` prediction will result in significantly longer processing time, altough it is more accurate! ``predict_k_method='sill'``, on the other hand, will be much faster.
verbose : bool, optional
If True, shows progress in each k. The default is True.
Expand Down Expand Up @@ -789,8 +804,11 @@ def fit(self, X, Ks, name="NMFk", note=""):
# init the stats header
# this will setup the logging for all configurations of nmfk
stats_header = {'k': 'k',
'sils_min': 'Min. Silhouette',
'sils_mean': 'Mean Silhouette'}
'sils_min_W': 'W Min. Silhouette',
'sils_mean_W': 'W Mean Silhouette',
'sils_min_H': 'H Min. Silhouette',
'sils_mean_H': 'H Mean Silhouette',
}
if self.calculate_error:
stats_header['err_mean'] = 'Mean Error'
stats_header['err_std'] = 'STD Error'
Expand Down Expand Up @@ -950,11 +968,14 @@ def fit(self, X, Ks, name="NMFk", note=""):
if self.predict_k:
if self.predict_k_method == "pvalue":
k_predict = pvalue_analysis(
combined_result["col_err"], Ks, combined_result["sils_min"], SILL_thr=self.sill_thresh
combined_result["col_err"], Ks, combined_result["sils_min_W"], SILL_thr=self.sill_thresh
)[0]
elif self.predict_k_method == "sill":
k_predict = Ks[np.max(np.argwhere(
np.array(combined_result["sils_min"]) >= self.sill_thresh).flatten())]
k_predict_W = Ks[np.max(np.argwhere(
np.array(combined_result["sils_min_W"]) >= self.sill_thresh).flatten())]
k_predict_H = Ks[np.max(np.argwhere(
np.array(combined_result["sils_min_H"]) >= self.sill_thresh).flatten())]
k_predict = min(k_predict_W, k_predict_H)
else:
k_predict = 0

Expand Down Expand Up @@ -998,7 +1019,8 @@ def fit(self, X, Ks, name="NMFk", note=""):
self.save_path_full,
plot_predict=self.predict_k,
plot_final=True,
simple_plot=self.simple_plot
simple_plot=self.simple_plot,
calculate_error=self.calculate_error
)
append_to_note(["#" * 100], self.save_path_full, name=note_name, lock=self.lock)
append_to_note(["end_time= "+str(datetime.now())], self.save_path_full, name=note_name, lock=self.lock)
Expand Down
100 changes: 66 additions & 34 deletions TELF/factorization/utilities/plot_NMFk.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def plot_BNMFk(Ks, sils, bool_err, path=None, name=None):
return None


def plot_NMFk(data, k_predict, name, path, plot_predict=False, plot_final=False, simple_plot=False):
def plot_NMFk(data, k_predict, name, path, plot_predict=False, plot_final=False, simple_plot=False, calculate_error=True):
"""
Expand Down Expand Up @@ -144,33 +144,52 @@ def plot_NMFk(data, k_predict, name, path, plot_predict=False, plot_final=False,
ax1.set_xlabel("latent dimension")

if pac:
ax1.set_ylabel("silhouette - pac", color=color)
ax1.set_ylabel("silhouette - pac", color="black")
else:
ax1.set_ylabel("silhouette", color=color)


# W sill
ax1.plot(
list(data["Ks"]),
data["sils_min"],
data["sils_min_W"],
"o-",
color=color,
label="minimum silhouette",
label="W minimum silhouette",
)

# H sill
ax1.plot(
list(data["Ks"]),
data["sils_min_H"],
"o-.",
color="purple",
label="H minimum silhouette",
)

# add a vertical line for xtick k-values to make it easier to see which point corresponds to which k
if not isinstance(data["sils_min"], np.float64): # if plot contains more than one k-value
if not isinstance(data["sils_min_W"], np.float64): # if plot contains more than one k-value
for xtick in ax1.get_xticks():
if xtick in data["Ks"]:
y = data["sils_min"][np.where(data["Ks"] == xtick)[0][0]] # get the y value that corresponds to xtick
plt.vlines(xtick, min(data["sils_min"] + [0]), y, colors='black', alpha=0.4)
y = data["sils_min_W"][np.where(data["Ks"] == xtick)[0][0]] # get the y value that corresponds to xtick
plt.vlines(xtick, min([0, np.min(data["sils_min_H"]), np.min(data["sils_min_W"])]), y, colors='black', alpha=0.4)

if not simple_plot:
ax1.errorbar(
list(data["Ks"]),
data["sils_mean"],
yerr=data["sils_std"],
data["sils_mean_W"],
yerr=data["sils_std_W"],
fmt="^:",
color="tab:green",
label=r"mean +- std silhouette",
label=r"W mean +- std silhouette",
)

ax1.errorbar(
list(data["Ks"]),
data["sils_mean_H"],
yerr=data["sils_std_H"],
fmt="^-.",
color="tab:green",
label=r"H mean +- std silhouette",
)

# pac
Expand All @@ -184,9 +203,10 @@ def plot_NMFk(data, k_predict, name, path, plot_predict=False, plot_final=False,
color=color,
label="PAC",
)
ax1.set_ylim(min(0, min(np.min(data["sils_min"]), np.min(data["pac"]))), 1)

ax1.set_ylim(min([0, np.min(data["sils_min_H"]), np.min(data["sils_min_W"]), np.min(data["pac"])]), 1)
else:
ax1.set_ylim(min(0, np.min(data["sils_min"])), 1)
ax1.set_ylim(min([0, np.min(data["sils_min_H"]), np.min(data["sils_min_W"])]), 1)

ax1.tick_params(axis="y", labelcolor=color)

Expand All @@ -209,35 +229,47 @@ def plot_NMFk(data, k_predict, name, path, plot_predict=False, plot_final=False,
)

# relative error
ax2 = ax1.twinx()
color = "tab:blue"
ax2.set_ylabel("relative error", color=color)
ax2.plot(
list(data["Ks"]),
data["err_reg"],
"o-",
color=color,
label="regression relative error",
)

if not simple_plot:
ax2.errorbar(
if calculate_error:
ax2 = ax1.twinx()
color = "tab:blue"
ax2.set_ylabel("relative error", color=color)
ax2.plot(
list(data["Ks"]),
data["err_mean"],
yerr=data["err_std"],
fmt="^:",
color="tab:orange",
label="perturbation relative error mean +- std",
data["err_reg"],
"o-",
color=color,
label="regression relative error",
)

if not simple_plot:
ax2.errorbar(
list(data["Ks"]),
data["err_mean"],
yerr=data["err_std"],
fmt="^:",
color="tab:orange",
label="perturbation relative error mean +- std",
)

ax2.tick_params(axis="y", labelcolor=color)
ax2.legend(
ax2.tick_params(axis="y", labelcolor=color)
ax2.legend(
loc="upper right",
bbox_to_anchor=(0.5, -0.07),
fancybox=True,
shadow=True,
handlelength=4,
)

ax1.legend(
loc="upper right",
bbox_to_anchor=(0.5, -0.07),
bbox_to_anchor=(.9, -0.07),
fancybox=True,
shadow=True,
handlelength=4,
)



# finalize
fig.tight_layout()
plt.title(name + " " + str(min(list(data["Ks"]))) + "-" + str(max(list(data["Ks"]))))
Expand Down
2 changes: 1 addition & 1 deletion TELF/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.11'
__version__ = '0.0.12'
6 changes: 3 additions & 3 deletions docs/Beaver.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<title>TELF.pre_processing.Beaver: Fast matrix and tensor building tool &#8212; TELF 0.0.11 documentation</title>
<title>TELF.pre_processing.Beaver: Fast matrix and tensor building tool &#8212; TELF 0.0.12 documentation</title>



Expand Down Expand Up @@ -37,7 +37,7 @@
<link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=365ca57ee442770a23c6" />
<script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=365ca57ee442770a23c6"></script>

<script src="_static/documentation_options.js?v=2fb9ae3b"></script>
<script src="_static/documentation_options.js?v=38cd2e5d"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824"></script>
Expand Down Expand Up @@ -127,7 +127,7 @@



<p class="title logo__title">TELF 0.0.11 documentation</p>
<p class="title logo__title">TELF 0.0.12 documentation</p>

</a></div>
<div class="sidebar-primary-item"><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
Expand Down
6 changes: 3 additions & 3 deletions docs/Cheetah.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />

<title>TELF.applications.Cheetah: Advanced search by keywords and phrases &#8212; TELF 0.0.11 documentation</title>
<title>TELF.applications.Cheetah: Advanced search by keywords and phrases &#8212; TELF 0.0.12 documentation</title>



Expand Down Expand Up @@ -37,7 +37,7 @@
<link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=365ca57ee442770a23c6" />
<script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=365ca57ee442770a23c6"></script>

<script src="_static/documentation_options.js?v=2fb9ae3b"></script>
<script src="_static/documentation_options.js?v=38cd2e5d"></script>
<script src="_static/doctools.js?v=888ff710"></script>
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="_static/scripts/sphinx-book-theme.js?digest=5a5c038af52cf7bc1a1ec88eea08e6366ee68824"></script>
Expand Down Expand Up @@ -127,7 +127,7 @@



<p class="title logo__title">TELF 0.0.11 documentation</p>
<p class="title logo__title">TELF 0.0.12 documentation</p>

</a></div>
<div class="sidebar-primary-item"><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
Expand Down
Loading

0 comments on commit 943a50c

Please sign in to comment.