Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Feb 29, 2024
1 parent 3de4019 commit c0ebf36
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 79 deletions.
2 changes: 1 addition & 1 deletion example_config/AmazonCat-13K/attentionxml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ val_metric: nDCG@5
seed: 1337
epochs: 10
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: Adam
optimizer: adam
optimizer_config:
lr: 0.001
# early stopping
Expand Down
2 changes: 1 addition & 1 deletion example_config/EUR-Lex/attentionxml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ seed: 1337
epochs: 30
silent: true
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: Adam
optimizer: adam
# early stopping
patience: 5

Expand Down
2 changes: 1 addition & 1 deletion example_config/Wiki10-31K/attentionxml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ seed: 1337
epochs: 30
silent: true
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: Adam
optimizer: adam
# early stopping
patience: 5

Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data

def load_datasets(
training_data=None,
training_sparse_data=None,
test_data=None,
val_data=None,
val_size=0.2,
Expand Down Expand Up @@ -230,6 +229,7 @@ def load_datasets(
datasets["train"] = _load_raw_data(
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
)

if val_data is not None:
datasets["val"] = _load_raw_data(
val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
):
super().__init__(num_classes=len(classes), log_path=log_path, **kwargs)
self.save_hyperparameters(
ignore=["log_path"]
ignore=["log_path", "embed_vecs", "word_dict", "metrics"]
) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir).
self.word_dict = word_dict
self.embed_vecs = embed_vecs
Expand Down
24 changes: 12 additions & 12 deletions libmultilabel/nn/model_AttentionXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,33 @@ def multilabel_binarize(
def training_step(self, batch, batch_idx):
x = batch["text"]
y = batch["label"]
samples = batch["samples"]
logits = self.network(x, samples=samples)
loss = self.loss_func(logits, torch.take_along_dim(y.float(), samples, dim=1))
labels_selected = batch["labels_selected"]
logits = self.network(x, samples=labels_selected)
loss = self.loss_function(logits, torch.take_along_dim(y.float(), labels_selected, dim=1))
return loss

def validation_step(self, batch, batch_idx):
x = batch["text"]
y = batch["label"]
samples = batch["samples"]
labels_selected = batch["labels_selected"]
label_scores = batch["label_scores"]
logits = self.network(x, samples=samples)
y_pred = self.multilabel_binarize(logits, samples, label_scores)
logits = self.network(x, samples=labels_selected)
y_pred = self.multilabel_binarize(logits, labels_selected, label_scores)
self.val_metric.update(y_pred, y.long())

def test_step(self, batch, batch_idx):
x = batch["text"]
y = batch["label"]
samples = batch["samples"]
labels_selected = batch["labels_selected"]
label_scores = batch["label_scores"]
logits = self.network(x, samples=samples)
y_pred = self.multilabel_binarize(logits, samples, label_scores)
logits = self.network(x, samples=labels_selected)
y_pred = self.multilabel_binarize(logits, labels_selected, label_scores)
self.test_metrics.update(y_pred, y.long())

def predict_step(self, batch, batch_idx, dataloader_idx=0):
x = batch["text"]
samples = batch["samples"]
labels_selected = batch["labels_selected"]
label_scores = batch["label_scores"]
logits = self.network(x, samples=samples)
logits = self.network(x, samples=labels_selected)
scores, labels = torch.topk(torch.sigmoid(logits) * label_scores, self.top_k)
return scores, torch.take_along_dim(samples, labels, dim=1)
return scores, torch.take_along_dim(labels_selected, labels, dim=1)
3 changes: 1 addition & 2 deletions libmultilabel/nn/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from .labelwise_attention_networks import BiLSTMLWAN
from .labelwise_attention_networks import BiLSTMLWMHAN
from .labelwise_attention_networks import CNNLWAN
from .labelwise_attention_networks import AttentionXML_0
from .labelwise_attention_networks import AttentionXML_1
from .labelwise_attention_networks import AttentionXML_0, AttentionXML_1


def get_init_weight_func(init_weight):
Expand Down
14 changes: 8 additions & 6 deletions libmultilabel/nn/networks/labelwise_attention_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
LabelwiseAttention,
LabelwiseMultiHeadAttention,
LabelwiseLinearOutput,
FastLabelwiseAttention,
PartialLabelwiseAttention,
MultilayerLinearOutput,
)

Expand Down Expand Up @@ -291,13 +291,14 @@ def __init__(

def forward(self, inputs):
# the index of padding is 0
inputs = inputs["text"]
masks = inputs != 0
lengths = masks.sum(dim=1)
masks = masks[:, : lengths.max()]

x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size
x = self.encoder(x, lengths) # batch_size, length, hidden_size
x, _ = self.attention(x, masks) # batch_size, num_classes, hidden_size
x, _ = self.attention(x) # batch_size, num_classes, hidden_size
x = self.output(x) # batch_size, num_classes
return {"logits": x}

Expand All @@ -318,17 +319,18 @@ def __init__(
super().__init__()
self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout)
self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout)
self.attention = FastLabelwiseAttention(rnn_dim, num_classes)
self.attention = PartialLabelwiseAttention(rnn_dim, num_classes)
self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1)

def forward(self, inputs, samples):
def forward(self, inputs, labels_selected):
# the index of padding is 0
inputs = inputs["text"]
masks = inputs != 0
lengths = masks.sum(dim=1)
masks = masks[:, : lengths.max()]

x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size
x = self.encoder(x, lengths) # batch_size, length, hidden_size
x, _ = self.attention(x, masks, samples) # batch_size, candidate_size, hidden_size
x = self.output(x) # batch_size, candidate_size
x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size
x = self.output(x) # batch_size, sample_size
return {"logits": x}
50 changes: 19 additions & 31 deletions libmultilabel/nn/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_rnn(self, input_size, hidden_size, num_layers, dropout):
return nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)


class LSTMEncoder(nn.Module):
class LSTMEncoder(RNNEncoder):
"""Bi-directional LSTM encoder with dropout
Args:
Expand All @@ -84,23 +84,10 @@ class LSTMEncoder(nn.Module):
"""

def __init__(self, input_size, hidden_size, num_layers, encoder_dropout=0, post_encoder_dropout=0):
super().__init__()
self.rnn = nn.LSTM(
input_size, hidden_size, num_layers, batch_first=True, dropout=encoder_dropout, bidirectional=True
)
self.h0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size))
self.c0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size))
self.post_encoder_dropout = nn.Dropout(post_encoder_dropout)
super(LSTMEncoder, self).__init__(input_size, hidden_size, num_layers, encoder_dropout, post_encoder_dropout)

def forward(self, inputs, length):
self.rnn.flatten_parameters()
idx = torch.argsort(length, descending=True)
h0 = self.h0.repeat([1, inputs.size(0), 1])
c0 = self.c0.repeat([1, inputs.size(0), 1])
length_clamped = length[idx].cpu().clamp(min=1) # avoid the empty text with length 0
packed_input = pack_padded_sequence(inputs[idx], length_clamped, batch_first=True)
outputs, _ = pad_packed_sequence(self.rnn(packed_input, (h0, c0))[0], batch_first=True)
return self.post_encoder_dropout(outputs[torch.argsort(idx)])
def _get_rnn(self, input_size, hidden_size, num_layers, dropout):
return nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)


class CNNEncoder(nn.Module):
Expand Down Expand Up @@ -174,12 +161,9 @@ def __init__(self, input_size, num_classes):
super(LabelwiseAttention, self).__init__()
self.attention = nn.Linear(input_size, num_classes, bias=False)

def forward(self, input, masks):
def forward(self, input):
# (batch_size, num_classes, sequence_length)
attention = self.attention(input).transpose(1, 2)
if masks is not None:
masks = torch.unsqueeze(masks, 1) # batch_size, 1, length
attention = attention.masked_fill(~masks, -torch.inf) # batch_size, num_classes, length
attention = F.softmax(attention, -1)
# (batch_size, num_classes, hidden_dim)
logits = torch.bmm(attention, input)
Expand Down Expand Up @@ -228,17 +212,21 @@ def forward(self, input):
return (self.output.weight * input).sum(dim=-1) + self.output.bias


class FastLabelwiseAttention(nn.Module):
def __init__(self, hidden_size, num_labels):
class PartialLabelwiseAttention(nn.Module):
"""Similar to LabelwiseAttention.
What makes the class different from LabelwiseAttention is that only the weights of selected labels will be
updated in a single iteration.
"""

def __init__(self, hidden_size, num_classes):
super().__init__()
self.attention = nn.Embedding(num_labels + 1, hidden_size)

def forward(self, inputs, masks, candidates):
masks = torch.unsqueeze(masks, 1) # batch_size, 1, length
attn_inputs = inputs.transpose(1, 2) # batch_size, hidden, length
attn_weights = self.attention(candidates) # batch_size, sample_size, hidden
attention = (attn_weights @ attn_inputs).masked_fill(~masks, -torch.inf) # batch_size, sampled_size, length
attention = F.softmax(attention, -1) # batch_size, sampled_size, length
self.attention = nn.Embedding(num_classes + 1, hidden_size)

def forward(self, inputs, labels_selected):
attn_inputs = inputs.transpose(1, 2) # batch_size, hidden_dim, length
attn_weights = self.attention(labels_selected) # batch_size, sample_size, hidden_dim
attention = attn_weights @ attn_inputs # batch_size, sample_size, length
attention = F.softmax(attention, -1) # batch_size, sample_size, length
logits = attention @ inputs # batch_size, sample_size, hidden_dim
return logits, attention

Expand Down
44 changes: 23 additions & 21 deletions libmultilabel/nn/plt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import time
from functools import partial
from pathlib import Path
from typing import Generator, Optional
Expand Down Expand Up @@ -59,7 +58,7 @@ def __init__(
self.embed_vecs = embed_vecs
self.word_dict = word_dict
self.classes = classes
self.num_labels = len(classes)
self.num_classes = len(classes)

# preprocessor of the datasets
self.preprocessor = None
Expand All @@ -70,11 +69,10 @@ def __init__(
# network parameters
self.network_config = config.network_config
self.init_weight = "xavier_uniform" # AttentionXML-specific setting
self.loss_func = config.loss_func
self.loss_function = config.loss_function

# optimizer parameters
self.optimizer = config.optimizer
self.optimizer_config = config.optimizer_config

# Trainer parameters
self.use_cpu = config.cpu
Expand Down Expand Up @@ -132,7 +130,7 @@ def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]:
Returns:
Generator[csr_matrix]: the mapped labels (ancestor clusters) for train and/or valid datasets.
"""
mapping = np.empty(self.num_labels, dtype=np.uint64)
mapping = np.empty(self.num_classes, dtype=np.uint64)
for idx, clusters in enumerate(cluster_mapping):
mapping[clusters] = idx

Expand Down Expand Up @@ -163,7 +161,10 @@ def fit(self, datasets):
# Convert training texts and labels to a matrix of tfidf features and a binary sparse matrix indicating the
# presence of a class label, respectively
train_val_dataset = datasets["train"] + datasets["val"]
train_val_dataset = {"x": (i["text"] for i in train_val_dataset), "y": (i["label"] for i in train_val_dataset)}
train_val_dataset = {
"x": [" ".join(i["text"]) for i in train_val_dataset],
"y": [i["label"] for i in train_val_dataset],
}

self.preprocessor = Preprocessor()
datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes}
Expand Down Expand Up @@ -203,7 +204,7 @@ def fit(self, datasets):
limit_test_batches=self.limit_test_batches,
save_checkpoints=True,
)
trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}0{ModelCheckpoint.FILE_EXTENSION}"
trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}0{ModelCheckpoint.FILE_EXTENSION}"

best_model_path = self.get_best_model_path(level=0)
if not best_model_path.exists():
Expand All @@ -212,9 +213,9 @@ def fit(self, datasets):
val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered))

model_0 = init_model(
model_name="AttentionXML",
model_name="AttentionXML_0",
network_config=self.config.network_config,
classes=self.classes,
classes=clusters,
word_dict=self.word_dict,
embed_vecs=self.embed_vecs,
init_weight=self.config.init_weight,
Expand All @@ -229,8 +230,8 @@ def fit(self, datasets):
metric_threshold=self.config.metric_threshold,
monitor_metrics=self.config.monitor_metrics,
multiclass=self.config.multiclass,
loss_function=self.config.loss_function,
silent=self.config.silent,
loss_function=self.loss_function,
silent=self.silent,
save_k_predictions=self.config.save_k_predictions,
)

Expand Down Expand Up @@ -313,7 +314,7 @@ def fit(self, datasets):
PLTDataset(
train_x,
train_y,
num_classes=self.num_labels,
num_classes=self.num_classes,
mapping=clusters,
clusters_selected=clusters_selected,
),
Expand All @@ -323,15 +324,15 @@ def fit(self, datasets):
PLTDataset(
val_x,
val_y,
num_classes=self.num_labels,
num_classes=self.num_classes,
mapping=clusters,
clusters_selected=val_clusters_pred,
cluster_scores=val_scores_pred,
),
)

try:
network = getattr(networks, "FastAttentionXML")(
network = getattr(networks, "AttentionXML_1")(
embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config)
)
except Exception:
Expand All @@ -341,14 +342,13 @@ def fit(self, datasets):
network=network,
network_config=self.network_config,
embed_vecs=self.embed_vecs,
num_labels=self.num_labels,
num_classes=self.num_classes,
optimizer=self.optimizer,
metrics=self.metrics,
top_k=self.predict_top_k,
save_k_predictions=self.predict_top_k,
val_metric=self.val_metric,
is_multiclass=self.is_multiclass,
loss_func=self.loss_func,
optimizer_params=self.optimizer_config,
loss_function=self.loss_function,
)
torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight)

Expand All @@ -357,8 +357,10 @@ def fit(self, datasets):
model_1.network.encoder.load_state_dict(model_0.network.encoder)
model_1.network.output.load_state_dict(model_0.network.output)

del model_0

logger.info(
f"Training level 1. Number of labels: {self.num_labels}."
f"Training level 1. Number of labels: {self.num_classes}."
f"Number of clusters selected: {train_dataloader.dataset.num_clusters_selected}"
)
trainer.fit(model_1, train_dataloader, valid_dataloader)
Expand All @@ -376,7 +378,7 @@ def test(self, dataset):
)

# prediction starts from level 0
model = BaseModel.load_from_checkpoint(
model = Model.load_from_checkpoint(
self.get_best_model_path(level=0),
embed_vecs=self.embed_vecs,
top_k=self.predict_top_k,
Expand All @@ -402,7 +404,7 @@ def test(self, dataset):
PLTDataset(
test_x,
test_y,
num_classes=self.num_labels,
num_classes=self.num_classes,
mapping=clusters,
clusters_selected=node_label_pred,
cluster_scores=node_score_pred,
Expand Down
Loading

0 comments on commit c0ebf36

Please sign in to comment.