diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 61f6a8bb..6cf57469 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -373,7 +373,21 @@ def fit(self, datasets): logger.info(f"Best model loaded from {best_model_path}") logger.info(f"Finish training level 1") - def test(self, dataset, classes): + def test(self, dataset): + # retrieve word_dict from model_1 + # prediction starts from level 0 + model_0 = Model.load_from_checkpoint( + self.get_best_model_path(level=0), + top_k=self.predict_top_k, + ) + model_1 = PLTModel.load_from_checkpoint( + self.get_best_model_path(level=1), + top_k=self.predict_top_k, + metrics=self.metrics, + ) + self.word_dict = model_1.word_dict + classes = model_1.classes + test_x = self.reformat_text(dataset) if self.binarizer is None: @@ -389,27 +403,15 @@ def test(self, dataset, classes): logger=False, ) - # prediction starts from level 0 - model = Model.load_from_checkpoint( - self.get_best_model_path(level=0), - top_k=self.predict_top_k, - ) - test_dataloader = self.eval_dataloader(PlainDataset(test_x)) logger.info(f"Predicting level 0, Top: {self.predict_top_k}") - test_pred = trainer.predict(model, test_dataloader) + test_pred = trainer.predict(model_0, test_dataloader) test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) - model = PLTModel.load_from_checkpoint( - self.get_best_model_path(level=1), - top_k=self.predict_top_k, - metrics=self.metrics, - ) - test_dataloader = self.eval_dataloader( PLTDataset( test_x, @@ -422,7 +424,7 @@ def test(self, dataset, classes): ) logger.info(f"Testing on level 1") - trainer.test(model, test_dataloader) + trainer.test(model_1, test_dataloader) logger.info("Testing process finished") def reformat_text(self, dataset): diff --git a/torch_trainer.py b/torch_trainer.py index 8e23216f..81bf765c 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -37,7 +37,6 @@ def __init__( self.run_name = config.run_name self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path - self.classes = classes os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up seed & device @@ -64,26 +63,41 @@ def __init__( else: self.datasets = datasets - if config.embed_file is not None: - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"] - if config.model_name.lower() != "attentionxml" - else self.datasets["train"] + self.datasets["val"], - vocab_file=config.vocab_file, - min_vocab_freq=config.min_vocab_freq, - embed_file=config.embed_file, - silent=config.silent, - normalize_embed=config.normalize_embed, - embed_cache_dir=config.embed_cache_dir, - ) - - if not classes: - self.classes = data_utils.load_or_build_label(self.datasets, config.label_file, config.include_test_labels) - self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) if self.config.model_name.lower() == "attentionxml": - self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict) + # Note that AttentionXML produces two models. checkpoint_path directs to model_1 + if config.checkpoint_path is None: + if self.config.embed_file is not None: + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + self.datasets["val"], + vocab_file=config.vocab_file, + min_vocab_freq=config.min_vocab_freq, + embed_file=config.embed_file, + silent=config.silent, + normalize_embed=config.normalize_embed, + embed_cache_dir=config.embed_cache_dir, + ) + + if not classes: + classes = data_utils.load_or_build_label( + self.datasets, self.config.label_file, self.config.include_test_labels + ) + + if self.config.early_stopping_metric not in self.config.monitor_metrics: + logging.warning( + f"{self.config.early_stopping_metric} is not in `monitor_metrics`. " + f"Add {self.config.early_stopping_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.early_stopping_metric] + + if self.config.val_metric not in self.config.monitor_metrics: + logging.warn( + f"{self.config.val_metric} is not in `monitor_metrics`. " + f"Add {self.config.val_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.val_metric] + self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict) else: self._setup_model( word_dict=word_dict, @@ -109,6 +123,7 @@ def __init__( def _setup_model( self, + classes: list = None, word_dict: dict = None, embed_vecs=None, log_path: str = None, @@ -134,6 +149,21 @@ def _setup_model( self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) else: logging.info("Initialize model from scratch.") + if self.config.embed_file is not None: + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"], + vocab_file=self.config.vocab_file, + min_vocab_freq=self.config.min_vocab_freq, + embed_file=self.config.embed_file, + silent=self.config.silent, + normalize_embed=self.config.normalize_embed, + embed_cache_dir=self.config.embed_cache_dir, + ) + + if not classes: + classes = data_utils.load_or_build_label( + self.datasets, self.config.label_file, self.config.include_test_labels + ) if self.config.early_stopping_metric not in self.config.monitor_metrics: logging.warn( @@ -152,7 +182,7 @@ def _setup_model( self.model = init_model( model_name=self.config.model_name, network_config=dict(self.config.network_config), - classes=self.classes, + classes=classes, word_dict=word_dict, embed_vecs=embed_vecs, init_weight=self.config.init_weight, @@ -201,6 +231,8 @@ def train(self): """ if self.config.model_name.lower() == "attentionxml": self.trainer.fit(self.datasets) + + dump_log(self.log_path, config=self.config) else: assert ( self.trainer is not None