Skip to content

Commit

Permalink
print estimated model size
Browse files Browse the repository at this point in the history
  • Loading branch information
ericliu8168 authored Jul 24, 2024
1 parent 2982ed3 commit 1e6ad94
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
return scores


def get_estimated_model_size(root):
num_nnz_feat_ls, num_branches_ls = [], []

def collect_stat(node: Node):
num_nnz_feat_ls.append(node.num_nnz_feat)
if node.isLeaf():
num_branches_ls.append(len(node.label_map))
else:
num_branches_ls.append(len(node.children))

root.dfs(collect_stat)

# 12 is because when storing sparse matrices, indices require 4 bytes while floats require 8 bytes
return np.dot(np.array(num_nnz_feat_ls), np.array(num_branches_ls)) * 12


def train_tree(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
Expand Down Expand Up @@ -139,9 +155,14 @@ def train_tree(
def count(node):
nonlocal num_nodes
num_nodes += 1
# count the number of dimentions that label representations used in a node
node.num_nnz_feat = np.count_nonzero(label_representation[node.label_map,:].sum(axis=0))

root.dfs(count)

model_size = get_estimated_model_size(root)
print(f'*** model_size: {model_size / (1024**3):.3f} GB')

pbar = tqdm(total=num_nodes, disable=not verbose)

def visit(node):
Expand Down

0 comments on commit 1e6ad94

Please sign in to comment.