Skip to content

Commit

Permalink
improve test process
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Mar 13, 2024
1 parent 88055ad commit 6b908f9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 34 deletions.
32 changes: 17 additions & 15 deletions libmultilabel/nn/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,21 @@ def fit(self, datasets):
logger.info(f"Best model loaded from {best_model_path}")
logger.info(f"Finish training level 1")

def test(self, dataset, classes):
def test(self, dataset):
# retrieve word_dict from model_1
# prediction starts from level 0
model_0 = Model.load_from_checkpoint(
self.get_best_model_path(level=0),
top_k=self.predict_top_k,
)
model_1 = PLTModel.load_from_checkpoint(
self.get_best_model_path(level=1),
top_k=self.predict_top_k,
metrics=self.metrics,
)
self.word_dict = model_1.word_dict
classes = model_1.classes

test_x = self.reformat_text(dataset)

if self.binarizer is None:
Expand All @@ -389,27 +403,15 @@ def test(self, dataset, classes):
logger=False,
)

# prediction starts from level 0
model = Model.load_from_checkpoint(
self.get_best_model_path(level=0),
top_k=self.predict_top_k,
)

test_dataloader = self.eval_dataloader(PlainDataset(test_x))

logger.info(f"Predicting level 0, Top: {self.predict_top_k}")
test_pred = trainer.predict(model, test_dataloader)
test_pred = trainer.predict(model_0, test_dataloader)
test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred]))
test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred])

clusters = np.load(self.get_cluster_path(), allow_pickle=True)

model = PLTModel.load_from_checkpoint(
self.get_best_model_path(level=1),
top_k=self.predict_top_k,
metrics=self.metrics,
)

test_dataloader = self.eval_dataloader(
PLTDataset(
test_x,
Expand All @@ -422,7 +424,7 @@ def test(self, dataset, classes):
)

logger.info(f"Testing on level 1")
trainer.test(model, test_dataloader)
trainer.test(model_1, test_dataloader)
logger.info("Testing process finished")

def reformat_text(self, dataset):
Expand Down
70 changes: 51 additions & 19 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
self.run_name = config.run_name
self.checkpoint_dir = config.checkpoint_dir
self.log_path = config.log_path
self.classes = classes
os.makedirs(self.checkpoint_dir, exist_ok=True)

# Set up seed & device
Expand All @@ -64,26 +63,41 @@ def __init__(
else:
self.datasets = datasets

if config.embed_file is not None:
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
dataset=self.datasets["train"]
if config.model_name.lower() != "attentionxml"
else self.datasets["train"] + self.datasets["val"],
vocab_file=config.vocab_file,
min_vocab_freq=config.min_vocab_freq,
embed_file=config.embed_file,
silent=config.silent,
normalize_embed=config.normalize_embed,
embed_cache_dir=config.embed_cache_dir,
)

if not classes:
self.classes = data_utils.load_or_build_label(self.datasets, config.label_file, config.include_test_labels)

self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list()))

if self.config.model_name.lower() == "attentionxml":
self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict)
# Note that AttentionXML produces two models. checkpoint_path directs to model_1
if config.checkpoint_path is None:
if self.config.embed_file is not None:
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
dataset=self.datasets["train"] + self.datasets["val"],
vocab_file=config.vocab_file,
min_vocab_freq=config.min_vocab_freq,
embed_file=config.embed_file,
silent=config.silent,
normalize_embed=config.normalize_embed,
embed_cache_dir=config.embed_cache_dir,
)

if not classes:
classes = data_utils.load_or_build_label(
self.datasets, self.config.label_file, self.config.include_test_labels
)

if self.config.early_stopping_metric not in self.config.monitor_metrics:
logging.warning(
f"{self.config.early_stopping_metric} is not in `monitor_metrics`. "
f"Add {self.config.early_stopping_metric} to `monitor_metrics`."
)
self.config.monitor_metrics += [self.config.early_stopping_metric]

if self.config.val_metric not in self.config.monitor_metrics:
logging.warn(
f"{self.config.val_metric} is not in `monitor_metrics`. "
f"Add {self.config.val_metric} to `monitor_metrics`."
)
self.config.monitor_metrics += [self.config.val_metric]
self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict)
else:
self._setup_model(
word_dict=word_dict,
Expand All @@ -109,6 +123,7 @@ def __init__(

def _setup_model(
self,
classes: list = None,
word_dict: dict = None,
embed_vecs=None,
log_path: str = None,
Expand All @@ -134,6 +149,21 @@ def _setup_model(
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)
else:
logging.info("Initialize model from scratch.")
if self.config.embed_file is not None:
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
dataset=self.datasets["train"],
vocab_file=self.config.vocab_file,
min_vocab_freq=self.config.min_vocab_freq,
embed_file=self.config.embed_file,
silent=self.config.silent,
normalize_embed=self.config.normalize_embed,
embed_cache_dir=self.config.embed_cache_dir,
)

if not classes:
classes = data_utils.load_or_build_label(
self.datasets, self.config.label_file, self.config.include_test_labels
)

if self.config.early_stopping_metric not in self.config.monitor_metrics:
logging.warn(
Expand All @@ -152,7 +182,7 @@ def _setup_model(
self.model = init_model(
model_name=self.config.model_name,
network_config=dict(self.config.network_config),
classes=self.classes,
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
init_weight=self.config.init_weight,
Expand Down Expand Up @@ -201,6 +231,8 @@ def train(self):
"""
if self.config.model_name.lower() == "attentionxml":
self.trainer.fit(self.datasets)

dump_log(self.log_path, config=self.config)
else:
assert (
self.trainer is not None
Expand Down

0 comments on commit 6b908f9

Please sign in to comment.