Skip to content

Commit

Permalink
fix according to feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Mar 18, 2024
1 parent 669746c commit 15bab60
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
9 changes: 4 additions & 5 deletions libmultilabel/nn/datasets_AttentionXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,15 @@ def __init__(
self.cluster_scores = cluster_scores
self.label_scores = None

# labels_selected are positive clusters at the current level. shape: (len(x), cluster_size * top_k)
# look like [[0, 1, 2, 4, 5, 18, 19,...], ...]
# labels_selected are labels extracted from cluster_selected.
# shape: (len(x), len(clusters_selected) * cluster_size)
self.labels_selected = [
np.concatenate(self.mapping[labels])
for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters")
]
if self.cluster_scores is not None:
# label_scores are corresponding scores for selected labels and
# look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k)
# notice how scores repeat for each cluster.
# label_scores are probability scores corresponding to labels_selected.
# shape: (len(x), len(clusters_selected) * cluster_size)
self.label_scores = [
np.repeat(scores, [len(i) for i in self.mapping[labels]])
for labels, scores in zip(self.clusters_selected, self.cluster_scores)
Expand Down
3 changes: 2 additions & 1 deletion libmultilabel/nn/model_AttentionXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def scatter_logits(
labels_selected: Tensor,
label_scores: Tensor,
) -> Tensor:
"""map predictions from sample space to label space. The scores of unsampled labels are set to 0."""
"""For each instance, we only have predictions on selected labels. This subroutine maps these predictions to
the whole label space. The scores of unsampled labels are set to 0."""
src = torch.sigmoid(logits.detach()) * label_scores
preds = torch.zeros(
labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype
Expand Down
15 changes: 8 additions & 7 deletions libmultilabel/nn/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def fit(self, datasets):
logger.info(f"Best model loaded from {best_model_path}")
model_0 = Model.load_from_checkpoint(best_model_path)

logger.info(f"Generating predictions for level 1. Will use the top {self.predict_top_k} predictions")
logger.info(
f"Predicting clusters by level 0 model. We then select {self.predict_top_k} clusters and use then "
f"to extract labels for level 1 training."
)
# load training and validation data and predict corresponding level 0 clusters

train_pred = trainer.predict(model_0, train_dataloader)
Expand All @@ -288,15 +291,12 @@ def fit(self, datasets):
val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred]))
val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred])

logger.info(
"Selecting relevant/irrelevant clusters of each instance for generating labels for level 1 training"
)
clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64)
for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")):
# relevant clusters are positive
pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]])
# Select relevant clusters first. Then from top-predicted clusters, sequentially include them until
# clusters reach top_k
# cluster number reaches predict_top_k
if len(pos) <= self.predict_top_k:
selected = pos
for y in ys:
Expand Down Expand Up @@ -383,9 +383,9 @@ def fit(self, datasets):
silent=self.silent,
save_k_predictions=self.predict_top_k,
)
logger.info(f"Initialize model with weights from level 0")
# For weights not initialized by the level-0 model, use xavier uniform initialization
torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight)

logger.info(f"Initialize model with weights from the last level")
# As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately
model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict())
model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict())
Expand Down Expand Up @@ -456,6 +456,7 @@ def test(self, dataset):
logger.info("Testing process finished")

def reformat_text(self, dataset):
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
encoded_text = list(
map(
lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)
Expand Down

0 comments on commit 15bab60

Please sign in to comment.