diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index f2f31d7e..e1e3b50f 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -79,16 +79,15 @@ def __init__( self.cluster_scores = cluster_scores self.label_scores = None - # labels_selected are positive clusters at the current level. shape: (len(x), cluster_size * top_k) - # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] + # labels_selected are labels extracted from cluster_selected. + # shape: (len(x), len(clusters_selected) * cluster_size) self.labels_selected = [ np.concatenate(self.mapping[labels]) for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters") ] if self.cluster_scores is not None: - # label_scores are corresponding scores for selected labels and - # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) - # notice how scores repeat for each cluster. + # label_scores are probability scores corresponding to labels_selected. + # shape: (len(x), len(clusters_selected) * cluster_size) self.label_scores = [ np.repeat(scores, [len(i) for i in self.mapping[labels]]) for labels, scores in zip(self.clusters_selected, self.cluster_scores) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 2114aa8b..88b9685f 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -31,7 +31,8 @@ def scatter_logits( labels_selected: Tensor, label_scores: Tensor, ) -> Tensor: - """map predictions from sample space to label space. The scores of unsampled labels are set to 0.""" + """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to + the whole label space. The scores of unsampled labels are set to 0.""" src = torch.sigmoid(logits.detach()) * label_scores preds = torch.zeros( labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index f64c06d3..b91a1368 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -278,7 +278,10 @@ def fit(self, datasets): logger.info(f"Best model loaded from {best_model_path}") model_0 = Model.load_from_checkpoint(best_model_path) - logger.info(f"Generating predictions for level 1. Will use the top {self.predict_top_k} predictions") + logger.info( + f"Predicting clusters by level 0 model. We then select {self.predict_top_k} clusters and use then " + f"to extract labels for level 1 training." + ) # load training and validation data and predict corresponding level 0 clusters train_pred = trainer.predict(model_0, train_dataloader) @@ -288,15 +291,12 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - logger.info( - "Selecting relevant/irrelevant clusters of each instance for generating labels for level 1 training" - ) clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) # Select relevant clusters first. Then from top-predicted clusters, sequentially include them until - # clusters reach top_k + # cluster number reaches predict_top_k if len(pos) <= self.predict_top_k: selected = pos for y in ys: @@ -383,9 +383,9 @@ def fit(self, datasets): silent=self.silent, save_k_predictions=self.predict_top_k, ) + logger.info(f"Initialize model with weights from level 0") + # For weights not initialized by the level-0 model, use xavier uniform initialization torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) - - logger.info(f"Initialize model with weights from the last level") # As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict()) model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict()) @@ -456,6 +456,7 @@ def test(self, dataset): logger.info("Testing process finished") def reformat_text(self, dataset): + # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length. encoded_text = list( map( lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)