Skip to content

Commit

Permalink
added CLI beam width functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
Sinacam committed Apr 9, 2024
1 parent ab7685c commit e07ccdc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/api/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ The simplest usage is::

.. autofunction:: get_positive_labels

.. autoclass:: FlatModel
:members:

.. autoclass:: TreeModel
:members:

Load Dataset
^^^^^^^^^^^^

Expand Down
5 changes: 4 additions & 1 deletion libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"predict_values",
"get_topk_labels",
"get_positive_labels",
"FlatModel",
]


class FlatModel:
"""A model returned from a training function."""

def __init__(
self,
name: str,
Expand Down Expand Up @@ -619,7 +622,7 @@ def train_binary_and_multiclass(


def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
"""Calculates the decision values associated with x.
"""Calculates the decision values associated with x, equivalent to model.predict_values(x).
Args:
model: A model returned from a training function.
Expand Down
4 changes: 3 additions & 1 deletion libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import linear

__all__ = ["train_tree"]
__all__ = ["train_tree", "TreeModel"]


class Node:
Expand Down Expand Up @@ -38,6 +38,8 @@ def dfs(self, visit: Callable[[Node], None]):


class TreeModel:
"""A model returned from train_tree."""

def __init__(
self,
root: Node,
Expand Down
7 changes: 6 additions & 1 deletion linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def linear_test(config, model, datasets, label_mapping):
else:
labels = []
scores = []

predict_kwargs = {}
if model.name == "tree":
predict_kwargs["beam_width"] = config.beam_width

for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size]
preds = linear.predict_values(model, datasets["test"]["x"][slice])
preds = model.predict_values(datasets["test"]["x"][slice], **predict_kwargs)
target = datasets["test"]["y"][slice].toarray()
metrics.update(preds, target)
if k > 0:
Expand Down

0 comments on commit e07ccdc

Please sign in to comment.