diff --git a/docs/api/linear.rst b/docs/api/linear.rst index bf11ecab1..23361a6bf 100644 --- a/docs/api/linear.rst +++ b/docs/api/linear.rst @@ -32,6 +32,12 @@ The simplest usage is:: .. autofunction:: get_positive_labels +.. autoclass:: FlatModel + :members: + +.. autoclass:: TreeModel + :members: + Load Dataset ^^^^^^^^^^^^ diff --git a/libmultilabel/linear/linear.py b/libmultilabel/linear/linear.py index 62b5bbba8..377bc2fc3 100644 --- a/libmultilabel/linear/linear.py +++ b/libmultilabel/linear/linear.py @@ -17,10 +17,13 @@ "predict_values", "get_topk_labels", "get_positive_labels", + "FlatModel", ] class FlatModel: + """A model returned from a training function.""" + def __init__( self, name: str, @@ -619,7 +622,7 @@ def train_binary_and_multiclass( def predict_values(model, x: sparse.csr_matrix) -> np.ndarray: - """Calculates the decision values associated with x. + """Calculates the decision values associated with x, equivalent to model.predict_values(x). Args: model: A model returned from a training function. diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index a6b05f3e9..0c428fec2 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -10,7 +10,7 @@ from . import linear -__all__ = ["train_tree"] +__all__ = ["train_tree", "TreeModel"] class Node: @@ -38,6 +38,8 @@ def dfs(self, visit: Callable[[Node], None]): class TreeModel: + """A model returned from train_tree.""" + def __init__( self, root: Node, diff --git a/linear_trainer.py b/linear_trainer.py index 7ab6c9164..e0e55c049 100644 --- a/linear_trainer.py +++ b/linear_trainer.py @@ -19,9 +19,14 @@ def linear_test(config, model, datasets, label_mapping): else: labels = [] scores = [] + + predict_kwargs = {} + if model.name == "tree": + predict_kwargs["beam_width"] = config.beam_width + for i in tqdm(range(ceil(num_instance / config.eval_batch_size))): slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size] - preds = linear.predict_values(model, datasets["test"]["x"][slice]) + preds = model.predict_values(datasets["test"]["x"][slice], **predict_kwargs) target = datasets["test"]["y"][slice].toarray() metrics.update(preds, target) if k > 0: