From ac6d1d1fe9350cfab45fa7a1a298e4c80dc08475 Mon Sep 17 00:00:00 2001 From: poppoping Date: Mon, 21 Oct 2024 17:06:13 +0800 Subject: [PATCH] explicitly assign arguments to avoid incorrect argument assignments --- libmultilabel/linear/utils.py | 10 ++++++---- linear_trainer.py | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/libmultilabel/linear/utils.py b/libmultilabel/linear/utils.py index 5b258ba6..e27f16c8 100644 --- a/libmultilabel/linear/utils.py +++ b/libmultilabel/linear/utils.py @@ -76,17 +76,18 @@ class MultiLabelEstimator(sklearn.base.BaseEstimator): scoring_metric (str, optional): The scoring metric. Defaults to 'P@1'. """ - def __init__(self, options: str = "", linear_technique: str = "1vsrest", scoring_metric: str = "P@1"): + def __init__(self, options: str = "", linear_technique: str = "1vsrest", scoring_metric: str = "P@1", multiclass: bool = False): super().__init__() self.options = options self.linear_technique = linear_technique self.scoring_metric = scoring_metric self._is_fitted = False + self.multiclass = multiclass def fit(self, X: sparse.csr_matrix, y: sparse.csr_matrix): X, y = sklearn.utils.validation.check_X_y(X, y, accept_sparse=True, multi_output=True) self._is_fitted = True - self.model = LINEAR_TECHNIQUES[self.linear_technique](y, X, self.options) + self.model = LINEAR_TECHNIQUES[self.linear_technique](y, X, options=self.options) return self def predict(self, X: sparse.csr_matrix) -> np.ndarray: @@ -96,8 +97,9 @@ def predict(self, X: sparse.csr_matrix) -> np.ndarray: def score(self, X: sparse.csr_matrix, y: sparse.csr_matrix) -> float: metrics = linear.get_metrics( - [self.scoring_metric], - y.shape[1], + monitor_metrics=[self.scoring_metric], + num_classes=y.shape[1], + multiclass=self.multiclass ) preds = self.predict(X) metrics.update(preds, y.toarray()) diff --git a/linear_trainer.py b/linear_trainer.py index e0e55c04..b0524ee7 100644 --- a/linear_trainer.py +++ b/linear_trainer.py @@ -51,9 +51,9 @@ def linear_train(datasets, config): model = LINEAR_TECHNIQUES[config.linear_technique]( datasets["train"]["y"], datasets["train"]["x"], - config.liblinear_options, - config.tree_degree, - config.tree_max_depth, + options=config.liblinear_options, + K=config.tree_degree, + dmax=config.tree_max_depth, ) else: model = LINEAR_TECHNIQUES[config.linear_technique](