Skip to content

Commit

Permalink
print estimated tree model size
Browse files Browse the repository at this point in the history
  • Loading branch information
ericliu8168 committed Jul 30, 2024
1 parent 2982ed3 commit 436a03e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
32 changes: 32 additions & 0 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sklearn.cluster
import sklearn.preprocessing
from tqdm import tqdm
import psutil

from . import linear

Expand Down Expand Up @@ -108,6 +109,25 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
return scores


def get_estimated_model_size(root, num_nodes):
num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes)
num_nodes = 0
def collect_stat(node: Node):
nonlocal num_nodes
num_nnz_feat[num_nodes] = node.num_nnz_feat

if node.isLeaf():
num_branches[num_nodes] = len(node.label_map)
else:
num_branches[num_nodes] = len(node.children)

num_nodes += 1

root.dfs(collect_stat)

return np.dot(num_nnz_feat, num_branches) * 16


def train_tree(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
Expand Down Expand Up @@ -135,13 +155,25 @@ def train_tree(
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)

num_nodes = 0
used_dim_labels = ((x != 0).T * y).tocsr()

def count(node):
nonlocal num_nodes
num_nodes += 1
node.num_nnz_feat = np.count_nonzero(used_dim_labels[:,node.label_map].sum(axis=1))

root.dfs(count)

# calculate total memory in local machine
total_memory = psutil.virtual_memory().total
print(f'{total_memory / (1024**3):.3f} GB')

model_size = get_estimated_model_size(root, num_nodes)
print(f'*** model_size: {model_size / (1024**3):.3f} GB')

if (total_memory <= model_size):
raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB')

pbar = tqdm(total=num_nodes, disable=not verbose)

def visit(node):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ PyYAML
scikit-learn
scipy
tqdm
psutil

0 comments on commit 436a03e

Please sign in to comment.