Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CLI beam width functionality #367

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ The simplest usage is::

.. autofunction:: get_positive_labels

.. autoclass:: FlatModel
:members:

.. autoclass:: TreeModel
:members:

Load Dataset
^^^^^^^^^^^^

Expand Down
5 changes: 4 additions & 1 deletion libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"predict_values",
"get_topk_labels",
"get_positive_labels",
"FlatModel",
]


class FlatModel:
"""A model returned from a training function."""

def __init__(
self,
name: str,
Expand Down Expand Up @@ -619,7 +622,7 @@ def train_binary_and_multiclass(


def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
"""Calculates the decision values associated with x.
"""Calculates the decision values associated with x, equivalent to model.predict_values(x).
Args:
model: A model returned from a training function.
Expand Down
4 changes: 3 additions & 1 deletion libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import linear

__all__ = ["train_tree"]
__all__ = ["train_tree", "TreeModel"]


class Node:
Expand Down Expand Up @@ -38,6 +38,8 @@ def dfs(self, visit: Callable[[Node], None]):


class TreeModel:
"""A model returned from train_tree."""

def __init__(
self,
root: Node,
Expand Down
7 changes: 6 additions & 1 deletion linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def linear_test(config, model, datasets, label_mapping):
else:
labels = []
scores = []

predict_kwargs = {}
if model.name == "tree":
predict_kwargs["beam_width"] = config.beam_width

for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size]
preds = linear.predict_values(model, datasets["test"]["x"][slice])
preds = model.predict_values(datasets["test"]["x"][slice], **predict_kwargs)
target = datasets["test"]["y"][slice].toarray()
metrics.update(preds, target)
if k > 0:
Expand Down
Loading