From 9d226414f76d8dde071d1d151430c30e9b782177 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 8 Jan 2024 18:46:52 +0400 Subject: [PATCH 01/29] add AttentionXML --- .../AmazonCat-13K/fastattentionxml.yml | 52 ++ example_config/EUR-Lex/fastattentionxml.yml | 50 ++ .../Wiki10-31K/fastattentionxml.yml | 50 ++ libmultilabel/nn/cluster.py | 124 +++++ libmultilabel/nn/data_utils.py | 16 +- libmultilabel/nn/datasets.py | 149 ++++++ libmultilabel/nn/model.py | 219 +++++++- libmultilabel/nn/networks/__init__.py | 2 + .../networks/labelwise_attention_networks.py | 66 +++ libmultilabel/nn/networks/modules.py | 54 +- libmultilabel/nn/plt.py | 478 ++++++++++++++++++ main.py | 18 + torch_trainer.py | 156 +++--- 13 files changed, 1355 insertions(+), 79 deletions(-) create mode 100644 example_config/AmazonCat-13K/fastattentionxml.yml create mode 100644 example_config/EUR-Lex/fastattentionxml.yml create mode 100644 example_config/Wiki10-31K/fastattentionxml.yml create mode 100644 libmultilabel/nn/cluster.py create mode 100644 libmultilabel/nn/datasets.py create mode 100644 libmultilabel/nn/plt.py diff --git a/example_config/AmazonCat-13K/fastattentionxml.yml b/example_config/AmazonCat-13K/fastattentionxml.yml new file mode 100644 index 000000000..7f84202b4 --- /dev/null +++ b/example_config/AmazonCat-13K/fastattentionxml.yml @@ -0,0 +1,52 @@ +data_name: AmazonCat-13K +training_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train.txt +training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train_ver3.svm +test_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/test.txt +# pretrained embeddings +embed_file: glove.840B.300d +embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding +# save path +result_dir: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/results + +# preprocessing +min_vocab_freq: 1 +max_seq_length: 500 + +# label tree related parameters +cluster_size: 8 +top_k: 64 + +# data +batch_size: 200 +val_size: 4000 +shuffle: true + +# eval +eval_batch_size: 200 +monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] +val_metric: nDCG@5 + +# trainer params +accelerator: gpu + +# train +seed: 1337 +epochs: 10 +# https://github.com/Lightning-AI/lightning/issues/8826 +optimizer: Adam +optimizer_config: + lr: 0.001 +# early stopping +patience: 5 +silent: true + +# model +model_name: FastAttentionXML +loss_func: binary_cross_entropy_with_logits +network_config: + embed_dropout: 0.2 + post_encoder_dropout: 0.5 + rnn_dim: 1024 + rnn_layers: 1 + linear_size: [512, 256] + freeze_embed_training: false diff --git a/example_config/EUR-Lex/fastattentionxml.yml b/example_config/EUR-Lex/fastattentionxml.yml new file mode 100644 index 000000000..186814e5d --- /dev/null +++ b/example_config/EUR-Lex/fastattentionxml.yml @@ -0,0 +1,50 @@ +data_name: EUR-Lex +training_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.txt +training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.svm +test_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/test.txt +# pretrained embeddings +embed_file: glove.840B.300d +embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding +# save path +result_dir: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/results + +# preprocessing +min_vocab_freq: 1 +max_seq_length: 500 + +# AttentionXML-related parameters +cluster_size: 8 +top_k: 64 + +# dataloader +batch_size: 40 +val_size: 200 +shuffle: true + +# eval +eval_batch_size: 40 +monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] +val_metric: nDCG@5 + +# trainer params +accelerator: gpu + +# train +seed: 1337 +epochs: 30 +silent: true +# https://github.com/Lightning-AI/lightning/issues/8826 +optimizer: Adam +# early stopping +patience: 5 + +# model +model_name: FastAttentionXML +loss_func: binary_cross_entropy_with_logits +network_config: + embed_dropout: 0.2 + post_encoder_dropout: 0.5 + rnn_dim: 512 + rnn_layers: 1 + linear_size: [256] + freeze_embed_training: True diff --git a/example_config/Wiki10-31K/fastattentionxml.yml b/example_config/Wiki10-31K/fastattentionxml.yml new file mode 100644 index 000000000..169beef8c --- /dev/null +++ b/example_config/Wiki10-31K/fastattentionxml.yml @@ -0,0 +1,50 @@ +data_name: Wiki10-31K +training_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.txt +training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.svm +test_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/test.txt +# pretrained embeddings +embed_file: glove.840B.300d +embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding +# save path +result_dir: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/results + +# preprocessing +min_vocab_freq: 1 +max_seq_length: 500 + +# label tree related parameters +cluster_size: 8 +top_k: 64 + +# dataloader +batch_size: 40 +val_size: 200 +shuffle: true + +# eval +eval_batch_size: 40 +monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] +val_metric: nDCG@5 + +# trainer params +accelerator: gpu + +# train +seed: 1337 +epochs: 30 +silent: true +# https://github.com/Lightning-AI/lightning/issues/8826 +optimizer: Adam +# early stopping +patience: 5 + +# model +model_name: FastAttentionXML +loss_func: binary_cross_entropy_with_logits +network_config: + embed_dropout: 0.2 + encoder_dropout: 0.5 + rnn_dim: 512 + rnn_layers: 1 + linear_size: [256] + freeze_embed_training: True diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py new file mode 100644 index 000000000..58ca77c3e --- /dev/null +++ b/libmultilabel/nn/cluster.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np +from scipy.sparse import csr_matrix, csc_matrix +from sklearn.preprocessing import normalize +from numpy import ndarray + +__all__ = ["CLUSTER_NAME", "FILE_EXTENSION", "build_label_tree"] + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +CLUSTER_NAME = "label_clusters" +FILE_EXTENSION = ".npy" + + +def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): + """Build label tree described in AttentionXML + + Args: + sparse_x: features extracted from texts in CSR sparse format + sparse_y: labels in CSR sparse format + cluster_size: the maximal number of labels inside a cluster + output_dir: directory to save the clusters + """ + # skip label clustering if the clustering file already exists + output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir) + cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}" + if cluster_path.exists(): + logger.info("Clustering has finished in a previous run") + return + + # cluster meta info + logger.info("Performing label clustering") + logger.info(f"Cluster size: {cluster_size}") + num_labels = sparse_y.shape[1] + # the height of the tree satisfies the following inequation: + # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size + height = int(np.ceil(np.log2(num_labels / cluster_size))) + logger.info(f"Labels will be grouped into {2**height} clusters") + + output_dir.mkdir(parents=True, exist_ok=True) + + # generate label representations + label_repr = normalize(sparse_y.T @ csc_matrix(sparse_x)) + + # clustering process + rng = np.random.default_rng() + clusters = [np.arange(num_labels)] + for _ in range(height): + assert sum(map(len, clusters)) == num_labels + + next_clusters = [] + for idx_in_cluster in clusters: + next_clusters.extend(_cluster_split(idx_in_cluster, label_repr, rng)) + clusters = next_clusters + logger.info(f"Having grouped {len(clusters)} clusters") + + np.save(cluster_path, np.asarray(clusters, dtype=object)) + logger.info(f"Finish clustering, saving cluster to '{cluster_path}'") + + +def _cluster_split( + idx_in_cluster: ndarray, label_repr: csr_matrix, rng: np.random.Generator +) -> tuple[ndarray, ndarray]: + """A variant of KMeans implemented in AttentionXML. Its main differences with sklearn.KMeans are: + 1. the distance metric is cosine similarity as all label representations are normalized. + 2. the end-of-loop criterion is the difference between the new and old average in-cluster distance to centroids. + Possible drawbacks: + Random initialization. + cluster_size matters. + """ + # tol is a possible hyperparameter + tol = 1e-4 + if tol <= 0 or tol > 1: + raise ValueError(f"tol should be a positive number that is less than 1, got {repr(tol)} instead.") + + # the corresponding label representations in the node + tgt_repr = label_repr[idx_in_cluster] + + # the number of leaf labels in the node + n = len(idx_in_cluster) + + # randomly choose two points as initial centroids + centroids = tgt_repr[rng.choice(n, size=2, replace=False)].toarray() + + # initialize distances (cosine similarity) + old_dist = -2.0 + new_dist = -1.0 + + # "c" denotes clusters + c0_idx = None + c1_idx = None + + while new_dist - old_dist >= tol: + # each points' distance (cosine similarity) to the two centroids + dist = tgt_repr @ centroids.T # shape: (n, 2) + + # generate clusters + # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to the c1 + k = n // 2 + c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k) + c0_idx = c_idx[:k] + c1_idx = c_idx[k:] + + # update distances + # the distance is the average in-cluster distance to the centroids + old_dist = new_dist + new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / n + + # update centroids + # the new centroid is the average of the points in the cluster + centroids = normalize( + np.asarray( + [ + np.squeeze(np.asarray(tgt_repr[c0_idx].sum(axis=0))), + np.squeeze(np.asarray(tgt_repr[c1_idx].sum(axis=0))), + ] + ) + ) + return idx_in_cluster[c0_idx], idx_in_cluster[c1_idx] diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index d558737de..8efd27c43 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -7,8 +7,10 @@ import torch import transformers from nltk.tokenize import RegexpTokenizer +from scipy.sparse import issparse +from sklearn.datasets import load_svmlight_file from sklearn.model_selection import train_test_split -from sklearn.preprocessing import MultiLabelBinarizer +from sklearn.preprocessing import MultiLabelBinarizer, normalize from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab @@ -17,6 +19,7 @@ transformers.logging.set_verbosity_error() warnings.simplefilter(action="ignore", category=FutureWarning) +# selection of UNK: https://groups.google.com/g/globalvectors/c/9w8ZADXJclA/m/hRdn4prm-XUJ UNK = "" PAD = "" @@ -194,6 +197,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data def load_datasets( training_data=None, + training_sparse_data=None, test_data=None, val_data=None, val_size=0.2, @@ -206,6 +210,7 @@ def load_datasets( Args: training_data (Union[str, pandas,.Dataframe], optional): Path to training data or a dataframe. + training_sparse_data (Union[str, pandas,.Dataframe], optional): Path to training sparse data or a dataframe in libsvm format. test_data (Union[str, pandas,.Dataframe], optional): Path to test data or a dataframe. val_data (Union[str, pandas,.Dataframe], optional): Path to validation data or a dataframe. val_size (float, optional): Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set. @@ -229,11 +234,16 @@ def load_datasets( training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data ) + if training_sparse_data is not None: + logging.info(f"Loading sparse training data") + datasets["train_sparse_x"] = normalize(load_svmlight_file(training_sparse_data, multilabel=True)[0]) + if val_data is not None: datasets["val"] = _load_raw_data( val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data ) elif val_size > 0: + datasets["train_full"] = datasets["train"] datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42) if test_data is not None: @@ -249,7 +259,7 @@ def load_datasets( del datasets["val"] gc.collect() - msg = " / ".join(f"{k}: {len(v)}" for k, v in datasets.items()) + msg = " / ".join(f"{k}: {v.shape[0] if issparse(v) else len(v)}" for k, v in datasets.items()) logging.info(f"Finish loading dataset ({msg})") return datasets @@ -335,7 +345,7 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False): classes = set() for split, data in datasets.items(): - if split == "test" and not include_test_labels: + if (split == "test" and not include_test_labels) or split == "train_sparse_x": continue for instance in data: classes.update(instance["label"]) diff --git a/libmultilabel/nn/datasets.py b/libmultilabel/nn/datasets.py new file mode 100644 index 000000000..6dbc589d1 --- /dev/null +++ b/libmultilabel/nn/datasets.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import Sequence, Optional + +import numpy as np +import torch +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from numpy import ndarray +from scipy.sparse import csr_matrix, issparse +from torch import Tensor, is_tensor +from torch.utils.data import Dataset +from tqdm import tqdm + + +class MultiLabelDataset(Dataset): + """Basic class for multi-label dataset.""" + + def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): + """General dataset class for multi-label dataset. + + Args: + x: text. + y: labels. + """ + if y is not None: + assert len(x) == y.shape[0], "Sizes mismatch between x and y" + self.x = x + self.y = y + + def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: + x = self.x[idx] + + # train/valid/test + if self.y is not None: + if issparse(self.y): + y = self.y[idx].toarray().squeeze(0) + elif is_tensor(self.y) or isinstance(self.y, (ndarray, torch.Tensor)): + y = self.y[idx] + else: + raise TypeError( + "The type of y should be one of scipy.csr_matrix, torch.Tensor, and numpy.ndarry." + f"Instead, got {type(self.y)}." + ) + return x, y + # predict + return x + + def __len__(self): + return len(self.x) + + +class PLTDataset(MultiLabelDataset): + """Dataset class for AttentionXML.""" + + def __init__( + self, + x, + y: Optional[csr_matrix | ndarray] = None, + *, + num_labels: int, + mapping: ndarray, + node_label: ndarray | Tensor, + node_score: Optional[ndarray | Tensor] = None, + ): + """Dataset for FastAttentionXML. + ~ means variable length. + + Args: + x: text + y: labels + num_labels: number of nodes at the current level. + mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), ~cluster_size). parent nodes to child nodes. + Cluster size will only vary at the last level. + node_label: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes + from last level. + node_score: corresponding scores. shape: (len(x), top_k) + """ + super().__init__(x, y) + self.num_labels = num_labels + self.mapping = mapping + self.node_label = node_label + self.node_score = node_score + self.candidate_scores = None + + # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) + # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] + prog = rank_zero_only(tqdm)(self.node_label, leave=False, desc="Generating candidates") + if prog is None: + prog = self.node_label + self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] + if self.node_score is not None: + # candidate_scores are corresponding scores for candidates and + # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) + # notice how scores repeat for each cluster. + self.candidate_scores = [ + np.repeat(scores, [len(i) for i in self.mapping[labels]]) + for labels, scores in zip(self.node_label, self.node_score) + ] + + # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. + self.num_candidates = self.node_label.shape[1] * max(len(node) for node in self.mapping) + + def __getitem__(self, idx: int): + x = self.x[idx] + candidates = np.asarray(self.candidates[idx], dtype=np.int64) + + # train/valid/test + if self.y is not None: + # squeezing is necessary here because csr_matrix.toarray() always returns a 2d array + # e.g., np.ndarray([[0, 1, 2]]) + y = self.y[idx].toarray().squeeze(0) + + # train + if self.candidate_scores is None: + # randomly select nodes as candidates when less than required + if len(candidates) < self.num_candidates: + sample = np.random.randint(self.num_labels, size=self.num_candidates - len(candidates)) + candidates = np.concatenate([candidates, sample]) + return x, y, candidates + + # valid/test + else: + candidate_scores = self.candidate_scores[idx] + + # add dummy elements when less than required + if len(candidates) < self.num_candidates: + candidate_scores = np.concatenate( + [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] + ) + candidates = np.concatenate( + [candidates, [self.num_labels] * (self.num_candidates - len(candidates))] + ) + + candidate_scores = np.asarray(candidate_scores, dtype=np.float32) + return x, y, candidates, candidate_scores + + # predict + else: + candidate_scores = self.candidate_scores[idx] + + # add dummy elements when less than required + if len(candidates) < self.num_candidates: + candidate_scores = np.concatenate( + [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] + ) + candidates = np.concatenate([candidates, [self.num_labels] * (self.num_candidates - len(candidates))]) + + candidate_scores = np.asarray(candidate_scores, dtype=np.float32) + return x, candidates, candidate_scores diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 040d35fb8..d4bf6201e 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -1,13 +1,19 @@ from abc import abstractmethod +from typing import Optional import lightning as L import numpy as np import torch import torch.nn.functional as F import torch.optim as optim +from lightning import LightningModule +from torch import nn, Tensor +from torch.nn import Module +from torch.optim import Optimizer from ..common_utils import argsort_top_k, dump_log from ..nn.metrics import get_metrics, tabulate_metrics +from libmultilabel.nn import networks class MultiLabelModel(L.LightningModule): @@ -43,7 +49,7 @@ def __init__( multiclass=False, silent=False, save_k_predictions=0, - **kwargs + **kwargs, ): super().__init__() @@ -197,7 +203,7 @@ def __init__( network, loss_function="binary_cross_entropy_with_logits", log_path=None, - **kwargs + **kwargs, ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( @@ -232,3 +238,212 @@ def shared_step(self, batch): loss = self.loss_function(pred_logits, target_labels.float()) return loss, pred_logits + + +class BaseModel(LightningModule): + def __init__( + self, + network: str, + network_config: dict, + embed_vecs: Tensor, + num_labels: int, + optimizer: str, + metrics: list[str], + val_metric: str, + top_k: int, + is_multiclass: bool, + init_weight: Optional[str] = None, + loss_func: str = "binary_cross_entropy_with_logits", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + metric_threshold: int = 0.5, + ): + super().__init__() + self.save_hyperparameters(ignore="embed_vecs") + + self.network = getattr(networks, network)(embed_vecs=embed_vecs, num_classes=num_labels, **network_config) + self.init_weight = init_weight + if init_weight is not None: + init_weight = networks.get_init_weight_func(init_weight=init_weight) + self.network.apply(init_weight) + + self.loss_func = self.configure_loss_func(loss_func) + + # optimizer config + self.optimizer_name = optimizer.lower() + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + + self.lr_scheduler_name = lr_scheduler + + self.top_k = top_k + + self.num_labels = num_labels + + self.metric_list = metrics + self.val_metric_name = val_metric + self.test_metric_names = metrics + self.is_multiclass = is_multiclass + self.metric_threshold = metric_threshold + + @staticmethod + def configure_loss_func(loss_func: str) -> Module: + try: + loss_func = getattr(F, loss_func) + except AttributeError: + raise AttributeError(f"Invalid loss function name: {loss_func}") + return loss_func + + def configure_optimizers(self) -> Optimizer: + parameters = [p for p in self.parameters() if p.requires_grad] + + if self.optimizer_name == "sgd": + optimizer = optim.SGD + elif self.optimizer_name == "adam": + optimizer = optim.Adam + elif self.optimizer_name == "adamw": + optimizer = optim.AdamW + elif self.optimizer_name == "adamax": + optimizer = optim.Adamax + else: + raise ValueError(f"Unsupported optimizer: {self.optimizer}") + + optimizer = optimizer(parameters, **self.optimizer_params) + if self.lr_scheduler_name is None: + return optimizer + + if self.lr_scheduler_name is not None and self.lr_scheduler_name.lower() == "reducelronplateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min" if self.val_metric_name == "loss" else "max", + ) + else: + raise ValueError(f"Unsupported learning rate scheduler: {self.lr_scheduler}") + + lr_scheduler_config = { + "scheduler": lr_scheduler, + "monitor": self.val_metric_name, + } + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + def on_fit_start(self): + self.val_metric = get_metrics( + metric_threshold=self.metric_threshold, + monitor_metrics=[self.val_metric_name], + num_classes=self.num_labels, + top_k=1 if self.is_multiclass else None, + ).to(self.device) + + def training_step(self, batch: Tensor, batch_idx: int): + x, y = batch + logits = self.network(x) + loss = self.loss_func(logits, y.float()) + return loss + + def validation_step(self, batch: Tensor, batch_idx: int): + x, y = batch + logits = self.network(x) + self.val_metric.update(torch.sigmoid(logits), y.long()) + + def on_validation_epoch_end(self): + self.log_dict(self.val_metric.compute(), prog_bar=True) + self.val_metric.reset() + + def on_test_start(self): + self.test_metrics = get_metrics( + metric_threshold=self.metric_threshold, + monitor_metrics=self.test_metric_names, + num_classes=self.num_labels, + top_k=1 if self.is_multiclass else None, + ).to(self.device) + + def test_step(self, batch: Tensor, batch_idx: int): + x, y = batch + logits = self.network(x) + self.test_metrics.update(torch.sigmoid(logits), y.long()) + + def on_test_epoch_end(self): + self.log_dict(self.test_metrics.compute()) + self.test_metrics.reset() + + def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0): + # lightning will put tensors on cpu + x = batch + logits = self.network(x) + scores, labels = torch.topk(torch.sigmoid(logits), self.top_k) + return scores, labels + + def forward(self, x): + return self.network(x) + + +class PLTModel(BaseModel): + def __init__( + self, + network: str, + network_config: dict, + embed_vecs: Tensor, + num_labels: int, + optimizer: str, + metrics: list[str], + val_metric: str, + top_k: int, + is_multiclass: bool, + loss_func: str = "binary_cross_entropy_with_logits", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + ): + super().__init__( + network=network, + network_config=network_config, + embed_vecs=embed_vecs, + num_labels=num_labels, + optimizer=optimizer, + metrics=metrics, + val_metric=val_metric, + top_k=top_k, + is_multiclass=is_multiclass, + loss_func=loss_func, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + ) + + def multilabel_binarize( + self, + logits: Tensor, + candidates: Tensor, + candidate_scores: Tensor, + ) -> Tensor: + """self-implemented MultiLabelBinarizer for AttentionXML""" + src = torch.sigmoid(logits.detach()) * candidate_scores + # make sure preds and src use the same precision, e.g., either float16 or float32 + preds = torch.zeros(candidates.size(0), self.num_labels + 1, device=candidates.device, dtype=src.dtype) + preds.scatter_(dim=1, index=candidates, src=src) + # remove dummy samples + preds = preds[:, :-1] + return preds + + def training_step(self, batch, batch_idx): + x, y, candidates = batch + logits = self.network(x, candidates=candidates) + loss = self.loss_func(logits, torch.take_along_dim(y.float(), candidates, dim=1)) + return loss + + def validation_step(self, batch, batch_idx): + x, y, candidates, candidate_scores = batch + logits = self.network(x, candidates=candidates) + # FIXME: Cannot calculate loss, candidates might contain element whose value is self.num_labels (see dataset.py) + # loss = self.loss_func(logits, torch.from_numpy(np.concatenate([y[:, candidates], offset]))) + y_pred = self.multilabel_binarize(logits, candidates, candidate_scores) + self.val_metric.update(y_pred, y.long()) + + def test_step(self, batch, batch_idx): + x, y, candidates, candidate_scores = batch + logits = self.network(x, candidates=candidates) + y_pred = self.multilabel_binarize(logits, candidates, candidate_scores) + self.test_metrics.update(y_pred, y.long()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + x, candidates, candidate_scores = batch + logits = self.network(x, candidates=candidates) + scores, labels = torch.topk(torch.sigmoid(logits) * candidate_scores, self.top_k) + return scores, torch.take_along_dim(candidates, labels, dim=1) diff --git a/libmultilabel/nn/networks/__init__.py b/libmultilabel/nn/networks/__init__.py index a56678565..a27a5c2a2 100644 --- a/libmultilabel/nn/networks/__init__.py +++ b/libmultilabel/nn/networks/__init__.py @@ -9,6 +9,8 @@ from .labelwise_attention_networks import BiLSTMLWAN from .labelwise_attention_networks import BiLSTMLWMHAN from .labelwise_attention_networks import CNNLWAN +from .labelwise_attention_networks import AttentionRNN as AttentionXML +from .labelwise_attention_networks import FastAttentionRNN as FastAttentionXML def get_init_weight_func(init_weight): diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index 357ec7e2a..ab8799770 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -10,6 +10,8 @@ LabelwiseAttention, LabelwiseMultiHeadAttention, LabelwiseLinearOutput, + FastLabelwiseAttention, + MultilayerLinearOutput, ) @@ -266,3 +268,67 @@ def forward(self, input): x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim) x = self.output(x) # (batch_size, num_classes) return {"logits": x} + + +class AttentionRNN(nn.Module): + def __init__( + self, + embed_vecs, + num_classes: int, + rnn_dim: int, + linear_size: list[int, ...], + freeze_embed_training: bool = False, + rnn_layers: int = 1, + embed_dropout: float = 0.2, + encoder_dropout: float = 0, + post_encoder_dropout: float = 0.5, + ): + super().__init__() + self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout) + self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout) + self.attention = LabelwiseAttention(rnn_dim, num_classes) + self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) + + def forward(self, inputs): + # the index of padding is 0 + masks = inputs != 0 + lengths = masks.sum(dim=1) + masks = masks[:, : lengths.max()] + + x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size + x = self.encoder(x, lengths) # batch_size, length, hidden_size + x, _ = self.attention(x, masks) # batch_size, num_classes, hidden_size + x = self.output(x) # batch_size, num_classes + return x + + +class FastAttentionRNN(nn.Module): + def __init__( + self, + embed_vecs, + num_classes: int, + rnn_dim: int, + linear_size: list[int], + freeze_embed_training: bool = False, + rnn_layers: int = 1, + embed_dropout: float = 0.2, + encoder_dropout: float = 0, + post_encoder_dropout: float = 0.5, + ): + super().__init__() + self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout) + self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout) + self.attention = FastLabelwiseAttention(rnn_dim, num_classes) + self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) + + def forward(self, inputs, candidates): + # the index of padding is 0 + masks = inputs != 0 + lengths = masks.sum(dim=1) + masks = masks[:, : lengths.max()] + + x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size + x = self.encoder(x, lengths) # batch_size, length, hidden_size + x, _ = self.attention(x, masks, candidates) # batch_size, candidate_size, hidden_size + x = self.output(x) # batch_size, candidate_size + return x diff --git a/libmultilabel/nn/networks/modules.py b/libmultilabel/nn/networks/modules.py index 34b8c0dc3..4eda3962d 100644 --- a/libmultilabel/nn/networks/modules.py +++ b/libmultilabel/nn/networks/modules.py @@ -72,7 +72,7 @@ def _get_rnn(self, input_size, hidden_size, num_layers, dropout): return nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True) -class LSTMEncoder(RNNEncoder): +class LSTMEncoder(nn.Module): """Bi-directional LSTM encoder with dropout Args: @@ -84,10 +84,23 @@ class LSTMEncoder(RNNEncoder): """ def __init__(self, input_size, hidden_size, num_layers, encoder_dropout=0, post_encoder_dropout=0): - super(LSTMEncoder, self).__init__(input_size, hidden_size, num_layers, encoder_dropout, post_encoder_dropout) + super().__init__() + self.rnn = nn.LSTM( + input_size, hidden_size, num_layers, batch_first=True, dropout=encoder_dropout, bidirectional=True + ) + self.h0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size)) + self.c0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size)) + self.post_encoder_dropout = nn.Dropout(post_encoder_dropout) - def _get_rnn(self, input_size, hidden_size, num_layers, dropout): - return nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True) + def forward(self, inputs, length): + self.rnn.flatten_parameters() + idx = torch.argsort(length, descending=True) + h0 = self.h0.repeat([1, inputs.size(0), 1]) + c0 = self.c0.repeat([1, inputs.size(0), 1]) + length_clamped = length[idx].cpu().clamp(min=1) # avoid the empty text with length 0 + packed_input = pack_padded_sequence(inputs[idx], length_clamped, batch_first=True) + outputs, _ = pad_packed_sequence(self.rnn(packed_input, (h0, c0))[0], batch_first=True) + return self.post_encoder_dropout(outputs[torch.argsort(idx)]) class CNNEncoder(nn.Module): @@ -161,9 +174,12 @@ def __init__(self, input_size, num_classes): super(LabelwiseAttention, self).__init__() self.attention = nn.Linear(input_size, num_classes, bias=False) - def forward(self, input): + def forward(self, input, masks): # (batch_size, num_classes, sequence_length) attention = self.attention(input).transpose(1, 2) + if masks is not None: + masks = torch.unsqueeze(masks, 1) # batch_size, 1, length + attention = attention.masked_fill(~masks, -torch.inf) # batch_size, num_classes, length attention = F.softmax(attention, -1) # (batch_size, num_classes, hidden_dim) logits = torch.bmm(attention, input) @@ -210,3 +226,31 @@ def __init__(self, input_size, num_classes): def forward(self, input): return (self.output.weight * input).sum(dim=-1) + self.output.bias + + +class FastLabelwiseAttention(nn.Module): + def __init__(self, hidden_size, num_labels): + super().__init__() + self.attention = nn.Embedding(num_labels + 1, hidden_size) + + def forward(self, inputs, masks, candidates): + masks = torch.unsqueeze(masks, 1) # batch_size, 1, length + attn_inputs = inputs.transpose(1, 2) # batch_size, hidden, length + attn_weights = self.attention(candidates) # batch_size, sample_size, hidden + attention = (attn_weights @ attn_inputs).masked_fill(~masks, -torch.inf) # batch_size, sampled_size, length + attention = F.softmax(attention, -1) # batch_size, sampled_size, length + logits = attention @ inputs # batch_size, sample_size, hidden_dim + return logits, attention + + +class MultilayerLinearOutput(nn.Module): + def __init__(self, linear_size: list[int], output_size: int): + super().__init__() + self.linears = nn.ModuleList(nn.Linear(in_s, out_s) for in_s, out_s in zip(linear_size[:-1], linear_size[1:])) + self.output = nn.Linear(linear_size[-1], output_size) + + def forward(self, inputs): + linear_out = inputs + for linear in self.linears: + linear_out = F.relu(linear(linear_out)) + return torch.squeeze(self.output(linear_out), -1) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py new file mode 100644 index 000000000..02116fcd0 --- /dev/null +++ b/libmultilabel/nn/plt.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import logging +import time +from concurrent.futures import ThreadPoolExecutor +from functools import reduce, partial +from pathlib import Path +from typing import Generator, Optional + +import numpy as np +import torch +import torch.distributed as dist +from lightning import Trainer +from scipy.sparse import csr_matrix +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from tqdm import tqdm +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.utilities import rank_zero_only, rank_zero_info + +from .cluster import CLUSTER_NAME, FILE_EXTENSION as CLUSTER_FILE_EXTENSION, build_label_tree +from .data_utils import UNK +from .datasets import MultiLabelDataset, PLTDataset +from .model import PLTModel, BaseModel + +__all__ = ["PLTTrainer"] + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class PLTTrainer: + CHECKPOINT_NAME = "model-level-" + + def __init__( + self, + config, + classes: Optional[list] = None, # TODO: removed in the future + embed_vecs: Optional[Tensor] = None, + word_dict: Optional[dict] = None, # TODO: removed in the future + mlb=None, # TODO: removed in the future + ): + # The number of levels is set to 2 + # In other words, there will be 2 models + + # cluster + self.cluster_size = config.cluster_size + # predict the top k labels + self.top_k = config.top_k + + # dataset meta info + self.embed_vecs = embed_vecs + self.word_dict = word_dict + self.mlb = mlb + self.num_labels = len(classes) + self.is_multiclass = config.multiclass + + # cluster meta info + self.cluster_size = config.cluster_size + + # network parameters + self.network_config = config.network_config + self.init_weight = "xavier_uniform" # AttentionXML-specific setting + self.loss_func = config.loss_func + + # optimizer parameters + self.optimizer = config.optimizer + self.optimizer_config = config.optimizer_config + + # Trainer parameters + self.accelerator = config.accelerator + self.devices = 1 + self.num_nodes = 1 + self.max_epochs = config.epochs + # callbacks + self.val_metric = config.val_metric + self.verbose = not config.silent + # EarlyStopping + self.patience = config.patience + # ModelCheckpoint + self.result_dir = Path(config.result_dir) + # SWA/EMA + # to understand how SWA work, see the pytorch doc and the following link + # https://stackoverflow.com/questions/68726290/setting-learning-rate-for-stochastic-weight-averaging-in-pytorch + # self.swa = config.get("swa") + # + # if (swa_config := config.get("swa")) is not None: + # self.swa_lr = swa_config.get("swa_lr", 5e-2) + # self.swa_epoch_start = swa_config.get("swa_epoch_start") + # self.annealing_epochs = swa_config.get("annealing_epochs", 10) + # self.annealing_strategy = swa_config.get("annealing_strategy", "cos") + # # TODO: SWA or EMA? + # + # # self.avg_fn = None # None == SWA + # def ema_avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: + # decay = 1.0 - 1.0 / num_averaged + # return torch.optim.swa_utils.get_ema_avg_fn(decay=decay) + # + # self.avg_fn = ema_avg_fn + + self.metrics = config.monitor_metrics + + # dataloader parameters + # whether shuffle the training dataset or not during the training process + self.shuffle = config.shuffle + pin_memory = True if self.accelerator == "gpu" else False + # training DataLoader + self.dataloader = partial( + DataLoader, + batch_size=config.batch_size, + num_workers=config.data_workers, + pin_memory=pin_memory, + ) + # evaluation DataLoader + self.eval_dataloader = partial( + DataLoader, + batch_size=config.eval_batch_size, + num_workers=config.data_workers, + pin_memory=pin_memory, + ) + + # save path + self.config = config + + def label2node(self, nodes, *ys) -> Generator[csr_matrix, ...]: + """Map labels (leaf nodes) to ancestor nodes at a certain level. + + If num_labels is 8 and nodes is [(0, 1), (2, 3), (4, 6), (5, 7)]. + Then the mapping is as follows: [0, 0, 1, 1, 2, 3, 2, 3] + Suppose one element of ys is [0, 1, 7]. The results after mapping is [0, 3]. + + Args: + nodes: the nodes generated at a pre-defined level. + *ys: true labels (leaf nodes) for train and/or valid datasets. + + Returns: + Generator[csr_matrix]: the mapped labels (ancestor nodes) for train and/or valid datasets. + """ + mapping = np.empty(self.num_labels, dtype=np.uint64) + for idx, node_labels in enumerate(nodes): + mapping[node_labels] = idx + + def _label2node(y: csr_matrix) -> csr_matrix: + row = [] + col = [] + data = [] + for i in range(y.shape[0]): + # n include all mapped ancestor nodes + n = np.unique(mapping[y.indices[y.indptr[i] : y.indptr[i + 1]]]) + row += [i] * len(n) + col += n.tolist() + data += [1] * len(n) + return csr_matrix((data, (row, col)), shape=(y.shape[0], len(nodes))) + + return (_label2node(y) for y in ys) + + def configure_trainer(self, level) -> Trainer: + callbacks = [] + monitor = self.val_metric + # loss cannot be calculated for PLTModel + mode = "max" + + # ModelCheckpoint + callbacks.append( + ModelCheckpoint( + dirpath=self.result_dir, + filename=f"{self.CHECKPOINT_NAME}{level}", + monitor=monitor, + verbose=self.verbose, + mode=mode, + enable_version_counter=False, + save_on_train_epoch_end=True, + ) + ) + + callbacks.append( + EarlyStopping( + monitor=monitor, + patience=self.patience, + mode=mode, + verbose=self.verbose, + ) + ) + + trainer = Trainer( + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + callbacks=callbacks, + max_epochs=self.max_epochs, + # TODO: Decide whether to keep these parameters + enable_progress_bar=True, + default_root_dir=self.result_dir, + ) + return trainer + + def fit(self, datasets): + """fit model to the training dataset + + Args: + datasets: dict of train, val, and test + """ + if self.get_best_model_path(level=1).exists(): + return + + train_data_full = datasets["train_full"] + train_sparse_x = datasets["train_sparse_x"] + # sparse training labels + # TODO: remove workaround in future PR + train_sparse_y_full = self.mlb.transform((i["label"] for i in train_data_full)) + + train_x = self.reformat_text(datasets["train"]) + val_x = self.reformat_text(datasets["val"]) + + train_y = self.mlb.transform((i["label"] for i in datasets["train"])) + val_y = self.mlb.transform((i["label"] for i in datasets["val"])) + + # only do clustering on GPU 0 + @rank_zero_only + def start_cluster(): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit( + build_label_tree, + sparse_x=train_sparse_x, + sparse_y=train_sparse_y_full, + cluster_size=self.cluster_size, + output_dir=self.result_dir, + ) + future.result() + + start_cluster() + + # wait until the clustering process finishes + cluster_path = self.get_cluster_path() + while not cluster_path.exists(): + time.sleep(15) + clusters = np.load(cluster_path, allow_pickle=True) + + # each y has been mapped to the node indices of its parent + train_y_cluster, val_y_cluster = self.label2node(clusters, train_y, val_y) + # regard each internal nodes as a "labels" + num_labels = len(clusters) + + # trainer + trainer = self.configure_trainer(level=0) + + best_model_path = self.get_best_model_path(level=0) + if not best_model_path.exists(): + # train & valid dataloaders for training + train_dataloader = self.dataloader(MultiLabelDataset(train_x, train_y_cluster), shuffle=self.shuffle) + val_dataloader = self.dataloader(MultiLabelDataset(val_x, val_y_cluster)) + + model = BaseModel( + network="AttentionXML", + network_config=self.network_config, + embed_vecs=self.embed_vecs, + num_labels=num_labels, + optimizer=self.optimizer, + metrics=self.metrics, + val_metric=self.val_metric, + top_k=self.top_k, + is_multiclass=self.is_multiclass, + init_weight=self.init_weight, + loss_func=self.loss_func, + optimizer_params=self.optimizer_config, + ) + + rank_zero_info(f"Training level 0. Number of labels: {num_labels}") + trainer.fit(model, train_dataloader, val_dataloader) + rank_zero_info(f"Finish training level 0") + + rank_zero_info(f"Best model loaded from {best_model_path}") + model = BaseModel.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) + + # Utilize single GPU to predict + trainer = Trainer( + num_nodes=1, + devices=1, + accelerator=self.accelerator, + logger=False, + ) + rank_zero_info( + f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.top_k}" + ) + # train & val dataloaders for prediction (without labels) + train_dataloader = self.eval_dataloader(MultiLabelDataset(train_x)) + val_dataloader = self.eval_dataloader(MultiLabelDataset(val_x)) + + # returned labels have been clustered into nodes (groups) + train_node_pred = trainer.predict(model, train_dataloader) + valid_node_pred = trainer.predict(model, val_dataloader) + + # shape of node_pred: (n, 2, ~batch_size, top_k). n is floor(num_x / batch_size) + # new shape: (2, num_x, top_k) + _, train_node_y_pred = map(torch.vstack, list(zip(*train_node_pred))) + valid_node_score_pred, valid_node_y_pred = map(torch.vstack, list(zip(*valid_node_pred))) + + # The following process can be simplified using method from LightXML + rank_zero_info("Getting Candidates") + node_candidates = np.empty((len(train_x), self.top_k), dtype=np.int64) + prog = rank_zero_only(tqdm)(train_node_y_pred, leave=False, desc="Parents") + if prog is None: + prog = train_node_y_pred + for i, ys in enumerate(prog): + # true nodes/labels are positive + positive = set(train_y_cluster.indices[train_y_cluster.indptr[i] : train_y_cluster.indptr[i + 1]]) + # Regard positive nodes and predicted training nodes that are not in positive as candidates + # until reaching top_k if the number of positive labels is less than top_k. + if len(positive) <= self.top_k: + candidates = positive + for y in ys: + y = y.item() + if len(candidates) == self.top_k: + break + candidates.add(y) + # Regard positive (true) label as candidates iff they appear in the predicted labels + # if the number of positive labels is more than top_k. If candidates are not of length top_k + # add unseen predicted labels until reaching top_k. + else: + candidates = set() + for y in ys: + y = y.item() + if y in positive: + candidates.add(y) + if len(candidates) == self.top_k: + break + if len(candidates) < self.top_k: + candidates = (list(candidates) + list(positive - candidates))[: self.top_k] + node_candidates[i] = np.asarray(list(candidates)) + + # mapping from the current nodes to leaf nodes. + assert reduce(lambda a, b: a + len(b), clusters, 0) == self.num_labels + + # trainer + trainer = self.configure_trainer(level=1) + + # train & valid dataloaders for training + train_dataloader = self.dataloader( + PLTDataset( + train_x, + train_y, + num_labels=self.num_labels, + mapping=clusters, + node_label=node_candidates, + ), + shuffle=self.shuffle, + ) + valid_dataloader = self.dataloader( + PLTDataset( + val_x, + val_y, + num_labels=self.num_labels, + mapping=clusters, + node_label=valid_node_y_pred, + node_score=valid_node_score_pred, + ), + ) + + model = PLTModel( + network="FastAttentionXML", + network_config=self.network_config, + embed_vecs=self.embed_vecs, + num_labels=self.num_labels, + optimizer=self.optimizer, + metrics=self.metrics, + top_k=self.top_k, + val_metric=self.val_metric, + is_multiclass=self.is_multiclass, + loss_func=self.loss_func, + optimizer_params=self.optimizer_config, + ) + torch.nn.init.xavier_uniform_(model.network.attention.attention.weight) + + # initialize model with weights from level 0 + rank_zero_info(f"Loading parameters of level 1 from level 0") + state_dict = torch.load(self.get_best_model_path(level=0))["state_dict"] + + # remove the name prefix in state_dict starting with "network.xxx" + embedding_state_dict = {} + encoder_state_dict = {} + output_state_dict = {} + for n, p in state_dict.items(): + truncated_n = n.split(".", 2)[-1] + if n.startswith("network.embedding"): + embedding_state_dict[truncated_n] = p + elif n.startswith("network.encoder"): + encoder_state_dict[truncated_n] = p + elif n.startswith("network.output"): + output_state_dict[truncated_n] = p + model.network.embedding.load_state_dict(embedding_state_dict) + model.network.encoder.load_state_dict(encoder_state_dict) + model.network.output.load_state_dict(output_state_dict) + + rank_zero_info( + f"Training level 1, Number of labels: {self.num_labels}, " + f"Number of candidates: {train_dataloader.dataset.num_candidates}" + ) + trainer.fit(model, train_dataloader, valid_dataloader) + rank_zero_info(f"Best model loaded from {best_model_path}") + rank_zero_info(f"Finish training level 1") + + # testing will hang forever without destroying process group + if dist.is_initialized(): + dist.destroy_process_group() + + # why we want to test on a single GPU? + # https://lightning.ai/docs/pytorch/stable/common/evaluation_intermediate.html + @rank_zero_only + def test(self, dataset): + test_x = self.reformat_text(dataset) + test_y = self.mlb.transform((i["label"] for i in dataset)) + rank_zero_info("Start predicting process.") + trainer = Trainer( + devices=1, + accelerator=self.accelerator, + logger=False, + ) + + # prediction starts from level 0 + model = BaseModel.load_from_checkpoint( + self.get_best_model_path(level=0), + embed_vecs=self.embed_vecs, + top_k=self.top_k, + metrics=self.metrics, + ) + + test_dataloader = self.eval_dataloader(MultiLabelDataset(test_x)) + + rank_zero_info(f"Predicting level 0, Top: {self.top_k}") + node_pred = trainer.predict(model, test_dataloader) + node_score_pred, node_label_pred = map(torch.vstack, list(zip(*node_pred))) + + clusters = np.load(self.get_cluster_path(), allow_pickle=True) + + model = PLTModel.load_from_checkpoint( + self.get_best_model_path(level=1), embed_vecs=self.embed_vecs, top_k=self.top_k, metrics=self.metrics + ) + + test_dataloader = self.eval_dataloader( + PLTDataset( + test_x, + test_y, + num_labels=self.num_labels, + mapping=clusters, + node_label=node_label_pred, + node_score=node_score_pred, + ), + ) + + rank_zero_info(f"Testing on level 1") + trainer.test(model, test_dataloader) + rank_zero_info("Testing process finished") + + def reformat_text(self, dataset): + encoded_text = list( + map( + lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64) + if text + else torch.tensor([self.word_dict[UNK]], dtype=torch.int64), + [instance["text"][: self.config["max_seq_length"]] for instance in dataset], + ) + ) + # pad the first entry to be of length 500 if necessary + encoded_text[0] = torch.cat( + ( + encoded_text[0], + torch.tensor(0, dtype=torch.int64).repeat(self.config["max_seq_length"] - encoded_text[0].shape[0]), + ) + ) + encoded_text = pad_sequence(encoded_text, batch_first=True) + return encoded_text + + def get_best_model_path(self, level: int) -> Path: + return self.result_dir / f"{self.CHECKPOINT_NAME}{level}{ModelCheckpoint.FILE_EXTENSION}" + + def get_cluster_path(self) -> Path: + return self.result_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}" diff --git a/main.py b/main.py index 3b8ba3824..a6094a1e5 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,7 @@ def add_all_arguments(parser): # data parser.add_argument("--data_name", default="unnamed_data", help="Dataset name (default: %(default)s)") parser.add_argument("--training_file", help="Path to training data (default: %(default)s)") + parser.add_argument("--training_sparse_file", help="Path to training sparse data (default: %(default)s)") parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)") parser.add_argument("--test_file", help="Path to test data (default: %(default)s") parser.add_argument( @@ -79,6 +80,10 @@ def add_all_arguments(parser): choices=["adam", "adamw", "adamax", "sgd"], help="Optimizer (default: %(default)s)", ) + parser.add_argument( + "--optimizer_config", + help="Optimizer parameters", + ) parser.add_argument( "--learning_rate", type=float, default=0.0001, help="Learning rate for optimizer (default: %(default)s)" ) @@ -223,6 +228,19 @@ def add_all_arguments(parser): parser.add_argument( "--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)" ) + # AttentionXML + parser.add_argument( + "--cluster_size", + type=int, + default=8, + help="the maximal number of labels inside a cluster (default: %(default)s)", + ) + parser.add_argument( + "--top_k", + type=int, + default=64, + help="sample top-k clusters and use them to train tree model (default: %(default)s)", + ) parser.add_argument( "--beam_width", type=int, diff --git a/torch_trainer.py b/torch_trainer.py index a6c7d87ad..73b118676 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -3,12 +3,14 @@ import numpy as np from lightning.pytorch.callbacks import ModelCheckpoint +from sklearn.preprocessing import MultiLabelBinarizer from transformers import AutoTokenizer from libmultilabel.common_utils import dump_log, is_multiclass_dataset from libmultilabel.nn import data_utils from libmultilabel.nn.model import Model from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer, set_seed +from libmultilabel.nn.plt import PLTTrainer class TorchTrainer: @@ -36,6 +38,7 @@ def __init__( self.run_name = config.run_name self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path + self.classes = classes os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up seed & device @@ -52,6 +55,7 @@ def __init__( if datasets is None: self.datasets = data_utils.load_datasets( training_data=config.training_file, + training_sparse_data=config.training_sparse_file, test_data=config.test_file, val_data=config.val_file, val_size=config.val_size, @@ -62,33 +66,56 @@ def __init__( else: self.datasets = datasets + if config.embed_file is not None: + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + if config.model_name.lower() != "fastattentionxml" + else self.datasets["train_full"], + vocab_file=config.vocab_file, + min_vocab_freq=config.min_vocab_freq, + embed_file=config.embed_file, + silent=config.silent, + normalize_embed=config.normalize_embed, + embed_cache_dir=config.embed_cache_dir, + ) + + if not classes: + self.classes = data_utils.load_or_build_label(self.datasets, config.label_file, config.include_test_labels) + + mlb = MultiLabelBinarizer(classes=self.classes, sparse_output=True) + mlb.fit(None) + self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) - self._setup_model( - classes=classes, - word_dict=word_dict, - embed_vecs=embed_vecs, - log_path=self.log_path, - checkpoint_path=config.checkpoint_path, - ) - self.trainer = init_trainer( - checkpoint_dir=self.checkpoint_dir, - epochs=config.epochs, - patience=config.patience, - early_stopping_metric=config.early_stopping_metric, - val_metric=config.val_metric, - silent=config.silent, - use_cpu=config.cpu, - limit_train_batches=config.limit_train_batches, - limit_val_batches=config.limit_val_batches, - limit_test_batches=config.limit_test_batches, - save_checkpoints=save_checkpoints, - ) - callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] - self.checkpoint_callback = callbacks[0] if callbacks else None + + if self.config.model_name.lower() == "fastattentionxml": + self.trainer = PLTTrainer( + self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict, mlb=mlb + ) + else: + self._setup_model( + word_dict=word_dict, + embed_vecs=embed_vecs, + log_path=self.log_path, + checkpoint_path=config.checkpoint_path, + ) + self.trainer = init_trainer( + checkpoint_dir=self.checkpoint_dir, + epochs=config.epochs, + patience=config.patience, + early_stopping_metric=config.early_stopping_metric, + val_metric=config.val_metric, + silent=config.silent, + use_cpu=config.cpu, + limit_train_batches=config.limit_train_batches, + limit_val_batches=config.limit_val_batches, + limit_test_batches=config.limit_test_batches, + save_checkpoints=save_checkpoints, + ) + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] + self.checkpoint_callback = callbacks[0] if callbacks else None def _setup_model( self, - classes: list = None, word_dict: dict = None, embed_vecs=None, log_path: str = None, @@ -114,21 +141,6 @@ def _setup_model( self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) else: logging.info("Initialize model from scratch.") - if self.config.embed_file is not None: - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"], - vocab_file=self.config.vocab_file, - min_vocab_freq=self.config.min_vocab_freq, - embed_file=self.config.embed_file, - silent=self.config.silent, - normalize_embed=self.config.normalize_embed, - embed_cache_dir=self.config.embed_cache_dir, - ) - if not classes: - classes = data_utils.load_or_build_label( - self.datasets, self.config.label_file, self.config.include_test_labels - ) if self.config.early_stopping_metric not in self.config.monitor_metrics: logging.warn( @@ -147,7 +159,7 @@ def _setup_model( self.model = init_model( model_name=self.config.model_name, network_config=dict(self.config.network_config), - classes=classes, + classes=self.classes, word_dict=word_dict, embed_vecs=embed_vecs, init_weight=self.config.init_weight, @@ -194,31 +206,34 @@ def train(self): """Train model with pytorch lightning trainer. Set model to the best model after the training process is finished. """ - assert ( - self.trainer is not None - ), "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." - train_loader = self._get_dataset_loader(split="train", shuffle=self.config.shuffle) - - if "val" not in self.datasets: - logging.info("No validation dataset is provided. Train without vaildation.") - self.trainer.fit(self.model, train_loader) - else: - val_loader = self._get_dataset_loader(split="val") - self.trainer.fit(self.model, train_loader, val_loader) - - # Set model to the best model. If the validation process is skipped during - # training (i.e., val_size=0), the model is set to the last model. - model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path - if model_path: - logging.info(f"Finished training. Load best model from {model_path}.") - self._setup_model(checkpoint_path=model_path, log_path=self.log_path) + if self.config.model_name.lower() == "fastattentionxml": + self.trainer.fit(self.datasets) else: - logging.info( - "No model is saved during training. \ - If you want to save the best and the last model, please set `save_checkpoints` to True." - ) + assert ( + self.trainer is not None + ), "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." + train_loader = self._get_dataset_loader(split="train", shuffle=self.config.shuffle) + + if "val" not in self.datasets: + logging.info("No validation dataset is provided. Train without vaildation.") + self.trainer.fit(self.model, train_loader) + else: + val_loader = self._get_dataset_loader(split="val") + self.trainer.fit(self.model, train_loader, val_loader) - dump_log(self.log_path, config=self.config) + # Set model to the best model. If the validation process is skipped during + # training (i.e., val_size=0), the model is set to the last model. + model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path + if model_path: + logging.info(f"Finished training. Load best model from {model_path}.") + self._setup_model(checkpoint_path=model_path, log_path=self.log_path) + else: + logging.info( + "No model is saved during training. \ + If you want to save the best and the last model, please set `save_checkpoints` to True." + ) + + dump_log(self.log_path, config=self.config) # return best model score for ray return self.checkpoint_callback.best_model_score.item() if self.checkpoint_callback.best_model_score else None @@ -235,15 +250,18 @@ def test(self, split="test"): """ assert "test" in self.datasets and self.trainer is not None - logging.info(f"Testing on {split} set.") - test_loader = self._get_dataset_loader(split=split) - metric_dict = self.trainer.test(self.model, dataloaders=test_loader, verbose=False)[0] + if self.config.model_name.lower() == "fastattentionxml": + self.trainer.test(self.datasets["test"]) + else: + logging.info(f"Testing on {split} set.") + test_loader = self._get_dataset_loader(split=split) + metric_dict = self.trainer.test(self.model, dataloaders=test_loader, verbose=False)[0] - if self.config.save_k_predictions > 0: - self._save_predictions(test_loader, self.config.predict_out_path) + if self.config.save_k_predictions > 0: + self._save_predictions(test_loader, self.config.predict_out_path) - dump_log(self.log_path, config=self.config) - return metric_dict + dump_log(self.log_path, config=self.config) + return metric_dict def _save_predictions(self, dataloader, predict_out_path): """Save top k label results. From 5b41782beac408efa9250e1b0f4883c3589b7d89 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Fri, 26 Jan 2024 14:54:10 +0400 Subject: [PATCH 02/29] improve variable naming, comments, and remove threading --- ...{fastattentionxml.yml => attentionxml.yml} | 8 +- ...{fastattentionxml.yml => attentionxml.yml} | 8 +- ...{fastattentionxml.yml => attentionxml.yml} | 6 +- libmultilabel/nn/cluster.py | 83 +++-- libmultilabel/nn/datasets.py | 149 -------- libmultilabel/nn/datasets_AttentionXML.py | 127 +++++++ libmultilabel/nn/model.py | 215 +----------- libmultilabel/nn/model_AttentionXML.py | 77 ++++ .../networks/labelwise_attention_networks.py | 8 +- libmultilabel/nn/plt.py | 332 +++++++++--------- torch_trainer.py | 15 +- 11 files changed, 419 insertions(+), 609 deletions(-) rename example_config/AmazonCat-13K/{fastattentionxml.yml => attentionxml.yml} (90%) rename example_config/EUR-Lex/{fastattentionxml.yml => attentionxml.yml} (89%) rename example_config/Wiki10-31K/{fastattentionxml.yml => attentionxml.yml} (90%) delete mode 100644 libmultilabel/nn/datasets.py create mode 100644 libmultilabel/nn/datasets_AttentionXML.py create mode 100644 libmultilabel/nn/model_AttentionXML.py diff --git a/example_config/AmazonCat-13K/fastattentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml similarity index 90% rename from example_config/AmazonCat-13K/fastattentionxml.yml rename to example_config/AmazonCat-13K/attentionxml.yml index 7f84202b4..0b6484b37 100644 --- a/example_config/AmazonCat-13K/fastattentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -14,7 +14,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -top_k: 64 +save_k_predictions: 64 # data batch_size: 200 @@ -26,9 +26,6 @@ eval_batch_size: 200 monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] val_metric: nDCG@5 -# trainer params -accelerator: gpu - # train seed: 1337 epochs: 10 @@ -41,8 +38,7 @@ patience: 5 silent: true # model -model_name: FastAttentionXML -loss_func: binary_cross_entropy_with_logits +model_name: AttentionXML network_config: embed_dropout: 0.2 post_encoder_dropout: 0.5 diff --git a/example_config/EUR-Lex/fastattentionxml.yml b/example_config/EUR-Lex/attentionxml.yml similarity index 89% rename from example_config/EUR-Lex/fastattentionxml.yml rename to example_config/EUR-Lex/attentionxml.yml index 186814e5d..5a2a57732 100644 --- a/example_config/EUR-Lex/fastattentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -14,7 +14,7 @@ max_seq_length: 500 # AttentionXML-related parameters cluster_size: 8 -top_k: 64 +save_k_predictions: 64 # dataloader batch_size: 40 @@ -26,9 +26,6 @@ eval_batch_size: 40 monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] val_metric: nDCG@5 -# trainer params -accelerator: gpu - # train seed: 1337 epochs: 30 @@ -39,8 +36,7 @@ optimizer: Adam patience: 5 # model -model_name: FastAttentionXML -loss_func: binary_cross_entropy_with_logits +model_name: AttentionXML network_config: embed_dropout: 0.2 post_encoder_dropout: 0.5 diff --git a/example_config/Wiki10-31K/fastattentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml similarity index 90% rename from example_config/Wiki10-31K/fastattentionxml.yml rename to example_config/Wiki10-31K/attentionxml.yml index 169beef8c..1abb780cc 100644 --- a/example_config/Wiki10-31K/fastattentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -26,9 +26,6 @@ eval_batch_size: 40 monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] val_metric: nDCG@5 -# trainer params -accelerator: gpu - # train seed: 1337 epochs: 30 @@ -39,8 +36,7 @@ optimizer: Adam patience: 5 # model -model_name: FastAttentionXML -loss_func: binary_cross_entropy_with_logits +model_name: AttentionXML network_config: embed_dropout: 0.2 encoder_dropout: 0.5 diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index 58ca77c3e..a481c8c66 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -4,90 +4,86 @@ from pathlib import Path import numpy as np -from scipy.sparse import csr_matrix, csc_matrix -from sklearn.preprocessing import normalize from numpy import ndarray +from scipy.sparse import csc_matrix, csr_matrix +from sklearn.preprocessing import normalize __all__ = ["CLUSTER_NAME", "FILE_EXTENSION", "build_label_tree"] logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) CLUSTER_NAME = "label_clusters" FILE_EXTENSION = ".npy" def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): - """Build label tree described in AttentionXML + """Group labels into clusters that contain up tp cluster_size labels. + Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters looks something like: + ((0, 2), (1, 3), (4, 5)). Args: sparse_x: features extracted from texts in CSR sparse format - sparse_y: labels in CSR sparse format - cluster_size: the maximal number of labels inside a cluster - output_dir: directory to save the clusters + sparse_y: binarized labels in CSR sparse format + cluster_size: the maximum number of labels within each cluster + output_dir: directory to store the clustering file """ - # skip label clustering if the clustering file already exists + # skip constructing label tree if the output file already exists output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir) cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}" if cluster_path.exists(): logger.info("Clustering has finished in a previous run") return - # cluster meta info - logger.info("Performing label clustering") + # meta info + logger.info("Label clustering started") logger.info(f"Cluster size: {cluster_size}") num_labels = sparse_y.shape[1] - # the height of the tree satisfies the following inequation: + # The height of the tree satisfies the following inequation: # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size height = int(np.ceil(np.log2(num_labels / cluster_size))) logger.info(f"Labels will be grouped into {2**height} clusters") output_dir.mkdir(parents=True, exist_ok=True) - # generate label representations + # the normalized representations of the relationship between labels and texts label_repr = normalize(sparse_y.T @ csc_matrix(sparse_x)) # clustering process - rng = np.random.default_rng() clusters = [np.arange(num_labels)] for _ in range(height): - assert sum(map(len, clusters)) == num_labels - next_clusters = [] - for idx_in_cluster in clusters: - next_clusters.extend(_cluster_split(idx_in_cluster, label_repr, rng)) + for cluster in clusters: + next_clusters.extend(_split_cluster(cluster, label_repr)) clusters = next_clusters logger.info(f"Having grouped {len(clusters)} clusters") np.save(cluster_path, np.asarray(clusters, dtype=object)) - logger.info(f"Finish clustering, saving cluster to '{cluster_path}'") - - -def _cluster_split( - idx_in_cluster: ndarray, label_repr: csr_matrix, rng: np.random.Generator -) -> tuple[ndarray, ndarray]: - """A variant of KMeans implemented in AttentionXML. Its main differences with sklearn.KMeans are: - 1. the distance metric is cosine similarity as all label representations are normalized. - 2. the end-of-loop criterion is the difference between the new and old average in-cluster distance to centroids. - Possible drawbacks: - Random initialization. - cluster_size matters. + logger.info(f"Label clustering finished. Saving results to {repr(cluster_path)}") + + +def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]: + """A variant of KMeans implemented in AttentionXML. The cluster is partitioned into two groups, each with + approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are: + 1. the distance metric is cosine similarity. + 2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids. + + Args: + cluster: a subset of labels + label_repr: the normalized representations of the relationship between labels and texts """ - # tol is a possible hyperparameter tol = 1e-4 - if tol <= 0 or tol > 1: - raise ValueError(f"tol should be a positive number that is less than 1, got {repr(tol)} instead.") - # the corresponding label representations in the node - tgt_repr = label_repr[idx_in_cluster] + # the normalized label representations corresponding to the cluster + tgt_repr = label_repr[cluster] - # the number of leaf labels in the node - n = len(idx_in_cluster) + # the number of labels in the cluster + n = len(cluster) - # randomly choose two points as initial centroids - centroids = tgt_repr[rng.choice(n, size=2, replace=False)].toarray() + # Randomly choose two points as initial centroids and obtain their label representations + centroids = tgt_repr[np.random.choice(n, size=2, replace=False)].toarray() - # initialize distances (cosine similarity) + # Initialize distances (cosine similarity) + # The cosine similarity always belongs to the interval [-1, 1] old_dist = -2.0 new_dist = -1.0 @@ -96,7 +92,8 @@ def _cluster_split( c1_idx = None while new_dist - old_dist >= tol: - # each points' distance (cosine similarity) to the two centroids + # each point's distances (cosine similarity) to the two centroids + # tgs_repr and centroids.T have been normalized dist = tgt_repr @ centroids.T # shape: (n, 2) # generate clusters @@ -107,12 +104,12 @@ def _cluster_split( c1_idx = c_idx[k:] # update distances - # the distance is the average in-cluster distance to the centroids + # the new distance is the average of in-cluster distances to the centroids old_dist = new_dist new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / n # update centroids - # the new centroid is the average of the points in the cluster + # the new centroid is the normalized average of the points in the cluster centroids = normalize( np.asarray( [ @@ -121,4 +118,4 @@ def _cluster_split( ] ) ) - return idx_in_cluster[c0_idx], idx_in_cluster[c1_idx] + return cluster[c0_idx], cluster[c1_idx] diff --git a/libmultilabel/nn/datasets.py b/libmultilabel/nn/datasets.py deleted file mode 100644 index 6dbc589d1..000000000 --- a/libmultilabel/nn/datasets.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from typing import Sequence, Optional - -import numpy as np -import torch -from lightning.pytorch.utilities.rank_zero import rank_zero_only -from numpy import ndarray -from scipy.sparse import csr_matrix, issparse -from torch import Tensor, is_tensor -from torch.utils.data import Dataset -from tqdm import tqdm - - -class MultiLabelDataset(Dataset): - """Basic class for multi-label dataset.""" - - def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): - """General dataset class for multi-label dataset. - - Args: - x: text. - y: labels. - """ - if y is not None: - assert len(x) == y.shape[0], "Sizes mismatch between x and y" - self.x = x - self.y = y - - def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: - x = self.x[idx] - - # train/valid/test - if self.y is not None: - if issparse(self.y): - y = self.y[idx].toarray().squeeze(0) - elif is_tensor(self.y) or isinstance(self.y, (ndarray, torch.Tensor)): - y = self.y[idx] - else: - raise TypeError( - "The type of y should be one of scipy.csr_matrix, torch.Tensor, and numpy.ndarry." - f"Instead, got {type(self.y)}." - ) - return x, y - # predict - return x - - def __len__(self): - return len(self.x) - - -class PLTDataset(MultiLabelDataset): - """Dataset class for AttentionXML.""" - - def __init__( - self, - x, - y: Optional[csr_matrix | ndarray] = None, - *, - num_labels: int, - mapping: ndarray, - node_label: ndarray | Tensor, - node_score: Optional[ndarray | Tensor] = None, - ): - """Dataset for FastAttentionXML. - ~ means variable length. - - Args: - x: text - y: labels - num_labels: number of nodes at the current level. - mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), ~cluster_size). parent nodes to child nodes. - Cluster size will only vary at the last level. - node_label: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes - from last level. - node_score: corresponding scores. shape: (len(x), top_k) - """ - super().__init__(x, y) - self.num_labels = num_labels - self.mapping = mapping - self.node_label = node_label - self.node_score = node_score - self.candidate_scores = None - - # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) - # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] - prog = rank_zero_only(tqdm)(self.node_label, leave=False, desc="Generating candidates") - if prog is None: - prog = self.node_label - self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] - if self.node_score is not None: - # candidate_scores are corresponding scores for candidates and - # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) - # notice how scores repeat for each cluster. - self.candidate_scores = [ - np.repeat(scores, [len(i) for i in self.mapping[labels]]) - for labels, scores in zip(self.node_label, self.node_score) - ] - - # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. - self.num_candidates = self.node_label.shape[1] * max(len(node) for node in self.mapping) - - def __getitem__(self, idx: int): - x = self.x[idx] - candidates = np.asarray(self.candidates[idx], dtype=np.int64) - - # train/valid/test - if self.y is not None: - # squeezing is necessary here because csr_matrix.toarray() always returns a 2d array - # e.g., np.ndarray([[0, 1, 2]]) - y = self.y[idx].toarray().squeeze(0) - - # train - if self.candidate_scores is None: - # randomly select nodes as candidates when less than required - if len(candidates) < self.num_candidates: - sample = np.random.randint(self.num_labels, size=self.num_candidates - len(candidates)) - candidates = np.concatenate([candidates, sample]) - return x, y, candidates - - # valid/test - else: - candidate_scores = self.candidate_scores[idx] - - # add dummy elements when less than required - if len(candidates) < self.num_candidates: - candidate_scores = np.concatenate( - [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] - ) - candidates = np.concatenate( - [candidates, [self.num_labels] * (self.num_candidates - len(candidates))] - ) - - candidate_scores = np.asarray(candidate_scores, dtype=np.float32) - return x, y, candidates, candidate_scores - - # predict - else: - candidate_scores = self.candidate_scores[idx] - - # add dummy elements when less than required - if len(candidates) < self.num_candidates: - candidate_scores = np.concatenate( - [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] - ) - candidates = np.concatenate([candidates, [self.num_labels] * (self.num_candidates - len(candidates))]) - - candidate_scores = np.asarray(candidate_scores, dtype=np.float32) - return x, candidates, candidate_scores diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py new file mode 100644 index 000000000..eb198795e --- /dev/null +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import Sequence, Optional + +import numpy as np +import torch +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from numpy import ndarray +from scipy.sparse import csr_matrix, issparse +from torch import Tensor, is_tensor +from torch.utils.data import Dataset +from tqdm import tqdm + + +class MultiLabelDataset(Dataset): + """Basic class for multi-label dataset.""" + + def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): + """General dataset class for multi-label dataset. + + Args: + x: texts + y: labels + """ + if y is not None: + assert len(x) == y.shape[0], "Sizes mismatch between x and y" + self.x = x + self.y = y + + def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: + item = {"text": self.x[idx]} + + # train/valid/test + if self.y is not None: + if issparse(self.y): + y = self.y[idx].toarray().squeeze(0) + elif is_tensor(self.y) or isinstance(self.y, (ndarray, torch.Tensor)): + y = self.y[idx] + else: + raise TypeError( + "The type of y should be one of scipy.csr_matrix, torch.Tensor, and numpy.ndarry." + f"Instead, got {type(self.y)}." + ) + item["label"] = y + return item + + def __len__(self): + return len(self.x) + + +class PLTDataset(MultiLabelDataset): + """Dataset class for AttentionXML.""" + + def __init__( + self, + x, + y: Optional[csr_matrix | ndarray] = None, + *, + num_classes: int, + mapping: ndarray, + cluster_samples: ndarray | Tensor, + cluster_scores: Optional[ndarray | Tensor] = None, + ): + """Dataset for AttentionXML. + + Args: + x: texts + y: labels + num_classes: number of nodes at the current level. + mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), cluster_size). Map from clusters to labels. + cluster_samples: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes + from last level. + cluster_scores: corresponding scores. shape: (len(x), top_k) + """ + super().__init__(x, y) + self.num_classes = num_classes + self.mapping = mapping + self.cluster_samples = cluster_samples + self.cluster_scores = cluster_scores + self.label_scores = None + + # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) + # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] + prog = rank_zero_only(tqdm)(self.cluster_samples, leave=False, desc="Generating candidates") + if prog is None: + prog = self.cluster_samples + self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] + if self.cluster_scores is not None: + # label_scores are corresponding scores for candidates and + # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) + # notice how scores repeat for each cluster. + self.label_scores = [ + np.repeat(scores, [len(i) for i in self.mapping[labels]]) + for labels, scores in zip(self.cluster_samples, self.cluster_scores) + ] + + # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. + self.num_candidates = self.cluster_samples.shape[1] * max(len(node) for node in self.mapping) + + def __getitem__(self, idx: int): + item = {"text": self.x[idx], "candidates": np.asarray(self.candidates[idx], dtype=np.int64)} + + # train/valid/test + if self.y is not None: + item["label"] = self.y[idx].toarray().squeeze(0) + + # train + if self.label_scores is None: + # randomly select clusters as candidates when less than required + if len(item["candidates"]) < self.num_candidates: + sample = np.random.randint(self.num_classes, size=self.num_candidates - len(item["candidates"])) + item["candidates"] = np.concatenate([item["candidates"], sample]) + # valid/test + else: + item["label_scores"] = self.label_scores[idx] + + # add dummy elements when less than required + if len(item["candidates"]) < self.num_candidates: + item["label_scores"] = np.concatenate( + [item["label_scores"], [-np.inf] * (self.num_candidates - len(item["candidates"]))] + ) + item["candidates"] = np.concatenate( + [item["candidates"], [self.num_classes] * (self.num_candidates - len(item["candidates"]))] + ) + + item["label_scores"] = np.asarray(item["label_scores"], dtype=np.float32) + return item diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index d4bf6201e..b6088417a 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -6,14 +6,10 @@ import torch import torch.nn.functional as F import torch.optim as optim -from lightning import LightningModule -from torch import nn, Tensor -from torch.nn import Module -from torch.optim import Optimizer +from torch import nn from ..common_utils import argsort_top_k, dump_log from ..nn.metrics import get_metrics, tabulate_metrics -from libmultilabel.nn import networks class MultiLabelModel(L.LightningModule): @@ -238,212 +234,3 @@ def shared_step(self, batch): loss = self.loss_function(pred_logits, target_labels.float()) return loss, pred_logits - - -class BaseModel(LightningModule): - def __init__( - self, - network: str, - network_config: dict, - embed_vecs: Tensor, - num_labels: int, - optimizer: str, - metrics: list[str], - val_metric: str, - top_k: int, - is_multiclass: bool, - init_weight: Optional[str] = None, - loss_func: str = "binary_cross_entropy_with_logits", - optimizer_params: Optional[dict] = None, - lr_scheduler: Optional[str] = None, - metric_threshold: int = 0.5, - ): - super().__init__() - self.save_hyperparameters(ignore="embed_vecs") - - self.network = getattr(networks, network)(embed_vecs=embed_vecs, num_classes=num_labels, **network_config) - self.init_weight = init_weight - if init_weight is not None: - init_weight = networks.get_init_weight_func(init_weight=init_weight) - self.network.apply(init_weight) - - self.loss_func = self.configure_loss_func(loss_func) - - # optimizer config - self.optimizer_name = optimizer.lower() - self.optimizer_params = optimizer_params if optimizer_params is not None else {} - - self.lr_scheduler_name = lr_scheduler - - self.top_k = top_k - - self.num_labels = num_labels - - self.metric_list = metrics - self.val_metric_name = val_metric - self.test_metric_names = metrics - self.is_multiclass = is_multiclass - self.metric_threshold = metric_threshold - - @staticmethod - def configure_loss_func(loss_func: str) -> Module: - try: - loss_func = getattr(F, loss_func) - except AttributeError: - raise AttributeError(f"Invalid loss function name: {loss_func}") - return loss_func - - def configure_optimizers(self) -> Optimizer: - parameters = [p for p in self.parameters() if p.requires_grad] - - if self.optimizer_name == "sgd": - optimizer = optim.SGD - elif self.optimizer_name == "adam": - optimizer = optim.Adam - elif self.optimizer_name == "adamw": - optimizer = optim.AdamW - elif self.optimizer_name == "adamax": - optimizer = optim.Adamax - else: - raise ValueError(f"Unsupported optimizer: {self.optimizer}") - - optimizer = optimizer(parameters, **self.optimizer_params) - if self.lr_scheduler_name is None: - return optimizer - - if self.lr_scheduler_name is not None and self.lr_scheduler_name.lower() == "reducelronplateau": - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="min" if self.val_metric_name == "loss" else "max", - ) - else: - raise ValueError(f"Unsupported learning rate scheduler: {self.lr_scheduler}") - - lr_scheduler_config = { - "scheduler": lr_scheduler, - "monitor": self.val_metric_name, - } - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - - def on_fit_start(self): - self.val_metric = get_metrics( - metric_threshold=self.metric_threshold, - monitor_metrics=[self.val_metric_name], - num_classes=self.num_labels, - top_k=1 if self.is_multiclass else None, - ).to(self.device) - - def training_step(self, batch: Tensor, batch_idx: int): - x, y = batch - logits = self.network(x) - loss = self.loss_func(logits, y.float()) - return loss - - def validation_step(self, batch: Tensor, batch_idx: int): - x, y = batch - logits = self.network(x) - self.val_metric.update(torch.sigmoid(logits), y.long()) - - def on_validation_epoch_end(self): - self.log_dict(self.val_metric.compute(), prog_bar=True) - self.val_metric.reset() - - def on_test_start(self): - self.test_metrics = get_metrics( - metric_threshold=self.metric_threshold, - monitor_metrics=self.test_metric_names, - num_classes=self.num_labels, - top_k=1 if self.is_multiclass else None, - ).to(self.device) - - def test_step(self, batch: Tensor, batch_idx: int): - x, y = batch - logits = self.network(x) - self.test_metrics.update(torch.sigmoid(logits), y.long()) - - def on_test_epoch_end(self): - self.log_dict(self.test_metrics.compute()) - self.test_metrics.reset() - - def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0): - # lightning will put tensors on cpu - x = batch - logits = self.network(x) - scores, labels = torch.topk(torch.sigmoid(logits), self.top_k) - return scores, labels - - def forward(self, x): - return self.network(x) - - -class PLTModel(BaseModel): - def __init__( - self, - network: str, - network_config: dict, - embed_vecs: Tensor, - num_labels: int, - optimizer: str, - metrics: list[str], - val_metric: str, - top_k: int, - is_multiclass: bool, - loss_func: str = "binary_cross_entropy_with_logits", - optimizer_params: Optional[dict] = None, - lr_scheduler: Optional[str] = None, - ): - super().__init__( - network=network, - network_config=network_config, - embed_vecs=embed_vecs, - num_labels=num_labels, - optimizer=optimizer, - metrics=metrics, - val_metric=val_metric, - top_k=top_k, - is_multiclass=is_multiclass, - loss_func=loss_func, - optimizer_params=optimizer_params, - lr_scheduler=lr_scheduler, - ) - - def multilabel_binarize( - self, - logits: Tensor, - candidates: Tensor, - candidate_scores: Tensor, - ) -> Tensor: - """self-implemented MultiLabelBinarizer for AttentionXML""" - src = torch.sigmoid(logits.detach()) * candidate_scores - # make sure preds and src use the same precision, e.g., either float16 or float32 - preds = torch.zeros(candidates.size(0), self.num_labels + 1, device=candidates.device, dtype=src.dtype) - preds.scatter_(dim=1, index=candidates, src=src) - # remove dummy samples - preds = preds[:, :-1] - return preds - - def training_step(self, batch, batch_idx): - x, y, candidates = batch - logits = self.network(x, candidates=candidates) - loss = self.loss_func(logits, torch.take_along_dim(y.float(), candidates, dim=1)) - return loss - - def validation_step(self, batch, batch_idx): - x, y, candidates, candidate_scores = batch - logits = self.network(x, candidates=candidates) - # FIXME: Cannot calculate loss, candidates might contain element whose value is self.num_labels (see dataset.py) - # loss = self.loss_func(logits, torch.from_numpy(np.concatenate([y[:, candidates], offset]))) - y_pred = self.multilabel_binarize(logits, candidates, candidate_scores) - self.val_metric.update(y_pred, y.long()) - - def test_step(self, batch, batch_idx): - x, y, candidates, candidate_scores = batch - logits = self.network(x, candidates=candidates) - y_pred = self.multilabel_binarize(logits, candidates, candidate_scores) - self.test_metrics.update(y_pred, y.long()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - x, candidates, candidate_scores = batch - logits = self.network(x, candidates=candidates) - scores, labels = torch.topk(torch.sigmoid(logits) * candidate_scores, self.top_k) - return scores, torch.take_along_dim(candidates, labels, dim=1) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py new file mode 100644 index 000000000..7cb4e0b89 --- /dev/null +++ b/libmultilabel/nn/model_AttentionXML.py @@ -0,0 +1,77 @@ +from typing import Optional + +import torch +from torch import Tensor + +from .model import MultiLabelModel + + +class PLTModel(MultiLabelModel): + def __init__( + self, + classes, + word_dict, + embed_vecs, + network, + loss_function="binary_cross_entropy_with_logits", + log_path=None, + **kwargs, + ): + super().__init__( + classes=classes, + word_dict=word_dict, + embed_vecs=embed_vecs, + network=network, + loss_function=loss_function, + log_path=log_path, + **kwargs, + ) + + def multilabel_binarize( + self, + logits: Tensor, + samples: Tensor, + label_scores: Tensor, + ) -> Tensor: + """self-implemented MultiLabelBinarizer for AttentionXML""" + src = torch.sigmoid(logits.detach()) * label_scores + # make sure preds and src use the same precision, e.g., either float16 or float32 + preds = torch.zeros(samples.size(0), self.num_labels + 1, device=samples.device, dtype=src.dtype) + preds.scatter_(dim=1, index=samples, src=src) + # remove dummy samples + preds = preds[:, :-1] + return preds + + def training_step(self, batch, batch_idx): + x = batch["text"] + y = batch["label"] + samples = batch["samples"] + logits = self.network(x, samples=samples) + loss = self.loss_func(logits, torch.take_along_dim(y.float(), samples, dim=1)) + return loss + + def validation_step(self, batch, batch_idx): + x = batch["text"] + y = batch["label"] + samples = batch["samples"] + label_scores = batch["label_scores"] + logits = self.network(x, samples=samples) + y_pred = self.multilabel_binarize(logits, samples, label_scores) + self.val_metric.update(y_pred, y.long()) + + def test_step(self, batch, batch_idx): + x = batch["text"] + y = batch["label"] + samples = batch["samples"] + label_scores = batch["label_scores"] + logits = self.network(x, samples=samples) + y_pred = self.multilabel_binarize(logits, samples, label_scores) + self.test_metrics.update(y_pred, y.long()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + x = batch["text"] + samples = batch["samples"] + label_scores = batch["label_scores"] + logits = self.network(x, samples=samples) + scores, labels = torch.topk(torch.sigmoid(logits) * label_scores, self.top_k) + return scores, torch.take_along_dim(samples, labels, dim=1) diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index ab8799770..365ddb276 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -299,7 +299,7 @@ def forward(self, inputs): x = self.encoder(x, lengths) # batch_size, length, hidden_size x, _ = self.attention(x, masks) # batch_size, num_classes, hidden_size x = self.output(x) # batch_size, num_classes - return x + return {"logits": x} class FastAttentionRNN(nn.Module): @@ -321,7 +321,7 @@ def __init__( self.attention = FastLabelwiseAttention(rnn_dim, num_classes) self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) - def forward(self, inputs, candidates): + def forward(self, inputs, samples): # the index of padding is 0 masks = inputs != 0 lengths = masks.sum(dim=1) @@ -329,6 +329,6 @@ def forward(self, inputs, candidates): x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size - x, _ = self.attention(x, masks, candidates) # batch_size, candidate_size, hidden_size + x, _ = self.attention(x, masks, samples) # batch_size, candidate_size, hidden_size x = self.output(x) # batch_size, candidate_size - return x + return {"logits": x} diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 02116fcd0..bc7c2dc68 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -2,31 +2,32 @@ import logging import time -from concurrent.futures import ThreadPoolExecutor from functools import reduce, partial from pathlib import Path from typing import Generator, Optional import numpy as np import torch -import torch.distributed as dist from lightning import Trainer from scipy.sparse import csr_matrix +from sklearn.preprocessing import MultiLabelBinarizer from torch import Tensor from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from lightning.pytorch.utilities import rank_zero_only, rank_zero_info +from lightning.pytorch.callbacks import ModelCheckpoint from .cluster import CLUSTER_NAME, FILE_EXTENSION as CLUSTER_FILE_EXTENSION, build_label_tree from .data_utils import UNK -from .datasets import MultiLabelDataset, PLTDataset -from .model import PLTModel, BaseModel +from .datasets_AttentionXML import MultiLabelDataset, PLTDataset +from .model_AttentionXML import PLTModel +from ..nn import networks +from ..nn.model import Model __all__ = ["PLTTrainer"] -logging.basicConfig(level=logging.INFO) +from .nn_utils import init_trainer, init_model + logger = logging.getLogger(__name__) @@ -36,25 +37,31 @@ class PLTTrainer: def __init__( self, config, - classes: Optional[list] = None, # TODO: removed in the future + classes: Optional[list] = None, embed_vecs: Optional[Tensor] = None, - word_dict: Optional[dict] = None, # TODO: removed in the future - mlb=None, # TODO: removed in the future + word_dict: Optional[dict] = None, ): # The number of levels is set to 2 # In other words, there will be 2 models + if config.multiclass: + raise ValueError( + "The label space of multi-class datasets are usually not large enough for PLT training." + "Please consider other methods." + "If you think this statement is false. Please comment the exception in plt.py" + ) + self.is_multiclass = config.multiclass + # cluster self.cluster_size = config.cluster_size # predict the top k labels - self.top_k = config.top_k + self.predict_top_k = config.top_k # dataset meta info self.embed_vecs = embed_vecs self.word_dict = word_dict - self.mlb = mlb + self.classes = classes self.num_labels = len(classes) - self.is_multiclass = config.multiclass # cluster meta info self.cluster_size = config.cluster_size @@ -70,34 +77,21 @@ def __init__( # Trainer parameters self.accelerator = config.accelerator + self.use_cpu = config.cpu self.devices = 1 self.num_nodes = 1 - self.max_epochs = config.epochs + self.epochs = config.epochs + self.limit_train_batches = config.limit_train_batches + self.limit_val_batches = config.limit_val_batches + self.limit_test_batches = config.limit_test_batches + # callbacks self.val_metric = config.val_metric - self.verbose = not config.silent + self.silent = config.silent # EarlyStopping self.patience = config.patience # ModelCheckpoint self.result_dir = Path(config.result_dir) - # SWA/EMA - # to understand how SWA work, see the pytorch doc and the following link - # https://stackoverflow.com/questions/68726290/setting-learning-rate-for-stochastic-weight-averaging-in-pytorch - # self.swa = config.get("swa") - # - # if (swa_config := config.get("swa")) is not None: - # self.swa_lr = swa_config.get("swa_lr", 5e-2) - # self.swa_epoch_start = swa_config.get("swa_epoch_start") - # self.annealing_epochs = swa_config.get("annealing_epochs", 10) - # self.annealing_strategy = swa_config.get("annealing_strategy", "cos") - # # TODO: SWA or EMA? - # - # # self.avg_fn = None # None == SWA - # def ema_avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: - # decay = 1.0 - 1.0 / num_averaged - # return torch.optim.swa_utils.get_ema_avg_fn(decay=decay) - # - # self.avg_fn = ema_avg_fn self.metrics = config.monitor_metrics @@ -123,12 +117,11 @@ def __init__( # save path self.config = config - def label2node(self, nodes, *ys) -> Generator[csr_matrix, ...]: - """Map labels (leaf nodes) to ancestor nodes at a certain level. + def label2cluster(self, nodes, *ys) -> Generator[csr_matrix, ...]: + """Map labels to their corresponding clusters in CSR sparse format. - If num_labels is 8 and nodes is [(0, 1), (2, 3), (4, 6), (5, 7)]. - Then the mapping is as follows: [0, 0, 1, 1, 2, 3, 2, 3] - Suppose one element of ys is [0, 1, 7]. The results after mapping is [0, 3]. + Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys is [0, 1, 4]. + The clustered labels of the given instance are [0, 2]. Args: nodes: the nodes generated at a pre-defined level. @@ -141,7 +134,7 @@ def label2node(self, nodes, *ys) -> Generator[csr_matrix, ...]: for idx, node_labels in enumerate(nodes): mapping[node_labels] = idx - def _label2node(y: csr_matrix) -> csr_matrix: + def _label2cluster(y: csr_matrix) -> csr_matrix: row = [] col = [] data = [] @@ -153,47 +146,7 @@ def _label2node(y: csr_matrix) -> csr_matrix: data += [1] * len(n) return csr_matrix((data, (row, col)), shape=(y.shape[0], len(nodes))) - return (_label2node(y) for y in ys) - - def configure_trainer(self, level) -> Trainer: - callbacks = [] - monitor = self.val_metric - # loss cannot be calculated for PLTModel - mode = "max" - - # ModelCheckpoint - callbacks.append( - ModelCheckpoint( - dirpath=self.result_dir, - filename=f"{self.CHECKPOINT_NAME}{level}", - monitor=monitor, - verbose=self.verbose, - mode=mode, - enable_version_counter=False, - save_on_train_epoch_end=True, - ) - ) - - callbacks.append( - EarlyStopping( - monitor=monitor, - patience=self.patience, - mode=mode, - verbose=self.verbose, - ) - ) - - trainer = Trainer( - accelerator=self.accelerator, - devices=self.devices, - num_nodes=self.num_nodes, - callbacks=callbacks, - max_epochs=self.max_epochs, - # TODO: Decide whether to keep these parameters - enable_progress_bar=True, - default_root_dir=self.result_dir, - ) - return trainer + return (_label2cluster(y) for y in ys) def fit(self, datasets): """fit model to the training dataset @@ -208,28 +161,23 @@ def fit(self, datasets): train_sparse_x = datasets["train_sparse_x"] # sparse training labels # TODO: remove workaround in future PR - train_sparse_y_full = self.mlb.transform((i["label"] for i in train_data_full)) + self.binarizer = MultiLabelBinarizer(classes=self.classes, sparse_output=True) + self.binarizer.fit(None) + train_sparse_y_full = self.binarizer.transform((i["label"] for i in train_data_full)) train_x = self.reformat_text(datasets["train"]) val_x = self.reformat_text(datasets["val"]) - train_y = self.mlb.transform((i["label"] for i in datasets["train"])) - val_y = self.mlb.transform((i["label"] for i in datasets["val"])) + train_y = self.binarizer.transform((i["label"] for i in datasets["train"])) + val_y = self.binarizer.transform((i["label"] for i in datasets["val"])) - # only do clustering on GPU 0 - @rank_zero_only - def start_cluster(): - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit( - build_label_tree, - sparse_x=train_sparse_x, - sparse_y=train_sparse_y_full, - cluster_size=self.cluster_size, - output_dir=self.result_dir, - ) - future.result() - - start_cluster() + # clustering + build_label_tree( + sparse_x=train_sparse_x, + sparse_y=train_sparse_y_full, + cluster_size=self.cluster_size, + output_dir=self.result_dir, + ) # wait until the clustering process finishes cluster_path = self.get_cluster_path() @@ -238,40 +186,62 @@ def start_cluster(): clusters = np.load(cluster_path, allow_pickle=True) # each y has been mapped to the node indices of its parent - train_y_cluster, val_y_cluster = self.label2node(clusters, train_y, val_y) + train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y) # regard each internal nodes as a "labels" num_labels = len(clusters) # trainer - trainer = self.configure_trainer(level=0) + trainer = init_trainer( + self.result_dir, + epochs=self.epochs, + patience=self.patience, + early_stopping_metric=self.val_metric, + val_metric=self.val_metric, + silent=self.silent, + use_cpu=self.use_cpu, + limit_train_batches=self.limit_train_batches, + limit_val_batches=self.limit_val_batches, + limit_test_batches=self.limit_test_batches, + search_params=False, + save_checkpoints=True, + ) + trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}0{ModelCheckpoint.FILE_EXTENSION}" best_model_path = self.get_best_model_path(level=0) if not best_model_path.exists(): # train & valid dataloaders for training - train_dataloader = self.dataloader(MultiLabelDataset(train_x, train_y_cluster), shuffle=self.shuffle) - val_dataloader = self.dataloader(MultiLabelDataset(val_x, val_y_cluster)) - - model = BaseModel( - network="AttentionXML", - network_config=self.network_config, + train_dataloader = self.dataloader(MultiLabelDataset(train_x, train_y_clustered), shuffle=self.shuffle) + val_dataloader = self.dataloader(MultiLabelDataset(val_x, val_y_clustered)) + + model = init_model( + model_name="AttentionXML", + network_config=self.config.network_config, + classes=self.classes, + word_dict=self.word_dict, embed_vecs=self.embed_vecs, - num_labels=num_labels, - optimizer=self.optimizer, - metrics=self.metrics, - val_metric=self.val_metric, - top_k=self.top_k, - is_multiclass=self.is_multiclass, - init_weight=self.init_weight, - loss_func=self.loss_func, - optimizer_params=self.optimizer_config, + init_weight=self.config.init_weight, + log_path=self.config.log_path, + learning_rate=self.config.learning_rate, + optimizer=self.config.optimizer, + momentum=self.config.momentum, + weight_decay=self.config.weight_decay, + lr_scheduler=self.config.lr_scheduler, + scheduler_config=self.config.scheduler_config, + val_metric=self.config.val_metric, + metric_threshold=self.config.metric_threshold, + monitor_metrics=self.config.monitor_metrics, + multiclass=self.config.multiclass, + loss_function=self.config.loss_function, + silent=self.config.silent, + save_k_predictions=self.config.save_k_predictions, ) - rank_zero_info(f"Training level 0. Number of labels: {num_labels}") + logger.info(f"Training level 0. Number of labels: {num_labels}") trainer.fit(model, train_dataloader, val_dataloader) - rank_zero_info(f"Finish training level 0") + logger.info(f"Finish training level 0") - rank_zero_info(f"Best model loaded from {best_model_path}") - model = BaseModel.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) + logger.info(f"Best model loaded from {best_model_path}") + model = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) # Utilize single GPU to predict trainer = Trainer( @@ -280,8 +250,8 @@ def start_cluster(): accelerator=self.accelerator, logger=False, ) - rank_zero_info( - f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.top_k}" + logger.info( + f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" ) # train & val dataloaders for prediction (without labels) train_dataloader = self.eval_dataloader(MultiLabelDataset(train_x)) @@ -291,58 +261,73 @@ def start_cluster(): train_node_pred = trainer.predict(model, train_dataloader) valid_node_pred = trainer.predict(model, val_dataloader) + import pdb + + pdb.set_trace() + # shape of node_pred: (n, 2, ~batch_size, top_k). n is floor(num_x / batch_size) # new shape: (2, num_x, top_k) _, train_node_y_pred = map(torch.vstack, list(zip(*train_node_pred))) valid_node_score_pred, valid_node_y_pred = map(torch.vstack, list(zip(*valid_node_pred))) # The following process can be simplified using method from LightXML - rank_zero_info("Getting Candidates") - node_candidates = np.empty((len(train_x), self.top_k), dtype=np.int64) - prog = rank_zero_only(tqdm)(train_node_y_pred, leave=False, desc="Parents") - if prog is None: - prog = train_node_y_pred - for i, ys in enumerate(prog): + logger.info("Getting samples") + cluster_samples = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) + for i, ys in enumerate(tqdm(train_node_y_pred, leave=False, desc="Sampling")): # true nodes/labels are positive - positive = set(train_y_cluster.indices[train_y_cluster.indptr[i] : train_y_cluster.indptr[i + 1]]) - # Regard positive nodes and predicted training nodes that are not in positive as candidates + pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) + # Regard positive nodes and predicted training nodes that are not in positive as samples # until reaching top_k if the number of positive labels is less than top_k. - if len(positive) <= self.top_k: - candidates = positive + if len(pos) <= self.predict_top_k: + samples = pos for y in ys: y = y.item() - if len(candidates) == self.top_k: + if len(samples) == self.predict_top_k: break - candidates.add(y) - # Regard positive (true) label as candidates iff they appear in the predicted labels - # if the number of positive labels is more than top_k. If candidates are not of length top_k + samples.add(y) + # Regard positive (true) label as samples iff they appear in the predicted labels + # if the number of positive labels is more than top_k. If samples are not of length top_k # add unseen predicted labels until reaching top_k. else: - candidates = set() + samples = set() for y in ys: y = y.item() - if y in positive: - candidates.add(y) - if len(candidates) == self.top_k: + if y in pos: + samples.add(y) + if len(samples) == self.predict_top_k: break - if len(candidates) < self.top_k: - candidates = (list(candidates) + list(positive - candidates))[: self.top_k] - node_candidates[i] = np.asarray(list(candidates)) + if len(samples) < self.predict_top_k: + samples = (list(samples) + list(pos - samples))[: self.predict_top_k] + cluster_samples[i] = np.asarray(list(samples)) # mapping from the current nodes to leaf nodes. assert reduce(lambda a, b: a + len(b), clusters, 0) == self.num_labels # trainer - trainer = self.configure_trainer(level=1) + trainer = init_trainer( + self.result_dir, + epochs=self.epochs, + patience=self.patience, + early_stopping_metric=self.val_metric, + val_metric=self.val_metric, + silent=self.silent, + use_cpu=self.use_cpu, + limit_train_batches=self.limit_train_batches, + limit_val_batches=self.limit_val_batches, + limit_test_batches=self.limit_test_batches, + search_params=False, + save_checkpoints=True, + ) + trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}1{ModelCheckpoint.FILE_EXTENSION}" # train & valid dataloaders for training train_dataloader = self.dataloader( PLTDataset( train_x, train_y, - num_labels=self.num_labels, + num_classes=self.num_labels, mapping=clusters, - node_label=node_candidates, + cluster_samples=cluster_samples, ), shuffle=self.shuffle, ) @@ -350,21 +335,28 @@ def start_cluster(): PLTDataset( val_x, val_y, - num_labels=self.num_labels, + num_classes=self.num_labels, mapping=clusters, - node_label=valid_node_y_pred, - node_score=valid_node_score_pred, + cluster_samples=valid_node_y_pred, + cluster_scores=valid_node_score_pred, ), ) + try: + network = getattr(networks, "FastAttentionXML")( + embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config) + ) + except: + raise AttributeError("Failed to initialize AttentionXML") + model = PLTModel( - network="FastAttentionXML", + network=network, network_config=self.network_config, embed_vecs=self.embed_vecs, num_labels=self.num_labels, optimizer=self.optimizer, metrics=self.metrics, - top_k=self.top_k, + top_k=self.predict_top_k, val_metric=self.val_metric, is_multiclass=self.is_multiclass, loss_func=self.loss_func, @@ -373,7 +365,7 @@ def start_cluster(): torch.nn.init.xavier_uniform_(model.network.attention.attention.weight) # initialize model with weights from level 0 - rank_zero_info(f"Loading parameters of level 1 from level 0") + logger.info(f"Loading parameters of level 1 from level 0") state_dict = torch.load(self.get_best_model_path(level=0))["state_dict"] # remove the name prefix in state_dict starting with "network.xxx" @@ -392,25 +384,18 @@ def start_cluster(): model.network.encoder.load_state_dict(encoder_state_dict) model.network.output.load_state_dict(output_state_dict) - rank_zero_info( - f"Training level 1, Number of labels: {self.num_labels}, " - f"Number of candidates: {train_dataloader.dataset.num_candidates}" + logger.info( + f"Training level 1. Number of labels: {self.num_labels}." + f"Number of samples: {train_dataloader.dataset.num_samples}" ) trainer.fit(model, train_dataloader, valid_dataloader) - rank_zero_info(f"Best model loaded from {best_model_path}") - rank_zero_info(f"Finish training level 1") - - # testing will hang forever without destroying process group - if dist.is_initialized(): - dist.destroy_process_group() + logger.info(f"Best model loaded from {best_model_path}") + logger.info(f"Finish training level 1") - # why we want to test on a single GPU? - # https://lightning.ai/docs/pytorch/stable/common/evaluation_intermediate.html - @rank_zero_only def test(self, dataset): test_x = self.reformat_text(dataset) - test_y = self.mlb.transform((i["label"] for i in dataset)) - rank_zero_info("Start predicting process.") + test_y = self.binarize.transform((i["label"] for i in dataset)) + logger.info("Testing process started") trainer = Trainer( devices=1, accelerator=self.accelerator, @@ -421,36 +406,39 @@ def test(self, dataset): model = BaseModel.load_from_checkpoint( self.get_best_model_path(level=0), embed_vecs=self.embed_vecs, - top_k=self.top_k, + top_k=self.predict_top_k, metrics=self.metrics, ) test_dataloader = self.eval_dataloader(MultiLabelDataset(test_x)) - rank_zero_info(f"Predicting level 0, Top: {self.top_k}") + logger.info(f"Predicting level 0, Top: {self.predict_top_k}") node_pred = trainer.predict(model, test_dataloader) node_score_pred, node_label_pred = map(torch.vstack, list(zip(*node_pred))) clusters = np.load(self.get_cluster_path(), allow_pickle=True) model = PLTModel.load_from_checkpoint( - self.get_best_model_path(level=1), embed_vecs=self.embed_vecs, top_k=self.top_k, metrics=self.metrics + self.get_best_model_path(level=1), + embed_vecs=self.embed_vecs, + top_k=self.predict_top_k, + metrics=self.metrics, ) test_dataloader = self.eval_dataloader( PLTDataset( test_x, test_y, - num_labels=self.num_labels, + num_classes=self.num_labels, mapping=clusters, - node_label=node_label_pred, - node_score=node_score_pred, + cluster_samples=node_label_pred, + cluster_scores=node_score_pred, ), ) - rank_zero_info(f"Testing on level 1") + logger.info(f"Testing on level 1") trainer.test(model, test_dataloader) - rank_zero_info("Testing process finished") + logger.info("Testing process finished") def reformat_text(self, dataset): encoded_text = list( diff --git a/torch_trainer.py b/torch_trainer.py index 73b118676..4899411b9 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -69,7 +69,7 @@ def __init__( if config.embed_file is not None: word_dict, embed_vecs = data_utils.load_or_build_text_dict( dataset=self.datasets["train"] - if config.model_name.lower() != "fastattentionxml" + if config.model_name.lower() != "attentionxml" else self.datasets["train_full"], vocab_file=config.vocab_file, min_vocab_freq=config.min_vocab_freq, @@ -82,15 +82,10 @@ def __init__( if not classes: self.classes = data_utils.load_or_build_label(self.datasets, config.label_file, config.include_test_labels) - mlb = MultiLabelBinarizer(classes=self.classes, sparse_output=True) - mlb.fit(None) - self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) - if self.config.model_name.lower() == "fastattentionxml": - self.trainer = PLTTrainer( - self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict, mlb=mlb - ) + if self.config.model_name.lower() == "attentionxml": + self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict) else: self._setup_model( word_dict=word_dict, @@ -206,7 +201,7 @@ def train(self): """Train model with pytorch lightning trainer. Set model to the best model after the training process is finished. """ - if self.config.model_name.lower() == "fastattentionxml": + if self.config.model_name.lower() == "attentionxml": self.trainer.fit(self.datasets) else: assert ( @@ -250,7 +245,7 @@ def test(self, split="test"): """ assert "test" in self.datasets and self.trainer is not None - if self.config.model_name.lower() == "fastattentionxml": + if self.config.model_name.lower() == "attentionxml": self.trainer.test(self.datasets["test"]) else: logging.info(f"Testing on {split} set.") From 92f3d562d5f2703ec5f2537ff91f5ee7d3b1abb6 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 13 Feb 2024 16:52:33 +0800 Subject: [PATCH 03/29] fix according to the feedback --- libmultilabel/nn/cluster.py | 22 ++++--- libmultilabel/nn/datasets_AttentionXML.py | 24 ++++---- libmultilabel/nn/model.py | 2 - libmultilabel/nn/plt.py | 75 ++++++++++------------- 4 files changed, 56 insertions(+), 67 deletions(-) diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index a481c8c66..ced4972a9 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -17,8 +17,9 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): - """Group labels into clusters that contain up tp cluster_size labels. - Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters looks something like: + """Build a binary tree to group labels into clusters, each of which contains up tp cluster_size labels. The tree has + several layers; nodes in the last layer correspond to the output clusters. + Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters look something like: ((0, 2), (1, 3), (4, 5)). Args: @@ -38,17 +39,18 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i logger.info("Label clustering started") logger.info(f"Cluster size: {cluster_size}") num_labels = sparse_y.shape[1] - # The height of the tree satisfies the following inequation: + # The height of the tree satisfies the following inequality: # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size height = int(np.ceil(np.log2(num_labels / cluster_size))) logger.info(f"Labels will be grouped into {2**height} clusters") output_dir.mkdir(parents=True, exist_ok=True) - # the normalized representations of the relationship between labels and texts + # For each label, sum up instances relevant to the label and normalize to get the label representation label_repr = normalize(sparse_y.T @ csc_matrix(sparse_x)) - # clustering process + # clustering by a binary tree: + # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters. clusters = [np.arange(num_labels)] for _ in range(height): next_clusters = [] @@ -62,8 +64,8 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]: - """A variant of KMeans implemented in AttentionXML. The cluster is partitioned into two groups, each with - approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are: + """A variant of KMeans implemented in AttentionXML. Here K = 2. The cluster is partitioned into two groups, each + with approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are: 1. the distance metric is cosine similarity. 2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids. @@ -83,7 +85,7 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n centroids = tgt_repr[np.random.choice(n, size=2, replace=False)].toarray() # Initialize distances (cosine similarity) - # The cosine similarity always belongs to the interval [-1, 1] + # Cosine similarity always falls to the interval [-1, 1] old_dist = -2.0 new_dist = -1.0 @@ -93,11 +95,11 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n while new_dist - old_dist >= tol: # each point's distances (cosine similarity) to the two centroids - # tgs_repr and centroids.T have been normalized + # tgt_repr and centroids.T have been normalized dist = tgt_repr @ centroids.T # shape: (n, 2) # generate clusters - # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to the c1 + # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1 k = n // 2 c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k) c0_idx = c_idx[:k] diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index eb198795e..c8f249174 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -58,7 +58,7 @@ def __init__( *, num_classes: int, mapping: ndarray, - cluster_samples: ndarray | Tensor, + clusters_selected: ndarray | Tensor, cluster_scores: Optional[ndarray | Tensor] = None, ): """Dataset for AttentionXML. @@ -68,22 +68,22 @@ def __init__( y: labels num_classes: number of nodes at the current level. mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), cluster_size). Map from clusters to labels. - cluster_samples: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes + clusters_selected: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes from last level. cluster_scores: corresponding scores. shape: (len(x), top_k) """ super().__init__(x, y) self.num_classes = num_classes self.mapping = mapping - self.cluster_samples = cluster_samples + self.clusters_selected = clusters_selected self.cluster_scores = cluster_scores self.label_scores = None # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] - prog = rank_zero_only(tqdm)(self.cluster_samples, leave=False, desc="Generating candidates") + prog = rank_zero_only(tqdm)(self.clusters_selected, leave=False, desc="Generating candidates") if prog is None: - prog = self.cluster_samples + prog = self.clusters_selected self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] if self.cluster_scores is not None: # label_scores are corresponding scores for candidates and @@ -91,11 +91,11 @@ def __init__( # notice how scores repeat for each cluster. self.label_scores = [ np.repeat(scores, [len(i) for i in self.mapping[labels]]) - for labels, scores in zip(self.cluster_samples, self.cluster_scores) + for labels, scores in zip(self.clusters_selected, self.cluster_scores) ] # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. - self.num_candidates = self.cluster_samples.shape[1] * max(len(node) for node in self.mapping) + self.num_clusters_selected = self.clusters_selected.shape[1] * max(len(node) for node in self.mapping) def __getitem__(self, idx: int): item = {"text": self.x[idx], "candidates": np.asarray(self.candidates[idx], dtype=np.int64)} @@ -107,20 +107,20 @@ def __getitem__(self, idx: int): # train if self.label_scores is None: # randomly select clusters as candidates when less than required - if len(item["candidates"]) < self.num_candidates: - sample = np.random.randint(self.num_classes, size=self.num_candidates - len(item["candidates"])) + if len(item["candidates"]) < self.num_clusters_selected: + sample = np.random.randint(self.num_classes, size=self.num_clusters_selected - len(item["candidates"])) item["candidates"] = np.concatenate([item["candidates"], sample]) # valid/test else: item["label_scores"] = self.label_scores[idx] # add dummy elements when less than required - if len(item["candidates"]) < self.num_candidates: + if len(item["candidates"]) < self.num_clusters_selected: item["label_scores"] = np.concatenate( - [item["label_scores"], [-np.inf] * (self.num_candidates - len(item["candidates"]))] + [item["label_scores"], [-np.inf] * (self.num_clusters_selected - len(item["candidates"]))] ) item["candidates"] = np.concatenate( - [item["candidates"], [self.num_classes] * (self.num_candidates - len(item["candidates"]))] + [item["candidates"], [self.num_classes] * (self.num_clusters_selected - len(item["candidates"]))] ) item["label_scores"] = np.asarray(item["label_scores"], dtype=np.float32) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index b6088417a..4c7e2fda7 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -1,12 +1,10 @@ from abc import abstractmethod -from typing import Optional import lightning as L import numpy as np import torch import torch.nn.functional as F import torch.optim as optim -from torch import nn from ..common_utils import argsort_top_k, dump_log from ..nn.metrics import get_metrics, tabulate_metrics diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index bc7c2dc68..169cda28c 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -46,15 +46,15 @@ def __init__( if config.multiclass: raise ValueError( - "The label space of multi-class datasets are usually not large enough for PLT training." + "The label space of multi-class datasets is usually not large, so PLT training is unnecessary." "Please consider other methods." - "If you think this statement is false. Please comment the exception in plt.py" + "If you have a multi-class set with numerous labels, please let us know" ) self.is_multiclass = config.multiclass # cluster self.cluster_size = config.cluster_size - # predict the top k labels + # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training self.predict_top_k = config.top_k # dataset meta info @@ -120,8 +120,8 @@ def __init__( def label2cluster(self, nodes, *ys) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. - Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys is [0, 1, 4]. - The clustered labels of the given instance are [0, 2]. + Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys of a given instance is [0, 1, 4]. + The clusters of the instance are [0, 2]. Args: nodes: the nodes generated at a pre-defined level. @@ -202,7 +202,6 @@ def fit(self, datasets): limit_train_batches=self.limit_train_batches, limit_val_batches=self.limit_val_batches, limit_test_batches=self.limit_test_batches, - search_params=False, save_checkpoints=True, ) trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}0{ModelCheckpoint.FILE_EXTENSION}" @@ -253,55 +252,46 @@ def fit(self, datasets): logger.info( f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" ) - # train & val dataloaders for prediction (without labels) + # load training and validation data and predict corresponding level 0 nodes train_dataloader = self.eval_dataloader(MultiLabelDataset(train_x)) val_dataloader = self.eval_dataloader(MultiLabelDataset(val_x)) - # returned labels have been clustered into nodes (groups) - train_node_pred = trainer.predict(model, train_dataloader) - valid_node_pred = trainer.predict(model, val_dataloader) - - import pdb - - pdb.set_trace() + train_pred = trainer.predict(model, train_dataloader) + val_pred = trainer.predict(model, val_dataloader) # shape of node_pred: (n, 2, ~batch_size, top_k). n is floor(num_x / batch_size) # new shape: (2, num_x, top_k) - _, train_node_y_pred = map(torch.vstack, list(zip(*train_node_pred))) - valid_node_score_pred, valid_node_y_pred = map(torch.vstack, list(zip(*valid_node_pred))) - - # The following process can be simplified using method from LightXML - logger.info("Getting samples") - cluster_samples = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) - for i, ys in enumerate(tqdm(train_node_y_pred, leave=False, desc="Sampling")): - # true nodes/labels are positive + _, train_clusters_pred = map(torch.vstack, list(zip(*train_pred))) + val_scores_pred, val_clusters_pred = map(torch.vstack, list(zip(*val_pred))) + + logger.info("Selecting relevant/irrelevant clusters of each instance for level 1 training") + clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) + for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): + # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) - # Regard positive nodes and predicted training nodes that are not in positive as samples + # Select relevant clusters first. Then from top-predicted clusters, sequentially include their labels until the total # of # until reaching top_k if the number of positive labels is less than top_k. if len(pos) <= self.predict_top_k: - samples = pos + selected = pos for y in ys: y = y.item() - if len(samples) == self.predict_top_k: + if len(selected) == self.predict_top_k: break - samples.add(y) + selected.add(y) # Regard positive (true) label as samples iff they appear in the predicted labels # if the number of positive labels is more than top_k. If samples are not of length top_k # add unseen predicted labels until reaching top_k. else: - samples = set() + selected = set() for y in ys: y = y.item() if y in pos: - samples.add(y) - if len(samples) == self.predict_top_k: + selected.add(y) + if len(selected) == self.predict_top_k: break - if len(samples) < self.predict_top_k: - samples = (list(samples) + list(pos - samples))[: self.predict_top_k] - cluster_samples[i] = np.asarray(list(samples)) - - # mapping from the current nodes to leaf nodes. - assert reduce(lambda a, b: a + len(b), clusters, 0) == self.num_labels + if len(selected) < self.predict_top_k: + selected = (list(selected) + list(pos - selected))[: self.predict_top_k] + clusters_selected[i] = np.asarray(list(selected)) # trainer trainer = init_trainer( @@ -315,7 +305,6 @@ def fit(self, datasets): limit_train_batches=self.limit_train_batches, limit_val_batches=self.limit_val_batches, limit_test_batches=self.limit_test_batches, - search_params=False, save_checkpoints=True, ) trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}1{ModelCheckpoint.FILE_EXTENSION}" @@ -327,7 +316,7 @@ def fit(self, datasets): train_y, num_classes=self.num_labels, mapping=clusters, - cluster_samples=cluster_samples, + clusters_selected=clusters_selected, ), shuffle=self.shuffle, ) @@ -337,8 +326,8 @@ def fit(self, datasets): val_y, num_classes=self.num_labels, mapping=clusters, - cluster_samples=valid_node_y_pred, - cluster_scores=valid_node_score_pred, + clusters_selected=val_clusters_pred, + cluster_scores=val_scores_pred, ), ) @@ -346,7 +335,7 @@ def fit(self, datasets): network = getattr(networks, "FastAttentionXML")( embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config) ) - except: + except Exception: raise AttributeError("Failed to initialize AttentionXML") model = PLTModel( @@ -386,7 +375,7 @@ def fit(self, datasets): logger.info( f"Training level 1. Number of labels: {self.num_labels}." - f"Number of samples: {train_dataloader.dataset.num_samples}" + f"Number of clusters selected: {train_dataloader.dataset.num_clusters_selected}" ) trainer.fit(model, train_dataloader, valid_dataloader) logger.info(f"Best model loaded from {best_model_path}") @@ -394,7 +383,7 @@ def fit(self, datasets): def test(self, dataset): test_x = self.reformat_text(dataset) - test_y = self.binarize.transform((i["label"] for i in dataset)) + test_y = self.binarizer.transform((i["label"] for i in dataset)) logger.info("Testing process started") trainer = Trainer( devices=1, @@ -431,7 +420,7 @@ def test(self, dataset): test_y, num_classes=self.num_labels, mapping=clusters, - cluster_samples=node_label_pred, + clusters_selected=node_label_pred, cluster_scores=node_score_pred, ), ) From 966489ad32f772efb38ece6c6c936d4bb1a2dfcf Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 28 Feb 2024 12:35:58 +0400 Subject: [PATCH 04/29] Use preprocessor from linear --- example_config/AmazonCat-13K/attentionxml.yml | 1 - example_config/EUR-Lex/attentionxml.yml | 1 - example_config/Wiki10-31K/attentionxml.yml | 1 - libmultilabel/nn/cluster.py | 2 +- libmultilabel/nn/data_utils.py | 15 ++------ libmultilabel/nn/datasets_AttentionXML.py | 4 +-- libmultilabel/nn/networks/__init__.py | 4 +-- .../networks/labelwise_attention_networks.py | 4 +-- libmultilabel/nn/plt.py | 36 +++++++++++-------- main.py | 5 --- 10 files changed, 31 insertions(+), 42 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 0b6484b37..04b418807 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -1,6 +1,5 @@ data_name: AmazonCat-13K training_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train.txt -training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train_ver3.svm test_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/test.txt # pretrained embeddings embed_file: glove.840B.300d diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index 5a2a57732..c5405d80e 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -1,6 +1,5 @@ data_name: EUR-Lex training_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.txt -training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.svm test_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/test.txt # pretrained embeddings embed_file: glove.840B.300d diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 1abb780cc..479b8b6f0 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -1,6 +1,5 @@ data_name: Wiki10-31K training_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.txt -training_sparse_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.svm test_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/test.txt # pretrained embeddings embed_file: glove.840B.300d diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index ced4972a9..c24e8dab7 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -47,7 +47,7 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i output_dir.mkdir(parents=True, exist_ok=True) # For each label, sum up instances relevant to the label and normalize to get the label representation - label_repr = normalize(sparse_y.T @ csc_matrix(sparse_x)) + label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x))) # clustering by a binary tree: # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters. diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index 8efd27c43..1bcd45a21 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -7,10 +7,8 @@ import torch import transformers from nltk.tokenize import RegexpTokenizer -from scipy.sparse import issparse -from sklearn.datasets import load_svmlight_file from sklearn.model_selection import train_test_split -from sklearn.preprocessing import MultiLabelBinarizer, normalize +from sklearn.preprocessing import MultiLabelBinarizer from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab @@ -210,7 +208,6 @@ def load_datasets( Args: training_data (Union[str, pandas,.Dataframe], optional): Path to training data or a dataframe. - training_sparse_data (Union[str, pandas,.Dataframe], optional): Path to training sparse data or a dataframe in libsvm format. test_data (Union[str, pandas,.Dataframe], optional): Path to test data or a dataframe. val_data (Union[str, pandas,.Dataframe], optional): Path to validation data or a dataframe. val_size (float, optional): Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set. @@ -233,17 +230,11 @@ def load_datasets( datasets["train"] = _load_raw_data( training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data ) - - if training_sparse_data is not None: - logging.info(f"Loading sparse training data") - datasets["train_sparse_x"] = normalize(load_svmlight_file(training_sparse_data, multilabel=True)[0]) - if val_data is not None: datasets["val"] = _load_raw_data( val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data ) elif val_size > 0: - datasets["train_full"] = datasets["train"] datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42) if test_data is not None: @@ -259,7 +250,7 @@ def load_datasets( del datasets["val"] gc.collect() - msg = " / ".join(f"{k}: {v.shape[0] if issparse(v) else len(v)}" for k, v in datasets.items()) + msg = " / ".join(f"{k}: {len(v)}" for k, v in datasets.items()) logging.info(f"Finish loading dataset ({msg})") return datasets @@ -345,7 +336,7 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False): classes = set() for split, data in datasets.items(): - if (split == "test" and not include_test_labels) or split == "train_sparse_x": + if split == "test" and not include_test_labels: continue for instance in data: classes.update(instance["label"]) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index c8f249174..b8736e1b5 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -79,7 +79,7 @@ def __init__( self.cluster_scores = cluster_scores self.label_scores = None - # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) + # candidate are positive nodes at the current level. shape: (len(x), cluster_size * top_k) # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] prog = rank_zero_only(tqdm)(self.clusters_selected, leave=False, desc="Generating candidates") if prog is None: @@ -87,7 +87,7 @@ def __init__( self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] if self.cluster_scores is not None: # label_scores are corresponding scores for candidates and - # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) + # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) # notice how scores repeat for each cluster. self.label_scores = [ np.repeat(scores, [len(i) for i in self.mapping[labels]]) diff --git a/libmultilabel/nn/networks/__init__.py b/libmultilabel/nn/networks/__init__.py index a27a5c2a2..b78963f68 100644 --- a/libmultilabel/nn/networks/__init__.py +++ b/libmultilabel/nn/networks/__init__.py @@ -9,8 +9,8 @@ from .labelwise_attention_networks import BiLSTMLWAN from .labelwise_attention_networks import BiLSTMLWMHAN from .labelwise_attention_networks import CNNLWAN -from .labelwise_attention_networks import AttentionRNN as AttentionXML -from .labelwise_attention_networks import FastAttentionRNN as FastAttentionXML +from .labelwise_attention_networks import AttentionXML_0 +from .labelwise_attention_networks import AttentionXML_1 def get_init_weight_func(init_weight): diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index 365ddb276..b2a86fdf9 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -270,7 +270,7 @@ def forward(self, input): return {"logits": x} -class AttentionRNN(nn.Module): +class AttentionXML_0(nn.Module): def __init__( self, embed_vecs, @@ -302,7 +302,7 @@ def forward(self, inputs): return {"logits": x} -class FastAttentionRNN(nn.Module): +class AttentionXML_1(nn.Module): def __init__( self, embed_vecs, diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 169cda28c..a94e3c8f6 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -2,7 +2,7 @@ import logging import time -from functools import reduce, partial +from functools import partial from pathlib import Path from typing import Generator, Optional @@ -10,7 +10,7 @@ import torch from lightning import Trainer from scipy.sparse import csr_matrix -from sklearn.preprocessing import MultiLabelBinarizer +from sklearn.preprocessing import normalize from torch import Tensor from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -21,6 +21,7 @@ from .data_utils import UNK from .datasets_AttentionXML import MultiLabelDataset, PLTDataset from .model_AttentionXML import PLTModel +from ..linear.preprocessor import Preprocessor from ..nn import networks from ..nn.model import Model @@ -63,6 +64,9 @@ def __init__( self.classes = classes self.num_labels = len(classes) + # preprocessor of the datasets + self.preprocessor = None + # cluster meta info self.cluster_size = config.cluster_size @@ -157,24 +161,26 @@ def fit(self, datasets): if self.get_best_model_path(level=1).exists(): return - train_data_full = datasets["train_full"] - train_sparse_x = datasets["train_sparse_x"] - # sparse training labels - # TODO: remove workaround in future PR - self.binarizer = MultiLabelBinarizer(classes=self.classes, sparse_output=True) - self.binarizer.fit(None) - train_sparse_y_full = self.binarizer.transform((i["label"] for i in train_data_full)) + # datasets preprocessing + # Convert training texts and labels to a matrix of tfidf features and a binary sparse matrix indicating the + # presence of a class label, respectively + train_val_dataset = datasets["train"] + datasets["val"] + train_val_dataset = {"x": (i["text"] for i in train_val_dataset), "y": (i["label"] for i in train_val_dataset)} + + self.preprocessor = Preprocessor() + datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} + datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) train_x = self.reformat_text(datasets["train"]) val_x = self.reformat_text(datasets["val"]) - train_y = self.binarizer.transform((i["label"] for i in datasets["train"])) - val_y = self.binarizer.transform((i["label"] for i in datasets["val"])) + train_y = datasets_temp_tf["train"]["y"][: len(datasets["train"])] + val_y = datasets_temp_tf["train"]["y"][len(datasets["train"]) :] # clustering build_label_tree( - sparse_x=train_sparse_x, - sparse_y=train_sparse_y_full, + sparse_x=datasets_temp_tf["train"]["x"], + sparse_y=datasets_temp_tf["train"]["y"], cluster_size=self.cluster_size, output_dir=self.result_dir, ) @@ -259,7 +265,7 @@ def fit(self, datasets): train_pred = trainer.predict(model, train_dataloader) val_pred = trainer.predict(model, val_dataloader) - # shape of node_pred: (n, 2, ~batch_size, top_k). n is floor(num_x / batch_size) + # shape of node_pred: (n, 2, batch_size, top_k). n is floor(num_x / batch_size) # new shape: (2, num_x, top_k) _, train_clusters_pred = map(torch.vstack, list(zip(*train_pred))) val_scores_pred, val_clusters_pred = map(torch.vstack, list(zip(*val_pred))) @@ -383,7 +389,7 @@ def fit(self, datasets): def test(self, dataset): test_x = self.reformat_text(dataset) - test_y = self.binarizer.transform((i["label"] for i in dataset)) + test_y = self.preprocessor.binarizer.transform((i["label"] for i in dataset)) logger.info("Testing process started") trainer = Trainer( devices=1, diff --git a/main.py b/main.py index a6094a1e5..b784ab33e 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,6 @@ def add_all_arguments(parser): # data parser.add_argument("--data_name", default="unnamed_data", help="Dataset name (default: %(default)s)") parser.add_argument("--training_file", help="Path to training data (default: %(default)s)") - parser.add_argument("--training_sparse_file", help="Path to training sparse data (default: %(default)s)") parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)") parser.add_argument("--test_file", help="Path to test data (default: %(default)s") parser.add_argument( @@ -80,10 +79,6 @@ def add_all_arguments(parser): choices=["adam", "adamw", "adamax", "sgd"], help="Optimizer (default: %(default)s)", ) - parser.add_argument( - "--optimizer_config", - help="Optimizer parameters", - ) parser.add_argument( "--learning_rate", type=float, default=0.0001, help="Learning rate for optimizer (default: %(default)s)" ) From 90555934504a3c17559452a8c99340a1b482d905 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 28 Feb 2024 14:16:43 +0400 Subject: [PATCH 05/29] fix issues in plt.py --- example_config/Wiki10-31K/attentionxml.yml | 2 +- libmultilabel/nn/plt.py | 90 ++++++++-------------- 2 files changed, 35 insertions(+), 57 deletions(-) diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 479b8b6f0..0e59378f2 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -top_k: 64 +save_k_predictions: 64 # dataloader batch_size: 40 diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index a94e3c8f6..964abccb9 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -10,7 +10,6 @@ import torch from lightning import Trainer from scipy.sparse import csr_matrix -from sklearn.preprocessing import normalize from torch import Tensor from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -33,7 +32,7 @@ class PLTTrainer: - CHECKPOINT_NAME = "model-level-" + CHECKPOINT_NAME = "model_" def __init__( self, @@ -42,9 +41,7 @@ def __init__( embed_vecs: Optional[Tensor] = None, word_dict: Optional[dict] = None, ): - # The number of levels is set to 2 - # In other words, there will be 2 models - + # The number of levels is set to 2. In other words, there will be 2 models if config.multiclass: raise ValueError( "The label space of multi-class datasets is usually not large, so PLT training is unnecessary." @@ -56,7 +53,7 @@ def __init__( # cluster self.cluster_size = config.cluster_size # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training - self.predict_top_k = config.top_k + self.predict_top_k = config.save_k_predictions # dataset meta info self.embed_vecs = embed_vecs @@ -80,8 +77,8 @@ def __init__( self.optimizer_config = config.optimizer_config # Trainer parameters - self.accelerator = config.accelerator self.use_cpu = config.cpu + self.accelerator = "cpu" if self.use_cpu else "gpu" self.devices = 1 self.num_nodes = 1 self.epochs = config.epochs @@ -90,11 +87,12 @@ def __init__( self.limit_test_batches = config.limit_test_batches # callbacks - self.val_metric = config.val_metric self.silent = config.silent # EarlyStopping + self.early_stopping_metric = config.early_stopping_metric self.patience = config.patience # ModelCheckpoint + self.val_metric = config.val_metric self.result_dir = Path(config.result_dir) self.metrics = config.monitor_metrics @@ -121,34 +119,34 @@ def __init__( # save path self.config = config - def label2cluster(self, nodes, *ys) -> Generator[csr_matrix, ...]: + def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys of a given instance is [0, 1, 4]. The clusters of the instance are [0, 2]. Args: - nodes: the nodes generated at a pre-defined level. - *ys: true labels (leaf nodes) for train and/or valid datasets. + cluster_mapping: the clusters generated at a pre-defined level. + *ys: labels for train and/or valid datasets. Returns: - Generator[csr_matrix]: the mapped labels (ancestor nodes) for train and/or valid datasets. + Generator[csr_matrix]: the mapped labels (ancestor clusters) for train and/or valid datasets. """ mapping = np.empty(self.num_labels, dtype=np.uint64) - for idx, node_labels in enumerate(nodes): - mapping[node_labels] = idx + for idx, clusters in enumerate(cluster_mapping): + mapping[clusters] = idx def _label2cluster(y: csr_matrix) -> csr_matrix: row = [] col = [] data = [] for i in range(y.shape[0]): - # n include all mapped ancestor nodes + # n include all mapped ancestor clusters n = np.unique(mapping[y.indices[y.indptr[i] : y.indptr[i + 1]]]) row += [i] * len(n) col += n.tolist() data += [1] * len(n) - return csr_matrix((data, (row, col)), shape=(y.shape[0], len(nodes))) + return csr_matrix((data, (row, col)), shape=(y.shape[0], len(cluster_mapping))) return (_label2cluster(y) for y in ys) @@ -184,16 +182,11 @@ def fit(self, datasets): cluster_size=self.cluster_size, output_dir=self.result_dir, ) + clusters = np.load(self.get_cluster_path(), allow_pickle=True) - # wait until the clustering process finishes - cluster_path = self.get_cluster_path() - while not cluster_path.exists(): - time.sleep(15) - clusters = np.load(cluster_path, allow_pickle=True) - - # each y has been mapped to the node indices of its parent + # each y has been mapped to the cluster indices of its parent train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y) - # regard each internal nodes as a "labels" + # regard each internal clusters as a "labels" num_labels = len(clusters) # trainer @@ -201,7 +194,7 @@ def fit(self, datasets): self.result_dir, epochs=self.epochs, patience=self.patience, - early_stopping_metric=self.val_metric, + early_stopping_metric=self.early_stopping_metric, val_metric=self.val_metric, silent=self.silent, use_cpu=self.use_cpu, @@ -218,7 +211,7 @@ def fit(self, datasets): train_dataloader = self.dataloader(MultiLabelDataset(train_x, train_y_clustered), shuffle=self.shuffle) val_dataloader = self.dataloader(MultiLabelDataset(val_x, val_y_clustered)) - model = init_model( + model_0 = init_model( model_name="AttentionXML", network_config=self.config.network_config, classes=self.classes, @@ -242,11 +235,11 @@ def fit(self, datasets): ) logger.info(f"Training level 0. Number of labels: {num_labels}") - trainer.fit(model, train_dataloader, val_dataloader) + trainer.fit(model_0, train_dataloader, val_dataloader) logger.info(f"Finish training level 0") logger.info(f"Best model loaded from {best_model_path}") - model = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) + model_0 = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) # Utilize single GPU to predict trainer = Trainer( @@ -258,15 +251,15 @@ def fit(self, datasets): logger.info( f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" ) - # load training and validation data and predict corresponding level 0 nodes + # load training and validation data and predict corresponding level 0 clusters train_dataloader = self.eval_dataloader(MultiLabelDataset(train_x)) val_dataloader = self.eval_dataloader(MultiLabelDataset(val_x)) - train_pred = trainer.predict(model, train_dataloader) - val_pred = trainer.predict(model, val_dataloader) + train_pred = trainer.predict(model_0, train_dataloader) + val_pred = trainer.predict(model_0, val_dataloader) - # shape of node_pred: (n, 2, batch_size, top_k). n is floor(num_x / batch_size) - # new shape: (2, num_x, top_k) + # shape of old pred: (n, 2, batch_size, top_k). n is floor(num_x / batch_size) + # shape of new pred: (2, num_x, top_k) _, train_clusters_pred = map(torch.vstack, list(zip(*train_pred))) val_scores_pred, val_clusters_pred = map(torch.vstack, list(zip(*val_pred))) @@ -344,7 +337,7 @@ def fit(self, datasets): except Exception: raise AttributeError("Failed to initialize AttentionXML") - model = PLTModel( + model_1 = PLTModel( network=network, network_config=self.network_config, embed_vecs=self.embed_vecs, @@ -357,33 +350,18 @@ def fit(self, datasets): loss_func=self.loss_func, optimizer_params=self.optimizer_config, ) - torch.nn.init.xavier_uniform_(model.network.attention.attention.weight) - - # initialize model with weights from level 0 - logger.info(f"Loading parameters of level 1 from level 0") - state_dict = torch.load(self.get_best_model_path(level=0))["state_dict"] - - # remove the name prefix in state_dict starting with "network.xxx" - embedding_state_dict = {} - encoder_state_dict = {} - output_state_dict = {} - for n, p in state_dict.items(): - truncated_n = n.split(".", 2)[-1] - if n.startswith("network.embedding"): - embedding_state_dict[truncated_n] = p - elif n.startswith("network.encoder"): - encoder_state_dict[truncated_n] = p - elif n.startswith("network.output"): - output_state_dict[truncated_n] = p - model.network.embedding.load_state_dict(embedding_state_dict) - model.network.encoder.load_state_dict(encoder_state_dict) - model.network.output.load_state_dict(output_state_dict) + torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) + + logger.info(f"Initialize model with weights from the last level") + model_1.network.embedding.load_state_dict(model_0.network.embedding) + model_1.network.encoder.load_state_dict(model_0.network.encoder) + model_1.network.output.load_state_dict(model_0.network.output) logger.info( f"Training level 1. Number of labels: {self.num_labels}." f"Number of clusters selected: {train_dataloader.dataset.num_clusters_selected}" ) - trainer.fit(model, train_dataloader, valid_dataloader) + trainer.fit(model_1, train_dataloader, valid_dataloader) logger.info(f"Best model loaded from {best_model_path}") logger.info(f"Finish training level 1") From 475e2138d5dfa95b9da3fb61ba5cb86ec09fbac1 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 28 Feb 2024 14:22:49 +0400 Subject: [PATCH 06/29] distinguish new Dataset class from TextDataset --- libmultilabel/nn/datasets_AttentionXML.py | 2 +- libmultilabel/nn/plt.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index b8736e1b5..ef81135a3 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -12,7 +12,7 @@ from tqdm import tqdm -class MultiLabelDataset(Dataset): +class PlainDataset(Dataset): """Basic class for multi-label dataset.""" def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 964abccb9..971810ebe 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -18,7 +18,7 @@ from .cluster import CLUSTER_NAME, FILE_EXTENSION as CLUSTER_FILE_EXTENSION, build_label_tree from .data_utils import UNK -from .datasets_AttentionXML import MultiLabelDataset, PLTDataset +from .datasets_AttentionXML import PlainDataset, PLTDataset from .model_AttentionXML import PLTModel from ..linear.preprocessor import Preprocessor from ..nn import networks @@ -208,8 +208,8 @@ def fit(self, datasets): best_model_path = self.get_best_model_path(level=0) if not best_model_path.exists(): # train & valid dataloaders for training - train_dataloader = self.dataloader(MultiLabelDataset(train_x, train_y_clustered), shuffle=self.shuffle) - val_dataloader = self.dataloader(MultiLabelDataset(val_x, val_y_clustered)) + train_dataloader = self.dataloader(PlainDataset(train_x, train_y_clustered), shuffle=self.shuffle) + val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered)) model_0 = init_model( model_name="AttentionXML", @@ -252,8 +252,8 @@ def fit(self, datasets): f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" ) # load training and validation data and predict corresponding level 0 clusters - train_dataloader = self.eval_dataloader(MultiLabelDataset(train_x)) - val_dataloader = self.eval_dataloader(MultiLabelDataset(val_x)) + train_dataloader = self.eval_dataloader(PlainDataset(train_x)) + val_dataloader = self.eval_dataloader(PlainDataset(val_x)) train_pred = trainer.predict(model_0, train_dataloader) val_pred = trainer.predict(model_0, val_dataloader) @@ -383,7 +383,7 @@ def test(self, dataset): metrics=self.metrics, ) - test_dataloader = self.eval_dataloader(MultiLabelDataset(test_x)) + test_dataloader = self.eval_dataloader(PlainDataset(test_x)) logger.info(f"Predicting level 0, Top: {self.predict_top_k}") node_pred = trainer.predict(model, test_dataloader) From 87d75259939da8ba7d9db64388b90660b4e930ae Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 28 Feb 2024 14:29:05 +0400 Subject: [PATCH 07/29] rephrase the comments on cosine similarity --- libmultilabel/nn/cluster.py | 4 +- libmultilabel/nn/datasets_AttentionXML.py | 47 +++++++++++++---------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index c24e8dab7..4a99056b5 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -94,8 +94,8 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n c1_idx = None while new_dist - old_dist >= tol: - # each point's distances (cosine similarity) to the two centroids - # tgt_repr and centroids.T have been normalized + # Notice that tgt_repr and centroids.T have been normalized + # Thus, dist indicates the cosine similarity between points and centroids. dist = tgt_repr @ centroids.T # shape: (n, 2) # generate clusters diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index ef81135a3..cce95a735 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -48,7 +48,7 @@ def __len__(self): return len(self.x) -class PLTDataset(MultiLabelDataset): +class PLTDataset(PlainDataset): """Dataset class for AttentionXML.""" def __init__( @@ -66,9 +66,9 @@ def __init__( Args: x: texts y: labels - num_classes: number of nodes at the current level. - mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), cluster_size). Map from clusters to labels. - clusters_selected: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes + num_classes: number of clusters at the current level. + mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(clusters), cluster_size). Map from clusters to labels. + clusters_selected: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted clusters from last level. cluster_scores: corresponding scores. shape: (len(x), top_k) """ @@ -79,14 +79,14 @@ def __init__( self.cluster_scores = cluster_scores self.label_scores = None - # candidate are positive nodes at the current level. shape: (len(x), cluster_size * top_k) + # labels_selected are positive clusters at the current level. shape: (len(x), cluster_size * top_k) # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] - prog = rank_zero_only(tqdm)(self.clusters_selected, leave=False, desc="Generating candidates") - if prog is None: - prog = self.clusters_selected - self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] + self.labels_selected = [ + np.concatenate(self.mapping[labels]) + for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters") + ] if self.cluster_scores is not None: - # label_scores are corresponding scores for candidates and + # label_scores are corresponding scores for selected labels and # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) # notice how scores repeat for each cluster. self.label_scores = [ @@ -94,11 +94,11 @@ def __init__( for labels, scores in zip(self.clusters_selected, self.cluster_scores) ] - # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. - self.num_clusters_selected = self.clusters_selected.shape[1] * max(len(node) for node in self.mapping) + # top_k * n (n <= cluster_size). number of maximum possible number selected labels at the current level. + self.num_labels_selected = self.clusters_selected.shape[1] * max(len(clusters) for clusters in self.mapping) def __getitem__(self, idx: int): - item = {"text": self.x[idx], "candidates": np.asarray(self.candidates[idx], dtype=np.int64)} + item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx], dtype=np.int64)} # train/valid/test if self.y is not None: @@ -106,21 +106,26 @@ def __getitem__(self, idx: int): # train if self.label_scores is None: - # randomly select clusters as candidates when less than required - if len(item["candidates"]) < self.num_clusters_selected: - sample = np.random.randint(self.num_classes, size=self.num_clusters_selected - len(item["candidates"])) - item["candidates"] = np.concatenate([item["candidates"], sample]) + # randomly select clusters as selected labels when less than required + if len(item["labels_selected"]) < self.num_labels_selected: + sample = np.random.randint( + self.num_classes, size=self.num_labels_selected - len(item["labels_selected"]) + ) + item["labels_selected"] = np.concatenate([item["labels_selected"], sample]) # valid/test else: item["label_scores"] = self.label_scores[idx] # add dummy elements when less than required - if len(item["candidates"]) < self.num_clusters_selected: + if len(item["labels_selected"]) < self.num_labels_selected: item["label_scores"] = np.concatenate( - [item["label_scores"], [-np.inf] * (self.num_clusters_selected - len(item["candidates"]))] + [item["label_scores"], [-np.inf] * (self.num_labels_selected - len(item["labels_selected"]))] ) - item["candidates"] = np.concatenate( - [item["candidates"], [self.num_classes] * (self.num_clusters_selected - len(item["candidates"]))] + item["labels_selected"] = np.concatenate( + [ + item["labels_selected"], + [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])), + ] ) item["label_scores"] = np.asarray(item["label_scores"], dtype=np.float32) From 9eea03ea7488380c800fc5f175a01aca09b9f3f5 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 28 Feb 2024 15:08:18 +0400 Subject: [PATCH 08/29] minor fixes --- example_config/AmazonCat-13K/attentionxml.yml | 3 +- example_config/EUR-Lex/attentionxml.yml | 3 +- example_config/Wiki10-31K/attentionxml.yml | 3 +- libmultilabel/nn/cluster.py | 2 +- libmultilabel/nn/data_utils.py | 2 +- libmultilabel/nn/datasets_AttentionXML.py | 1 - libmultilabel/nn/model.py | 2 +- libmultilabel/nn/model_AttentionXML.py | 58 ++++---- libmultilabel/nn/networks/__init__.py | 3 +- .../networks/labelwise_attention_networks.py | 13 +- libmultilabel/nn/networks/modules.py | 50 +++---- libmultilabel/nn/plt.py | 139 ++++++++++-------- main.py | 12 -- torch_trainer.py | 12 +- 14 files changed, 145 insertions(+), 158 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 04b418807..ba0758ab0 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -29,12 +29,11 @@ val_metric: nDCG@5 seed: 1337 epochs: 10 # https://github.com/Lightning-AI/lightning/issues/8826 -optimizer: Adam +optimizer: adam optimizer_config: lr: 0.001 # early stopping patience: 5 -silent: true # model model_name: AttentionXML diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index c5405d80e..e4440dd10 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -28,9 +28,8 @@ val_metric: nDCG@5 # train seed: 1337 epochs: 30 -silent: true # https://github.com/Lightning-AI/lightning/issues/8826 -optimizer: Adam +optimizer: adam # early stopping patience: 5 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 0e59378f2..3657b3083 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -28,9 +28,8 @@ val_metric: nDCG@5 # train seed: 1337 epochs: 30 -silent: true # https://github.com/Lightning-AI/lightning/issues/8826 -optimizer: Adam +optimizer: adam # early stopping patience: 5 diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index 4a99056b5..23b823848 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -60,7 +60,7 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i logger.info(f"Having grouped {len(clusters)} clusters") np.save(cluster_path, np.asarray(clusters, dtype=object)) - logger.info(f"Label clustering finished. Saving results to {repr(cluster_path)}") + logger.info(f"Label clustering finished. Saving results to {cluster_path}") def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]: diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index 1bcd45a21..863a54036 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -195,7 +195,6 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data def load_datasets( training_data=None, - training_sparse_data=None, test_data=None, val_data=None, val_size=0.2, @@ -230,6 +229,7 @@ def load_datasets( datasets["train"] = _load_raw_data( training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data ) + if val_data is not None: datasets["val"] = _load_raw_data( val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index cce95a735..dbb87af36 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -4,7 +4,6 @@ import numpy as np import torch -from lightning.pytorch.utilities.rank_zero import rank_zero_only from numpy import ndarray from scipy.sparse import csr_matrix, issparse from torch import Tensor, is_tensor diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 4c7e2fda7..0cc405287 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -201,7 +201,7 @@ def __init__( ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( - ignore=["log_path"] + ignore=["log_path", "embed_vecs", "word_dict", "metrics"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). self.word_dict = word_dict self.embed_vecs = embed_vecs diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 7cb4e0b89..a841f0c8b 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -1,12 +1,10 @@ -from typing import Optional - import torch from torch import Tensor -from .model import MultiLabelModel +from .model import Model -class PLTModel(MultiLabelModel): +class PLTModel(Model): def __init__( self, classes, @@ -30,48 +28,50 @@ def __init__( def multilabel_binarize( self, logits: Tensor, - samples: Tensor, + labels_selected: Tensor, label_scores: Tensor, ) -> Tensor: """self-implemented MultiLabelBinarizer for AttentionXML""" src = torch.sigmoid(logits.detach()) * label_scores # make sure preds and src use the same precision, e.g., either float16 or float32 - preds = torch.zeros(samples.size(0), self.num_labels + 1, device=samples.device, dtype=src.dtype) - preds.scatter_(dim=1, index=samples, src=src) - # remove dummy samples + preds = torch.zeros( + labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype + ) + preds.scatter_(dim=1, index=labels_selected, src=src) + # remove dummy labels preds = preds[:, :-1] return preds - def training_step(self, batch, batch_idx): - x = batch["text"] - y = batch["label"] - samples = batch["samples"] - logits = self.network(x, samples=samples) - loss = self.loss_func(logits, torch.take_along_dim(y.float(), samples, dim=1)) - return loss + def shared_step(self, batch): + """Return loss and predicted logits of the network. + + Args: + batch (dict): A batch of text and label. - def validation_step(self, batch, batch_idx): + Returns: + loss (torch.Tensor): Loss between target and predict logits. + pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). + """ x = batch["text"] y = batch["label"] - samples = batch["samples"] - label_scores = batch["label_scores"] - logits = self.network(x, samples=samples) - y_pred = self.multilabel_binarize(logits, samples, label_scores) - self.val_metric.update(y_pred, y.long()) + labels_selected = batch["labels_selected"] + logits = self.network(x, labels_selected=labels_selected)["logits"] + loss = self.loss_function(logits, torch.take_along_dim(y.float(), labels_selected, dim=1)) + return loss, logits - def test_step(self, batch, batch_idx): + def _shared_eval_step(self, batch, batch_idx): x = batch["text"] y = batch["label"] - samples = batch["samples"] + labels_selected = batch["labels_selected"] label_scores = batch["label_scores"] - logits = self.network(x, samples=samples) - y_pred = self.multilabel_binarize(logits, samples, label_scores) - self.test_metrics.update(y_pred, y.long()) + logits = self.network(x, labels_selected=labels_selected)["logits"] + y_pred = self.multilabel_binarize(logits, labels_selected, label_scores) + self.eval_metric.update(y_pred, y.long()) def predict_step(self, batch, batch_idx, dataloader_idx=0): x = batch["text"] - samples = batch["samples"] + labels_selected = batch["labels_selected"] label_scores = batch["label_scores"] - logits = self.network(x, samples=samples) + logits = self.network(x, labels_selected=labels_selected)["logits"] scores, labels = torch.topk(torch.sigmoid(logits) * label_scores, self.top_k) - return scores, torch.take_along_dim(samples, labels, dim=1) + return scores.numpy(force=True), torch.take_along_dim(labels_selected, labels, dim=1).numpy(force=True) diff --git a/libmultilabel/nn/networks/__init__.py b/libmultilabel/nn/networks/__init__.py index b78963f68..52206aae6 100644 --- a/libmultilabel/nn/networks/__init__.py +++ b/libmultilabel/nn/networks/__init__.py @@ -9,8 +9,7 @@ from .labelwise_attention_networks import BiLSTMLWAN from .labelwise_attention_networks import BiLSTMLWMHAN from .labelwise_attention_networks import CNNLWAN -from .labelwise_attention_networks import AttentionXML_0 -from .labelwise_attention_networks import AttentionXML_1 +from .labelwise_attention_networks import AttentionXML_0, AttentionXML_1 def get_init_weight_func(init_weight): diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index b2a86fdf9..2941972e6 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -10,7 +10,7 @@ LabelwiseAttention, LabelwiseMultiHeadAttention, LabelwiseLinearOutput, - FastLabelwiseAttention, + PartialLabelwiseAttention, MultilayerLinearOutput, ) @@ -291,13 +291,14 @@ def __init__( def forward(self, inputs): # the index of padding is 0 + inputs = inputs["text"] masks = inputs != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size - x, _ = self.attention(x, masks) # batch_size, num_classes, hidden_size + x, _ = self.attention(x) # batch_size, num_classes, hidden_size x = self.output(x) # batch_size, num_classes return {"logits": x} @@ -318,10 +319,10 @@ def __init__( super().__init__() self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout) self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout) - self.attention = FastLabelwiseAttention(rnn_dim, num_classes) + self.attention = PartialLabelwiseAttention(rnn_dim, num_classes) self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) - def forward(self, inputs, samples): + def forward(self, inputs, labels_selected): # the index of padding is 0 masks = inputs != 0 lengths = masks.sum(dim=1) @@ -329,6 +330,6 @@ def forward(self, inputs, samples): x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size - x, _ = self.attention(x, masks, samples) # batch_size, candidate_size, hidden_size - x = self.output(x) # batch_size, candidate_size + x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size + x = self.output(x) # batch_size, sample_size return {"logits": x} diff --git a/libmultilabel/nn/networks/modules.py b/libmultilabel/nn/networks/modules.py index 4eda3962d..755792a12 100644 --- a/libmultilabel/nn/networks/modules.py +++ b/libmultilabel/nn/networks/modules.py @@ -72,7 +72,7 @@ def _get_rnn(self, input_size, hidden_size, num_layers, dropout): return nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True) -class LSTMEncoder(nn.Module): +class LSTMEncoder(RNNEncoder): """Bi-directional LSTM encoder with dropout Args: @@ -84,23 +84,10 @@ class LSTMEncoder(nn.Module): """ def __init__(self, input_size, hidden_size, num_layers, encoder_dropout=0, post_encoder_dropout=0): - super().__init__() - self.rnn = nn.LSTM( - input_size, hidden_size, num_layers, batch_first=True, dropout=encoder_dropout, bidirectional=True - ) - self.h0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size)) - self.c0 = nn.Parameter(torch.zeros(2 * num_layers, 1, hidden_size)) - self.post_encoder_dropout = nn.Dropout(post_encoder_dropout) + super(LSTMEncoder, self).__init__(input_size, hidden_size, num_layers, encoder_dropout, post_encoder_dropout) - def forward(self, inputs, length): - self.rnn.flatten_parameters() - idx = torch.argsort(length, descending=True) - h0 = self.h0.repeat([1, inputs.size(0), 1]) - c0 = self.c0.repeat([1, inputs.size(0), 1]) - length_clamped = length[idx].cpu().clamp(min=1) # avoid the empty text with length 0 - packed_input = pack_padded_sequence(inputs[idx], length_clamped, batch_first=True) - outputs, _ = pad_packed_sequence(self.rnn(packed_input, (h0, c0))[0], batch_first=True) - return self.post_encoder_dropout(outputs[torch.argsort(idx)]) + def _get_rnn(self, input_size, hidden_size, num_layers, dropout): + return nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True) class CNNEncoder(nn.Module): @@ -174,12 +161,9 @@ def __init__(self, input_size, num_classes): super(LabelwiseAttention, self).__init__() self.attention = nn.Linear(input_size, num_classes, bias=False) - def forward(self, input, masks): + def forward(self, input): # (batch_size, num_classes, sequence_length) attention = self.attention(input).transpose(1, 2) - if masks is not None: - masks = torch.unsqueeze(masks, 1) # batch_size, 1, length - attention = attention.masked_fill(~masks, -torch.inf) # batch_size, num_classes, length attention = F.softmax(attention, -1) # (batch_size, num_classes, hidden_dim) logits = torch.bmm(attention, input) @@ -228,17 +212,21 @@ def forward(self, input): return (self.output.weight * input).sum(dim=-1) + self.output.bias -class FastLabelwiseAttention(nn.Module): - def __init__(self, hidden_size, num_labels): +class PartialLabelwiseAttention(nn.Module): + """Similar to LabelwiseAttention. + What makes the class different from LabelwiseAttention is that only the weights of selected labels will be + updated in a single iteration. + """ + + def __init__(self, hidden_size, num_classes): super().__init__() - self.attention = nn.Embedding(num_labels + 1, hidden_size) - - def forward(self, inputs, masks, candidates): - masks = torch.unsqueeze(masks, 1) # batch_size, 1, length - attn_inputs = inputs.transpose(1, 2) # batch_size, hidden, length - attn_weights = self.attention(candidates) # batch_size, sample_size, hidden - attention = (attn_weights @ attn_inputs).masked_fill(~masks, -torch.inf) # batch_size, sampled_size, length - attention = F.softmax(attention, -1) # batch_size, sampled_size, length + self.attention = nn.Embedding(num_classes + 1, hidden_size) + + def forward(self, inputs, labels_selected): + attn_inputs = inputs.transpose(1, 2) # batch_size, hidden_dim, length + attn_weights = self.attention(labels_selected) # batch_size, sample_size, hidden_dim + attention = attn_weights @ attn_inputs # batch_size, sample_size, length + attention = F.softmax(attention, -1) # batch_size, sample_size, length logits = attention @ inputs # batch_size, sample_size, hidden_dim return logits, attention diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 971810ebe..c3212ef63 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import time from functools import partial from pathlib import Path from typing import Generator, Optional @@ -10,6 +9,7 @@ import torch from lightning import Trainer from scipy.sparse import csr_matrix +from sklearn.preprocessing import MultiLabelBinarizer from torch import Tensor from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -59,7 +59,7 @@ def __init__( self.embed_vecs = embed_vecs self.word_dict = word_dict self.classes = classes - self.num_labels = len(classes) + self.num_classes = len(classes) # preprocessor of the datasets self.preprocessor = None @@ -70,11 +70,10 @@ def __init__( # network parameters self.network_config = config.network_config self.init_weight = "xavier_uniform" # AttentionXML-specific setting - self.loss_func = config.loss_func + self.loss_function = config.loss_function # optimizer parameters self.optimizer = config.optimizer - self.optimizer_config = config.optimizer_config # Trainer parameters self.use_cpu = config.cpu @@ -110,10 +109,8 @@ def __init__( ) # evaluation DataLoader self.eval_dataloader = partial( - DataLoader, + self.dataloader, batch_size=config.eval_batch_size, - num_workers=config.data_workers, - pin_memory=pin_memory, ) # save path @@ -132,7 +129,7 @@ def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: Returns: Generator[csr_matrix]: the mapped labels (ancestor clusters) for train and/or valid datasets. """ - mapping = np.empty(self.num_labels, dtype=np.uint64) + mapping = np.empty(self.num_classes, dtype=np.uint64) for idx, clusters in enumerate(cluster_mapping): mapping[clusters] = idx @@ -150,6 +147,22 @@ def _label2cluster(y: csr_matrix) -> csr_matrix: return (_label2cluster(y) for y in ys) + def preprocess(self, datasets): + # datasets preprocessing + # Convert training texts and labels to a matrix of tfidf features and a binary sparse matrix indicating the + # presence of a class label, respectively + train_val_dataset = datasets["train"] + datasets["val"] + train_val_dataset = { + "x": [" ".join(i["text"]) for i in train_val_dataset], + "y": [i["label"] for i in train_val_dataset], + } + + self.preprocessor = Preprocessor() + datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} + datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) + + return datasets_temp_tf + def fit(self, datasets): """fit model to the training dataset @@ -159,15 +172,7 @@ def fit(self, datasets): if self.get_best_model_path(level=1).exists(): return - # datasets preprocessing - # Convert training texts and labels to a matrix of tfidf features and a binary sparse matrix indicating the - # presence of a class label, respectively - train_val_dataset = datasets["train"] + datasets["val"] - train_val_dataset = {"x": (i["text"] for i in train_val_dataset), "y": (i["label"] for i in train_val_dataset)} - - self.preprocessor = Preprocessor() - datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} - datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) + datasets_temp_tf = self.preprocess(datasets) train_x = self.reformat_text(datasets["train"]) val_x = self.reformat_text(datasets["val"]) @@ -189,7 +194,6 @@ def fit(self, datasets): # regard each internal clusters as a "labels" num_labels = len(clusters) - # trainer trainer = init_trainer( self.result_dir, epochs=self.epochs, @@ -203,18 +207,17 @@ def fit(self, datasets): limit_test_batches=self.limit_test_batches, save_checkpoints=True, ) - trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}0{ModelCheckpoint.FILE_EXTENSION}" + trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}0" + + train_dataloader = self.dataloader(PlainDataset(train_x, train_y_clustered), shuffle=self.shuffle) + val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered)) best_model_path = self.get_best_model_path(level=0) if not best_model_path.exists(): - # train & valid dataloaders for training - train_dataloader = self.dataloader(PlainDataset(train_x, train_y_clustered), shuffle=self.shuffle) - val_dataloader = self.dataloader(PlainDataset(val_x, val_y_clustered)) - model_0 = init_model( - model_name="AttentionXML", + model_name="AttentionXML_0", network_config=self.config.network_config, - classes=self.classes, + classes=clusters, word_dict=self.word_dict, embed_vecs=self.embed_vecs, init_weight=self.config.init_weight, @@ -229,8 +232,8 @@ def fit(self, datasets): metric_threshold=self.config.metric_threshold, monitor_metrics=self.config.monitor_metrics, multiclass=self.config.multiclass, - loss_function=self.config.loss_function, - silent=self.config.silent, + loss_function=self.loss_function, + silent=self.silent, save_k_predictions=self.config.save_k_predictions, ) @@ -239,7 +242,7 @@ def fit(self, datasets): logger.info(f"Finish training level 0") logger.info(f"Best model loaded from {best_model_path}") - model_0 = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs) + model_0 = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs, word_dict=self.word_dict) # Utilize single GPU to predict trainer = Trainer( @@ -252,16 +255,13 @@ def fit(self, datasets): f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" ) # load training and validation data and predict corresponding level 0 clusters - train_dataloader = self.eval_dataloader(PlainDataset(train_x)) - val_dataloader = self.eval_dataloader(PlainDataset(val_x)) train_pred = trainer.predict(model_0, train_dataloader) val_pred = trainer.predict(model_0, val_dataloader) - # shape of old pred: (n, 2, batch_size, top_k). n is floor(num_x / batch_size) - # shape of new pred: (2, num_x, top_k) - _, train_clusters_pred = map(torch.vstack, list(zip(*train_pred))) - val_scores_pred, val_clusters_pred = map(torch.vstack, list(zip(*val_pred))) + train_clusters_pred = np.vstack([i["top_k_pred"] for i in train_pred]) + val_scores_pred = np.vstack([i["top_k_pred_scores"] for i in val_pred]) + val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) logger.info("Selecting relevant/irrelevant clusters of each instance for level 1 training") clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) @@ -292,7 +292,6 @@ def fit(self, datasets): selected = (list(selected) + list(pos - selected))[: self.predict_top_k] clusters_selected[i] = np.asarray(list(selected)) - # trainer trainer = init_trainer( self.result_dir, epochs=self.epochs, @@ -306,14 +305,14 @@ def fit(self, datasets): limit_test_batches=self.limit_test_batches, save_checkpoints=True, ) - trainer.checkpoint_callback.file_name = f"{self.CHECKPOINT_NAME}1{ModelCheckpoint.FILE_EXTENSION}" + trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}1" # train & valid dataloaders for training train_dataloader = self.dataloader( PLTDataset( train_x, train_y, - num_classes=self.num_labels, + num_classes=self.num_classes, mapping=clusters, clusters_selected=clusters_selected, ), @@ -323,7 +322,7 @@ def fit(self, datasets): PLTDataset( val_x, val_y, - num_classes=self.num_labels, + num_classes=self.num_classes, mapping=clusters, clusters_selected=val_clusters_pred, cluster_scores=val_scores_pred, @@ -331,43 +330,58 @@ def fit(self, datasets): ) try: - network = getattr(networks, "FastAttentionXML")( + network = getattr(networks, "AttentionXML_1")( embed_vecs=self.embed_vecs, num_classes=len(self.classes), **dict(self.network_config) ) except Exception: raise AttributeError("Failed to initialize AttentionXML") model_1 = PLTModel( - network=network, - network_config=self.network_config, + classes=self.classes, + word_dict=self.word_dict, embed_vecs=self.embed_vecs, - num_labels=self.num_labels, - optimizer=self.optimizer, - metrics=self.metrics, - top_k=self.predict_top_k, - val_metric=self.val_metric, - is_multiclass=self.is_multiclass, - loss_func=self.loss_func, - optimizer_params=self.optimizer_config, + network=network, + log_path=self.config.log_path, + learning_rate=self.config.learning_rate, + optimizer=self.config.optimizer, + momentum=self.config.momentum, + weight_decay=self.config.weight_decay, + lr_scheduler=self.config.lr_scheduler, + scheduler_config=self.config.scheduler_config, + val_metric=self.config.val_metric, + metric_threshold=self.config.metric_threshold, + monitor_metrics=self.config.monitor_metrics, + multiclass=self.config.multiclass, + loss_function=self.loss_function, + silent=self.silent, + save_k_predictions=self.config.save_k_predictions, ) torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) logger.info(f"Initialize model with weights from the last level") - model_1.network.embedding.load_state_dict(model_0.network.embedding) - model_1.network.encoder.load_state_dict(model_0.network.encoder) - model_1.network.output.load_state_dict(model_0.network.output) + model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict()) + model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict()) + model_1.network.output.load_state_dict(model_0.network.output.state_dict()) + + del model_0 logger.info( - f"Training level 1. Number of labels: {self.num_labels}." - f"Number of clusters selected: {train_dataloader.dataset.num_clusters_selected}" + f"Training level 1. Number of labels: {self.num_classes}." + f"Number of labels selected: {train_dataloader.dataset.num_labels_selected}" ) trainer.fit(model_1, train_dataloader, valid_dataloader) logger.info(f"Best model loaded from {best_model_path}") logger.info(f"Finish training level 1") - def test(self, dataset): + def test(self, dataset, classes): test_x = self.reformat_text(dataset) - test_y = self.preprocessor.binarizer.transform((i["label"] for i in dataset)) + + if self.preprocessor is None: + binarizer = MultiLabelBinarizer(classes=classes, sparse_output=True) + binarizer.fit(None) + test_y = binarizer.transform((i["label"] for i in dataset)) + else: + test_y = self.preprocessor.binarizer.transform((i["label"] for i in dataset)) logger.info("Testing process started") trainer = Trainer( devices=1, @@ -376,9 +390,10 @@ def test(self, dataset): ) # prediction starts from level 0 - model = BaseModel.load_from_checkpoint( + model = Model.load_from_checkpoint( self.get_best_model_path(level=0), embed_vecs=self.embed_vecs, + word_dict=self.word_dict, top_k=self.predict_top_k, metrics=self.metrics, ) @@ -386,8 +401,8 @@ def test(self, dataset): test_dataloader = self.eval_dataloader(PlainDataset(test_x)) logger.info(f"Predicting level 0, Top: {self.predict_top_k}") - node_pred = trainer.predict(model, test_dataloader) - node_score_pred, node_label_pred = map(torch.vstack, list(zip(*node_pred))) + test_pred = trainer.predict(model, test_dataloader) + test_score_pred, test_label_pred = map(torch.vstack, list(zip(*test_pred))) clusters = np.load(self.get_cluster_path(), allow_pickle=True) @@ -402,10 +417,10 @@ def test(self, dataset): PLTDataset( test_x, test_y, - num_classes=self.num_labels, + num_classes=self.num_classes, mapping=clusters, - clusters_selected=node_label_pred, - cluster_scores=node_score_pred, + clusters_selected=test_label_pred, + cluster_scores=test_score_pred, ), ) diff --git a/main.py b/main.py index b784ab33e..80d8d256e 100644 --- a/main.py +++ b/main.py @@ -230,18 +230,6 @@ def add_all_arguments(parser): default=8, help="the maximal number of labels inside a cluster (default: %(default)s)", ) - parser.add_argument( - "--top_k", - type=int, - default=64, - help="sample top-k clusters and use them to train tree model (default: %(default)s)", - ) - parser.add_argument( - "--beam_width", - type=int, - default=10, - help="The width of the beam search (default: %(default)s)", - ) parser.add_argument( "-h", "--help", diff --git a/torch_trainer.py b/torch_trainer.py index 4899411b9..8e23216fd 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -3,7 +3,6 @@ import numpy as np from lightning.pytorch.callbacks import ModelCheckpoint -from sklearn.preprocessing import MultiLabelBinarizer from transformers import AutoTokenizer from libmultilabel.common_utils import dump_log, is_multiclass_dataset @@ -55,7 +54,6 @@ def __init__( if datasets is None: self.datasets = data_utils.load_datasets( training_data=config.training_file, - training_sparse_data=config.training_sparse_file, test_data=config.test_file, val_data=config.val_file, val_size=config.val_size, @@ -70,7 +68,7 @@ def __init__( word_dict, embed_vecs = data_utils.load_or_build_text_dict( dataset=self.datasets["train"] if config.model_name.lower() != "attentionxml" - else self.datasets["train_full"], + else self.datasets["train"] + self.datasets["val"], vocab_file=config.vocab_file, min_vocab_freq=config.min_vocab_freq, embed_file=config.embed_file, @@ -230,8 +228,10 @@ def train(self): dump_log(self.log_path, config=self.config) - # return best model score for ray - return self.checkpoint_callback.best_model_score.item() if self.checkpoint_callback.best_model_score else None + # return best model score for ray + return ( + self.checkpoint_callback.best_model_score.item() if self.checkpoint_callback.best_model_score else None + ) def test(self, split="test"): """Test model with pytorch lightning trainer. Top-k predictions are saved @@ -246,7 +246,7 @@ def test(self, split="test"): assert "test" in self.datasets and self.trainer is not None if self.config.model_name.lower() == "attentionxml": - self.trainer.test(self.datasets["test"]) + self.trainer.test(self.datasets["test"], self.classes) else: logging.info(f"Testing on {split} set.") test_loader = self._get_dataset_loader(split=split) From 4d55cd703da9b0acf5d925fecd04aabd30ebec0f Mon Sep 17 00:00:00 2001 From: Dongli He Date: Fri, 1 Mar 2024 16:22:59 +0400 Subject: [PATCH 09/29] fix issues in testing --- libmultilabel/nn/plt.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index c3212ef63..a75c8d906 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -172,7 +172,16 @@ def fit(self, datasets): if self.get_best_model_path(level=1).exists(): return - datasets_temp_tf = self.preprocess(datasets) + # AttentionXML-specific data preprocessing + train_val_dataset = datasets["train"] + datasets["val"] + train_val_dataset = { + "x": [" ".join(i["text"]) for i in train_val_dataset], + "y": [i["label"] for i in train_val_dataset], + } + + self.preprocessor = Preprocessor() + datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} + datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) train_x = self.reformat_text(datasets["train"]) val_x = self.reformat_text(datasets["val"]) @@ -402,13 +411,15 @@ def test(self, dataset, classes): logger.info(f"Predicting level 0, Top: {self.predict_top_k}") test_pred = trainer.predict(model, test_dataloader) - test_score_pred, test_label_pred = map(torch.vstack, list(zip(*test_pred))) + test_score_pred = np.vstack([i["top_k_pred_scores"] for i in test_pred]) + test_label_pred = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) model = PLTModel.load_from_checkpoint( self.get_best_model_path(level=1), embed_vecs=self.embed_vecs, + word_dict=self.word_dict, top_k=self.predict_top_k, metrics=self.metrics, ) From 3b98c177c6638afc1a8ab9383b5f531e56e7d8f7 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 5 Mar 2024 15:10:32 +0400 Subject: [PATCH 10/29] fix learning rate --- example_config/AmazonCat-13K/attentionxml.yml | 3 +-- example_config/EUR-Lex/attentionxml.yml | 1 + example_config/Wiki10-31K/attentionxml.yml | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index ba0758ab0..92f1b9719 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -30,8 +30,7 @@ seed: 1337 epochs: 10 # https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam -optimizer_config: - lr: 0.001 +learning_rate: 0.001 # early stopping patience: 5 diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index e4440dd10..3f40c0c26 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -30,6 +30,7 @@ seed: 1337 epochs: 30 # https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam +learning_rate: 0.001 # early stopping patience: 5 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 3657b3083..3ca617047 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -30,6 +30,7 @@ seed: 1337 epochs: 30 # https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam +learning_rate: 0.001 # early stopping patience: 5 From 3644dbb57c93c8d0702de2059bd9e282c8580e6b Mon Sep 17 00:00:00 2001 From: Dongli He Date: Thu, 7 Mar 2024 13:38:49 +0400 Subject: [PATCH 11/29] fix misused prob & improve readibility --- example_config/AmazonCat-13K/attentionxml.yml | 11 +- example_config/EUR-Lex/attentionxml.yml | 9 +- example_config/Wiki10-31K/attentionxml.yml | 9 +- libmultilabel/nn/cluster.py | 35 ++--- libmultilabel/nn/datasets_AttentionXML.py | 44 +++--- libmultilabel/nn/model_AttentionXML.py | 7 +- libmultilabel/nn/networks/modules.py | 2 +- libmultilabel/nn/plt.py | 148 ++++++++---------- 8 files changed, 121 insertions(+), 144 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 92f1b9719..9c416ed29 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -1,11 +1,9 @@ data_name: AmazonCat-13K -training_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train.txt -test_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/test.txt +# will change path later +training_file: data/AmazonCat-13K/train.txt +test_file: data/AmazonCat-13K/test.txt # pretrained embeddings embed_file: glove.840B.300d -embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding -# save path -result_dir: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/results # preprocessing min_vocab_freq: 1 @@ -28,11 +26,10 @@ val_metric: nDCG@5 # train seed: 1337 epochs: 10 -# https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam learning_rate: 0.001 # early stopping -patience: 5 +patience: 10 # model model_name: AttentionXML diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index 3f40c0c26..87299e2d6 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -1,10 +1,10 @@ data_name: EUR-Lex -training_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.txt -test_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/test.txt +training_file: data/EUR-Lex/train.txt +test_file: data/EUR-Lex/test.txt # pretrained embeddings embed_file: glove.840B.300d -embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding # save path +# will change path later result_dir: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/results # preprocessing @@ -28,11 +28,10 @@ val_metric: nDCG@5 # train seed: 1337 epochs: 30 -# https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam learning_rate: 0.001 # early stopping -patience: 5 +patience: 30 # model model_name: AttentionXML diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 3ca617047..30fab98f5 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -1,9 +1,9 @@ data_name: Wiki10-31K -training_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.txt -test_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/test.txt +# will change path later +training_file: data/Wiki10-31K/train.txt +test_file: data/Wiki10-31K/test.txt # pretrained embeddings embed_file: glove.840B.300d -embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding # save path result_dir: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/results @@ -28,11 +28,10 @@ val_metric: nDCG@5 # train seed: 1337 epochs: 30 -# https://github.com/Lightning-AI/lightning/issues/8826 optimizer: adam learning_rate: 0.001 # early stopping -patience: 5 +patience: 30 # model model_name: AttentionXML diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index 23b823848..8cbcc658b 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -38,24 +38,23 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i # meta info logger.info("Label clustering started") logger.info(f"Cluster size: {cluster_size}") - num_labels = sparse_y.shape[1] # The height of the tree satisfies the following inequality: # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size - height = int(np.ceil(np.log2(num_labels / cluster_size))) + height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size))) logger.info(f"Labels will be grouped into {2**height} clusters") output_dir.mkdir(parents=True, exist_ok=True) - # For each label, sum up instances relevant to the label and normalize to get the label representation + # For each label, sum up normalized instances relevant to the label and normalize to get the label representation label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x))) # clustering by a binary tree: # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters. - clusters = [np.arange(num_labels)] + clusters = [np.arange(sparse_y.shape[1])] for _ in range(height): next_clusters = [] for cluster in clusters: - next_clusters.extend(_split_cluster(cluster, label_repr)) + next_clusters.extend(_split_cluster(cluster, label_repr[cluster])) clusters = next_clusters logger.info(f"Having grouped {len(clusters)} clusters") @@ -71,18 +70,10 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n Args: cluster: a subset of labels - label_repr: the normalized representations of the relationship between labels and texts + label_repr: the normalized representations of the relationship between labels and texts of the given cluster """ - tol = 1e-4 - - # the normalized label representations corresponding to the cluster - tgt_repr = label_repr[cluster] - - # the number of labels in the cluster - n = len(cluster) - # Randomly choose two points as initial centroids and obtain their label representations - centroids = tgt_repr[np.random.choice(n, size=2, replace=False)].toarray() + centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray() # Initialize distances (cosine similarity) # Cosine similarity always falls to the interval [-1, 1] @@ -93,14 +84,14 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n c0_idx = None c1_idx = None - while new_dist - old_dist >= tol: - # Notice that tgt_repr and centroids.T have been normalized + while new_dist - old_dist >= 1e-4: + # Notice that label_repr and centroids.T have been normalized # Thus, dist indicates the cosine similarity between points and centroids. - dist = tgt_repr @ centroids.T # shape: (n, 2) + dist = label_repr @ centroids.T # shape: (n, 2) # generate clusters # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1 - k = n // 2 + k = len(cluster) // 2 c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k) c0_idx = c_idx[:k] c1_idx = c_idx[k:] @@ -108,15 +99,15 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n # update distances # the new distance is the average of in-cluster distances to the centroids old_dist = new_dist - new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / n + new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster) # update centroids # the new centroid is the normalized average of the points in the cluster centroids = normalize( np.asarray( [ - np.squeeze(np.asarray(tgt_repr[c0_idx].sum(axis=0))), - np.squeeze(np.asarray(tgt_repr[c1_idx].sum(axis=0))), + np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))), + np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))), ] ) ) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index dbb87af36..f2f31d7ea 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -12,17 +12,20 @@ class PlainDataset(Dataset): - """Basic class for multi-label dataset.""" + """Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset. + WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset + does not. + Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does + this while generating clusters. There is no need to do multilabel binarization again. - def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): - """General dataset class for multi-label dataset. + Args: + x: texts + y: labels + """ - Args: - x: texts - y: labels - """ + def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): if y is not None: - assert len(x) == y.shape[0], "Sizes mismatch between x and y" + assert len(x) == y.shape[0], "Sizes mismatch between texts and labels" self.x = x self.y = y @@ -48,7 +51,16 @@ def __len__(self): class PLTDataset(PlainDataset): - """Dataset class for AttentionXML.""" + """Dataset for model_1 of AttentionXML. + + Args: + x: texts + y: labels + num_classes: number of classes. + mapping: mapping from clusters to labels. Shape: (len(clusters), cluster_size). + clusters_selected: sampled predicted clusters from model_0. Shape: (len(x), predict_top_k). + cluster_scores: corresponding scores. Shape: (len(x), predict_top_k) + """ def __init__( self, @@ -60,17 +72,6 @@ def __init__( clusters_selected: ndarray | Tensor, cluster_scores: Optional[ndarray | Tensor] = None, ): - """Dataset for AttentionXML. - - Args: - x: texts - y: labels - num_classes: number of clusters at the current level. - mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(clusters), cluster_size). Map from clusters to labels. - clusters_selected: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted clusters - from last level. - cluster_scores: corresponding scores. shape: (len(x), top_k) - """ super().__init__(x, y) self.num_classes = num_classes self.mapping = mapping @@ -105,7 +106,8 @@ def __getitem__(self, idx: int): # train if self.label_scores is None: - # randomly select clusters as selected labels when less than required + # As networks require input to be of fixed shape, randomly select labels when the number of selected label + # is not enough if len(item["labels_selected"]) < self.num_labels_selected: sample = np.random.randint( self.num_classes, size=self.num_labels_selected - len(item["labels_selected"]) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index a841f0c8b..45673a569 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -25,15 +25,14 @@ def __init__( **kwargs, ) - def multilabel_binarize( + def scatter_preds( self, logits: Tensor, labels_selected: Tensor, label_scores: Tensor, ) -> Tensor: - """self-implemented MultiLabelBinarizer for AttentionXML""" + """map predictions from sample space to label space. The scores of unsampled labels are set to 0.""" src = torch.sigmoid(logits.detach()) * label_scores - # make sure preds and src use the same precision, e.g., either float16 or float32 preds = torch.zeros( labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype ) @@ -65,7 +64,7 @@ def _shared_eval_step(self, batch, batch_idx): labels_selected = batch["labels_selected"] label_scores = batch["label_scores"] logits = self.network(x, labels_selected=labels_selected)["logits"] - y_pred = self.multilabel_binarize(logits, labels_selected, label_scores) + y_pred = self.scatter_preds(logits, labels_selected, label_scores) self.eval_metric.update(y_pred, y.long()) def predict_step(self, batch, batch_idx, dataloader_idx=0): diff --git a/libmultilabel/nn/networks/modules.py b/libmultilabel/nn/networks/modules.py index 755792a12..e7e3fc892 100644 --- a/libmultilabel/nn/networks/modules.py +++ b/libmultilabel/nn/networks/modules.py @@ -232,7 +232,7 @@ def forward(self, inputs, labels_selected): class MultilayerLinearOutput(nn.Module): - def __init__(self, linear_size: list[int], output_size: int): + def __init__(self, linear_size, output_size): super().__init__() self.linears = nn.ModuleList(nn.Linear(in_s, out_s) for in_s, out_s in zip(linear_size[:-1], linear_size[1:])) self.output = nn.Linear(linear_size[-1], output_size) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index a75c8d906..6050f7da0 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -7,6 +7,7 @@ import numpy as np import torch +from scipy.special import expit from lightning import Trainer from scipy.sparse import csr_matrix from sklearn.preprocessing import MultiLabelBinarizer @@ -42,13 +43,13 @@ def __init__( word_dict: Optional[dict] = None, ): # The number of levels is set to 2. In other words, there will be 2 models - if config.multiclass: + self.multiclass = config.multiclass + if self.multiclass: raise ValueError( "The label space of multi-class datasets is usually not large, so PLT training is unnecessary." "Please consider other methods." "If you have a multi-class set with numerous labels, please let us know" ) - self.is_multiclass = config.multiclass # cluster self.cluster_size = config.cluster_size @@ -59,10 +60,11 @@ def __init__( self.embed_vecs = embed_vecs self.word_dict = word_dict self.classes = classes + self.max_seq_length = config.max_seq_length self.num_classes = len(classes) - # preprocessor of the datasets - self.preprocessor = None + # multilabel binarizer fitted to the datasets + self.binarizer = None # cluster meta info self.cluster_size = config.cluster_size @@ -74,6 +76,12 @@ def __init__( # optimizer parameters self.optimizer = config.optimizer + self.learning_rate = config.learning_rate + self.momentum = config.momentum + self.weight_decay = config.weight_decay + # learning rate scheduler + self.lr_scheduler = config.lr_scheduler + self.scheduler_config = config.scheduler_config # Trainer parameters self.use_cpu = config.cpu @@ -95,6 +103,8 @@ def __init__( self.result_dir = Path(config.result_dir) self.metrics = config.monitor_metrics + self.metric_threshold = config.metric_threshold + self.monitor_metrics = config.monitor_metrics # dataloader parameters # whether shuffle the training dataset or not during the training process @@ -114,7 +124,7 @@ def __init__( ) # save path - self.config = config + self.log_path = config.log_path def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. @@ -147,22 +157,6 @@ def _label2cluster(y: csr_matrix) -> csr_matrix: return (_label2cluster(y) for y in ys) - def preprocess(self, datasets): - # datasets preprocessing - # Convert training texts and labels to a matrix of tfidf features and a binary sparse matrix indicating the - # presence of a class label, respectively - train_val_dataset = datasets["train"] + datasets["val"] - train_val_dataset = { - "x": [" ".join(i["text"]) for i in train_val_dataset], - "y": [i["label"] for i in train_val_dataset], - } - - self.preprocessor = Preprocessor() - datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} - datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) - - return datasets_temp_tf - def fit(self, datasets): """fit model to the training dataset @@ -179,29 +173,34 @@ def fit(self, datasets): "y": [i["label"] for i in train_val_dataset], } - self.preprocessor = Preprocessor() + # Preprocessor does tf-idf vectorization and multilabel binarization + # For details, see libmultilabel.linear.preprocessor.Preprocessor + preprocessor = Preprocessor() datasets_temp = {"data_format": "txt", "train": train_val_dataset, "classes": self.classes} - datasets_temp_tf = self.preprocessor.fit_transform(datasets_temp) + # Preprocessor requires the input dictionary to has a key named "train" and will return a new dictionary with + # the same key. + train_val_dataset_tf = preprocessor.fit_transform(datasets_temp)["train"] + # save binarizer for testing + self.binarizer = preprocessor.binarizer train_x = self.reformat_text(datasets["train"]) val_x = self.reformat_text(datasets["val"]) - train_y = datasets_temp_tf["train"]["y"][: len(datasets["train"])] - val_y = datasets_temp_tf["train"]["y"][len(datasets["train"]) :] + train_y = train_val_dataset_tf["y"][: len(datasets["train"])] + val_y = train_val_dataset_tf["y"][len(datasets["train"]) :] # clustering build_label_tree( - sparse_x=datasets_temp_tf["train"]["x"], - sparse_y=datasets_temp_tf["train"]["y"], + sparse_x=train_val_dataset_tf["x"], + sparse_y=train_val_dataset_tf["y"], cluster_size=self.cluster_size, output_dir=self.result_dir, ) + clusters = np.load(self.get_cluster_path(), allow_pickle=True) # each y has been mapped to the cluster indices of its parent train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y) - # regard each internal clusters as a "labels" - num_labels = len(clusters) trainer = init_trainer( self.result_dir, @@ -225,51 +224,42 @@ def fit(self, datasets): if not best_model_path.exists(): model_0 = init_model( model_name="AttentionXML_0", - network_config=self.config.network_config, + network_config=self.network_config, classes=clusters, word_dict=self.word_dict, embed_vecs=self.embed_vecs, - init_weight=self.config.init_weight, - log_path=self.config.log_path, - learning_rate=self.config.learning_rate, - optimizer=self.config.optimizer, - momentum=self.config.momentum, - weight_decay=self.config.weight_decay, - lr_scheduler=self.config.lr_scheduler, - scheduler_config=self.config.scheduler_config, - val_metric=self.config.val_metric, - metric_threshold=self.config.metric_threshold, - monitor_metrics=self.config.monitor_metrics, - multiclass=self.config.multiclass, + init_weight=self.init_weight, + log_path=self.log_path, + learning_rate=self.learning_rate, + optimizer=self.optimizer, + momentum=self.momentum, + weight_decay=self.weight_decay, + lr_scheduler=self.lr_scheduler, + scheduler_config=self.scheduler_config, + val_metric=self.val_metric, + metric_threshold=self.metric_threshold, + monitor_metrics=self.monitor_metrics, + multiclass=self.multiclass, loss_function=self.loss_function, silent=self.silent, - save_k_predictions=self.config.save_k_predictions, + save_k_predictions=self.predict_top_k, ) - logger.info(f"Training level 0. Number of labels: {num_labels}") + logger.info(f"Training level 0. Number of clusters: {len(clusters)}") trainer.fit(model_0, train_dataloader, val_dataloader) logger.info(f"Finish training level 0") logger.info(f"Best model loaded from {best_model_path}") model_0 = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs, word_dict=self.word_dict) - # Utilize single GPU to predict - trainer = Trainer( - num_nodes=1, - devices=1, - accelerator=self.accelerator, - logger=False, - ) - logger.info( - f"Generating predictions for level 1. Number of possible predictions: {num_labels}. Top k: {self.predict_top_k}" - ) + logger.info(f"Generating predictions for level 1. Will use the top {self.predict_top_k} predictions") # load training and validation data and predict corresponding level 0 clusters train_pred = trainer.predict(model_0, train_dataloader) val_pred = trainer.predict(model_0, val_dataloader) train_clusters_pred = np.vstack([i["top_k_pred"] for i in train_pred]) - val_scores_pred = np.vstack([i["top_k_pred_scores"] for i in val_pred]) + val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) logger.info("Selecting relevant/irrelevant clusters of each instance for level 1 training") @@ -277,8 +267,8 @@ def fit(self, datasets): for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) - # Select relevant clusters first. Then from top-predicted clusters, sequentially include their labels until the total # of - # until reaching top_k if the number of positive labels is less than top_k. + # Select relevant clusters first. Then from top-predicted clusters, sequentially include their labels until + # reaching self.predict_top_k if the number of positive labels is less than self.predict_top_k if len(pos) <= self.predict_top_k: selected = pos for y in ys: @@ -327,7 +317,7 @@ def fit(self, datasets): ), shuffle=self.shuffle, ) - valid_dataloader = self.dataloader( + val_dataloader = self.dataloader( PLTDataset( val_x, val_y, @@ -350,20 +340,20 @@ def fit(self, datasets): word_dict=self.word_dict, embed_vecs=self.embed_vecs, network=network, - log_path=self.config.log_path, - learning_rate=self.config.learning_rate, - optimizer=self.config.optimizer, - momentum=self.config.momentum, - weight_decay=self.config.weight_decay, - lr_scheduler=self.config.lr_scheduler, - scheduler_config=self.config.scheduler_config, - val_metric=self.config.val_metric, - metric_threshold=self.config.metric_threshold, - monitor_metrics=self.config.monitor_metrics, - multiclass=self.config.multiclass, + log_path=self.log_path, + learning_rate=self.learning_rate, + optimizer=self.optimizer, + momentum=self.momentum, + weight_decay=self.weight_decay, + lr_scheduler=self.lr_scheduler, + scheduler_config=self.scheduler_config, + val_metric=self.val_metric, + metric_threshold=self.metric_threshold, + monitor_metrics=self.monitor_metrics, + multiclass=self.multiclass, loss_function=self.loss_function, silent=self.silent, - save_k_predictions=self.config.save_k_predictions, + save_k_predictions=self.predict_top_k, ) torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) @@ -378,19 +368,19 @@ def fit(self, datasets): f"Training level 1. Number of labels: {self.num_classes}." f"Number of labels selected: {train_dataloader.dataset.num_labels_selected}" ) - trainer.fit(model_1, train_dataloader, valid_dataloader) + trainer.fit(model_1, train_dataloader, val_dataloader) logger.info(f"Best model loaded from {best_model_path}") logger.info(f"Finish training level 1") def test(self, dataset, classes): test_x = self.reformat_text(dataset) - if self.preprocessor is None: + if self.binarizer is None: binarizer = MultiLabelBinarizer(classes=classes, sparse_output=True) binarizer.fit(None) test_y = binarizer.transform((i["label"] for i in dataset)) else: - test_y = self.preprocessor.binarizer.transform((i["label"] for i in dataset)) + test_y = self.binarizer.transform((i["label"] for i in dataset)) logger.info("Testing process started") trainer = Trainer( devices=1, @@ -411,8 +401,8 @@ def test(self, dataset, classes): logger.info(f"Predicting level 0, Top: {self.predict_top_k}") test_pred = trainer.predict(model, test_dataloader) - test_score_pred = np.vstack([i["top_k_pred_scores"] for i in test_pred]) - test_label_pred = np.vstack([i["top_k_pred"] for i in test_pred]) + test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) + test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) @@ -430,8 +420,8 @@ def test(self, dataset, classes): test_y, num_classes=self.num_classes, mapping=clusters, - clusters_selected=test_label_pred, - cluster_scores=test_score_pred, + clusters_selected=test_pred_cluters, + cluster_scores=test_pred_scores, ), ) @@ -445,14 +435,14 @@ def reformat_text(self, dataset): lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64) if text else torch.tensor([self.word_dict[UNK]], dtype=torch.int64), - [instance["text"][: self.config["max_seq_length"]] for instance in dataset], + [instance["text"][: self.max_seq_length] for instance in dataset], ) ) # pad the first entry to be of length 500 if necessary encoded_text[0] = torch.cat( ( encoded_text[0], - torch.tensor(0, dtype=torch.int64).repeat(self.config["max_seq_length"] - encoded_text[0].shape[0]), + torch.tensor(0, dtype=torch.int64).repeat(self.max_seq_length - encoded_text[0].shape[0]), ) ) encoded_text = pad_sequence(encoded_text, batch_first=True) From 33f520e058beb92aa651df038349d5e3984f132f Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 12 Mar 2024 11:58:58 +0400 Subject: [PATCH 12/29] Revert "add function get_logits" This reverts commit 7438bcb6c9ad898adabc340b1c1a8c4ee318c5f5. --- libmultilabel/nn/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 0cc405287..4ceb99d73 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -73,6 +73,11 @@ def shared_step(self, batch): """Return loss and predicted logits""" return NotImplemented + @abstractmethod + def get_logits(self, batch): + """Return predicted logits""" + return NotImplemented + def configure_optimizers(self): """Initialize an optimizer for the free parameters of the network.""" parameters = [p for p in self.parameters() if p.requires_grad] @@ -232,3 +237,6 @@ def shared_step(self, batch): loss = self.loss_function(pred_logits, target_labels.float()) return loss, pred_logits + + def get_logits(self, batch): + return self.network(batch)["logits"] From efc67fa9d499bba60eccc4309431bce37cab7e33 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 12 Mar 2024 15:18:54 +0400 Subject: [PATCH 13/29] improve predicting function of AttentionXML Model --- libmultilabel/nn/model_AttentionXML.py | 10 +++++++--- .../nn/networks/labelwise_attention_networks.py | 6 ++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 45673a569..eb052c172 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -68,9 +68,13 @@ def _shared_eval_step(self, batch, batch_idx): self.eval_metric.update(y_pred, y.long()) def predict_step(self, batch, batch_idx, dataloader_idx=0): - x = batch["text"] labels_selected = batch["labels_selected"] label_scores = batch["label_scores"] - logits = self.network(x, labels_selected=labels_selected)["logits"] + logits = self.network(batch)["logits"] scores, labels = torch.topk(torch.sigmoid(logits) * label_scores, self.top_k) - return scores.numpy(force=True), torch.take_along_dim(labels_selected, labels, dim=1).numpy(force=True) + # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned + logits = torch.logit(scores) + return { + "top_k_pred": torch.take_along_dim(labels_selected, labels, dim=1).numpy(force=True), + "top_k_pred_scores": logits.numpy(force=True), + } diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index 2941972e6..d90256959 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -290,8 +290,8 @@ def __init__( self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) def forward(self, inputs): - # the index of padding is 0 inputs = inputs["text"] + # the index of padding is 0 masks = inputs != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] @@ -322,7 +322,9 @@ def __init__( self.attention = PartialLabelwiseAttention(rnn_dim, num_classes) self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) - def forward(self, inputs, labels_selected): + def forward(self, inputs): + inputs = inputs["text"] + labels_selected = inputs["labels_selected"] # the index of padding is 0 masks = inputs != 0 lengths = masks.sum(dim=1) From 834a8ae90cc6d54a593c2808ed6570e1fa716248 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 12 Mar 2024 15:33:03 +0400 Subject: [PATCH 14/29] add explanatory comments --- libmultilabel/nn/plt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 6050f7da0..2e73aa7c1 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -358,6 +358,7 @@ def fit(self, datasets): torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) logger.info(f"Initialize model with weights from the last level") + # As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict()) model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict()) model_1.network.output.load_state_dict(model_0.network.output.state_dict()) From eb5dfaa5f35695c943bba6a845c6792346195390 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 12 Mar 2024 15:40:12 +0400 Subject: [PATCH 15/29] revert changes to save_hyperparameters in model.py --- libmultilabel/nn/model.py | 2 +- libmultilabel/nn/plt.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 4ceb99d73..09589d215 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -206,7 +206,7 @@ def __init__( ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( - ignore=["log_path", "embed_vecs", "word_dict", "metrics"] + ignore=["log_path"] ) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir). self.word_dict = word_dict self.embed_vecs = embed_vecs diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 2e73aa7c1..61f6a8bb9 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -250,7 +250,7 @@ def fit(self, datasets): logger.info(f"Finish training level 0") logger.info(f"Best model loaded from {best_model_path}") - model_0 = Model.load_from_checkpoint(best_model_path, embed_vecs=self.embed_vecs, word_dict=self.word_dict) + model_0 = Model.load_from_checkpoint(best_model_path) logger.info(f"Generating predictions for level 1. Will use the top {self.predict_top_k} predictions") # load training and validation data and predict corresponding level 0 clusters @@ -392,10 +392,7 @@ def test(self, dataset, classes): # prediction starts from level 0 model = Model.load_from_checkpoint( self.get_best_model_path(level=0), - embed_vecs=self.embed_vecs, - word_dict=self.word_dict, top_k=self.predict_top_k, - metrics=self.metrics, ) test_dataloader = self.eval_dataloader(PlainDataset(test_x)) @@ -409,8 +406,6 @@ def test(self, dataset, classes): model = PLTModel.load_from_checkpoint( self.get_best_model_path(level=1), - embed_vecs=self.embed_vecs, - word_dict=self.word_dict, top_k=self.predict_top_k, metrics=self.metrics, ) From 2458e0dbddc8466ac93d134707315a15f159073a Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 13 Mar 2024 13:39:38 +0400 Subject: [PATCH 16/29] improve test process --- libmultilabel/nn/plt.py | 32 ++++++++++--------- torch_trainer.py | 71 ++++++++++++++++++++++++++++++----------- 2 files changed, 69 insertions(+), 34 deletions(-) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 61f6a8bb9..6cf574691 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -373,7 +373,21 @@ def fit(self, datasets): logger.info(f"Best model loaded from {best_model_path}") logger.info(f"Finish training level 1") - def test(self, dataset, classes): + def test(self, dataset): + # retrieve word_dict from model_1 + # prediction starts from level 0 + model_0 = Model.load_from_checkpoint( + self.get_best_model_path(level=0), + top_k=self.predict_top_k, + ) + model_1 = PLTModel.load_from_checkpoint( + self.get_best_model_path(level=1), + top_k=self.predict_top_k, + metrics=self.metrics, + ) + self.word_dict = model_1.word_dict + classes = model_1.classes + test_x = self.reformat_text(dataset) if self.binarizer is None: @@ -389,27 +403,15 @@ def test(self, dataset, classes): logger=False, ) - # prediction starts from level 0 - model = Model.load_from_checkpoint( - self.get_best_model_path(level=0), - top_k=self.predict_top_k, - ) - test_dataloader = self.eval_dataloader(PlainDataset(test_x)) logger.info(f"Predicting level 0, Top: {self.predict_top_k}") - test_pred = trainer.predict(model, test_dataloader) + test_pred = trainer.predict(model_0, test_dataloader) test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) - model = PLTModel.load_from_checkpoint( - self.get_best_model_path(level=1), - top_k=self.predict_top_k, - metrics=self.metrics, - ) - test_dataloader = self.eval_dataloader( PLTDataset( test_x, @@ -422,7 +424,7 @@ def test(self, dataset, classes): ) logger.info(f"Testing on level 1") - trainer.test(model, test_dataloader) + trainer.test(model_1, test_dataloader) logger.info("Testing process finished") def reformat_text(self, dataset): diff --git a/torch_trainer.py b/torch_trainer.py index 8e23216fd..7b8c57848 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -37,7 +37,6 @@ def __init__( self.run_name = config.run_name self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path - self.classes = classes os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up seed & device @@ -64,26 +63,42 @@ def __init__( else: self.datasets = datasets - if config.embed_file is not None: - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"] - if config.model_name.lower() != "attentionxml" - else self.datasets["train"] + self.datasets["val"], - vocab_file=config.vocab_file, - min_vocab_freq=config.min_vocab_freq, - embed_file=config.embed_file, - silent=config.silent, - normalize_embed=config.normalize_embed, - embed_cache_dir=config.embed_cache_dir, - ) - - if not classes: - self.classes = data_utils.load_or_build_label(self.datasets, config.label_file, config.include_test_labels) - self.config.multiclass = is_multiclass_dataset(self.datasets["train"] + self.datasets.get("val", list())) if self.config.model_name.lower() == "attentionxml": - self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=embed_vecs, word_dict=word_dict) + # Note that AttentionXML produces two models. checkpoint_path directs to model_1 + if config.checkpoint_path is None: + if self.config.embed_file is not None: + logging.info("Load word dictionary ") + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + self.datasets["val"], + vocab_file=config.vocab_file, + min_vocab_freq=config.min_vocab_freq, + embed_file=config.embed_file, + silent=config.silent, + normalize_embed=config.normalize_embed, + embed_cache_dir=config.embed_cache_dir, + ) + + if not classes: + classes = data_utils.load_or_build_label( + self.datasets, self.config.label_file, self.config.include_test_labels + ) + + if self.config.early_stopping_metric not in self.config.monitor_metrics: + logging.warning( + f"{self.config.early_stopping_metric} is not in `monitor_metrics`. " + f"Add {self.config.early_stopping_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.early_stopping_metric] + + if self.config.val_metric not in self.config.monitor_metrics: + logging.warn( + f"{self.config.val_metric} is not in `monitor_metrics`. " + f"Add {self.config.val_metric} to `monitor_metrics`." + ) + self.config.monitor_metrics += [self.config.val_metric] + self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict) else: self._setup_model( word_dict=word_dict, @@ -109,6 +124,7 @@ def __init__( def _setup_model( self, + classes: list = None, word_dict: dict = None, embed_vecs=None, log_path: str = None, @@ -134,6 +150,21 @@ def _setup_model( self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path) else: logging.info("Initialize model from scratch.") + if self.config.embed_file is not None: + logging.info("Load word dictionary ") + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"], + vocab_file=self.config.vocab_file, + min_vocab_freq=self.config.min_vocab_freq, + embed_file=self.config.embed_file, + silent=self.config.silent, + normalize_embed=self.config.normalize_embed, + embed_cache_dir=self.config.embed_cache_dir, + ) + if not classes: + classes = data_utils.load_or_build_label( + self.datasets, self.config.label_file, self.config.include_test_labels + ) if self.config.early_stopping_metric not in self.config.monitor_metrics: logging.warn( @@ -152,7 +183,7 @@ def _setup_model( self.model = init_model( model_name=self.config.model_name, network_config=dict(self.config.network_config), - classes=self.classes, + classes=classes, word_dict=word_dict, embed_vecs=embed_vecs, init_weight=self.config.init_weight, @@ -201,6 +232,8 @@ def train(self): """ if self.config.model_name.lower() == "attentionxml": self.trainer.fit(self.datasets) + + dump_log(self.log_path, config=self.config) else: assert ( self.trainer is not None From bed9855db7c2e99ddb681683aa02bcd7a3ef9dd2 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 13 Mar 2024 14:44:52 +0400 Subject: [PATCH 17/29] align with the last PR --- libmultilabel/nn/model.py | 12 ++--------- libmultilabel/nn/model_AttentionXML.py | 28 +++++++++----------------- 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 09589d215..040d35fb8 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -43,7 +43,7 @@ def __init__( multiclass=False, silent=False, save_k_predictions=0, - **kwargs, + **kwargs ): super().__init__() @@ -73,11 +73,6 @@ def shared_step(self, batch): """Return loss and predicted logits""" return NotImplemented - @abstractmethod - def get_logits(self, batch): - """Return predicted logits""" - return NotImplemented - def configure_optimizers(self): """Initialize an optimizer for the free parameters of the network.""" parameters = [p for p in self.parameters() if p.requires_grad] @@ -202,7 +197,7 @@ def __init__( network, loss_function="binary_cross_entropy_with_logits", log_path=None, - **kwargs, + **kwargs ): super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) self.save_hyperparameters( @@ -237,6 +232,3 @@ def shared_step(self, batch): loss = self.loss_function(pred_logits, target_labels.float()) return loss, pred_logits - - def get_logits(self, batch): - return self.network(batch)["logits"] diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index eb052c172..412bf5377 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -25,7 +25,7 @@ def __init__( **kwargs, ) - def scatter_preds( + def scatter_logits( self, logits: Tensor, labels_selected: Tensor, @@ -51,30 +51,22 @@ def shared_step(self, batch): loss (torch.Tensor): Loss between target and predict logits. pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). """ - x = batch["text"] - y = batch["label"] - labels_selected = batch["labels_selected"] - logits = self.network(x, labels_selected=labels_selected)["logits"] - loss = self.loss_function(logits, torch.take_along_dim(y.float(), labels_selected, dim=1)) + y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) + logits = self(batch) + loss = self.loss_function(logits, y) return loss, logits def _shared_eval_step(self, batch, batch_idx): - x = batch["text"] - y = batch["label"] - labels_selected = batch["labels_selected"] - label_scores = batch["label_scores"] - logits = self.network(x, labels_selected=labels_selected)["logits"] - y_pred = self.scatter_preds(logits, labels_selected, label_scores) - self.eval_metric.update(y_pred, y.long()) + logits = self(batch) + logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"]) + self.eval_metric.update(logits, batch["label"].long()) def predict_step(self, batch, batch_idx, dataloader_idx=0): - labels_selected = batch["labels_selected"] - label_scores = batch["label_scores"] - logits = self.network(batch)["logits"] - scores, labels = torch.topk(torch.sigmoid(logits) * label_scores, self.top_k) + logits = self(batch) + scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.top_k) # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned logits = torch.logit(scores) return { - "top_k_pred": torch.take_along_dim(labels_selected, labels, dim=1).numpy(force=True), + "top_k_pred": torch.take_along_dim(batch["labels_selected"], labels, dim=1).numpy(force=True), "top_k_pred_scores": logits.numpy(force=True), } From 8de834caecf332ab0f2181f7bb47a8f5596a487a Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 13 Mar 2024 16:53:33 +0400 Subject: [PATCH 18/29] minor fix --- libmultilabel/nn/model_AttentionXML.py | 2 +- .../networks/labelwise_attention_networks.py | 16 +-- libmultilabel/nn/plt.py | 60 ++++++--- torch_trainer.py | 117 +++++++++--------- 4 files changed, 111 insertions(+), 84 deletions(-) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 412bf5377..2114aa8b4 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -63,7 +63,7 @@ def _shared_eval_step(self, batch, batch_idx): def predict_step(self, batch, batch_idx, dataloader_idx=0): logits = self(batch) - scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.top_k) + scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.save_k_predictions) # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned logits = torch.logit(scores) return { diff --git a/libmultilabel/nn/networks/labelwise_attention_networks.py b/libmultilabel/nn/networks/labelwise_attention_networks.py index d90256959..738a274df 100644 --- a/libmultilabel/nn/networks/labelwise_attention_networks.py +++ b/libmultilabel/nn/networks/labelwise_attention_networks.py @@ -276,7 +276,7 @@ def __init__( embed_vecs, num_classes: int, rnn_dim: int, - linear_size: list[int, ...], + linear_size: list, freeze_embed_training: bool = False, rnn_layers: int = 1, embed_dropout: float = 0.2, @@ -290,13 +290,13 @@ def __init__( self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) def forward(self, inputs): - inputs = inputs["text"] + x = inputs["text"] # the index of padding is 0 - masks = inputs != 0 + masks = x != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] - x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size + x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size x, _ = self.attention(x) # batch_size, num_classes, hidden_size x = self.output(x) # batch_size, num_classes @@ -309,7 +309,7 @@ def __init__( embed_vecs, num_classes: int, rnn_dim: int, - linear_size: list[int], + linear_size: list, freeze_embed_training: bool = False, rnn_layers: int = 1, embed_dropout: float = 0.2, @@ -323,14 +323,14 @@ def __init__( self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) def forward(self, inputs): - inputs = inputs["text"] + x = inputs["text"] labels_selected = inputs["labels_selected"] # the index of padding is 0 - masks = inputs != 0 + masks = x != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] - x = self.embedding(inputs)[:, : lengths.max()] # batch_size, length, embedding_size + x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size x = self.output(x) # batch_size, sample_size diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 6cf574691..f64c06d34 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -100,7 +100,7 @@ def __init__( self.patience = config.patience # ModelCheckpoint self.val_metric = config.val_metric - self.result_dir = Path(config.result_dir) + self.checkpoint_dir = Path(config.checkpoint_dir) self.metrics = config.monitor_metrics self.metric_threshold = config.metric_threshold @@ -133,11 +133,11 @@ def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: The clusters of the instance are [0, 2]. Args: - cluster_mapping: the clusters generated at a pre-defined level. - *ys: labels for train and/or valid datasets. + cluster_mapping: mapping from clusters to labels. + *ys: sparse labels. Returns: - Generator[csr_matrix]: the mapped labels (ancestor clusters) for train and/or valid datasets. + Generator[csr_matrix]: clusters generated from labels """ mapping = np.empty(self.num_classes, dtype=np.uint64) for idx, clusters in enumerate(cluster_mapping): @@ -157,11 +157,36 @@ def _label2cluster(y: csr_matrix) -> csr_matrix: return (_label2cluster(y) for y in ys) + # def cluster2label(self, cluster_mapping, *ys): + # """Map clusters to their corresponding labels. Notice this function only deals with dense matrix. + # + # Args: + # cluster_mapping: mapping from clusters to labels. + # *ys: sparse clusters. + # + # Returns: + # Generator[csr_matrix]: labels generated from clusters + # """ + # + # def _cluster2label(y: csr_matrix) -> csr_matrix: + # self.labels_selected = [np.concatenate(cluster_mapping[labels]) for labels in y] + # return (_cluster2label(y) for y in ys) + + # def generate_goals(self, cluster_scores, y): + # if cluster_scores is not None: + # # label_scores are corresponding scores for selected labels and + # # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) + # # notice how scores repeat for each cluster. + # self.label_scores = [ + # np.repeat(scores, [len(i) for i in cluster_mapping[labels]]) + # for labels, scores in zip(y, cluster_scores) + # ] + def fit(self, datasets): """fit model to the training dataset Args: - datasets: dict of train, val, and test + datasets: dict containing training, validation, and/or test datasets """ if self.get_best_model_path(level=1).exists(): return @@ -189,12 +214,13 @@ def fit(self, datasets): train_y = train_val_dataset_tf["y"][: len(datasets["train"])] val_y = train_val_dataset_tf["y"][len(datasets["train"]) :] - # clustering + # clusters are saved to the disk so that users doesn't need to provide the original training data when they want + # to do predicting solely build_label_tree( sparse_x=train_val_dataset_tf["x"], sparse_y=train_val_dataset_tf["y"], cluster_size=self.cluster_size, - output_dir=self.result_dir, + output_dir=self.checkpoint_dir, ) clusters = np.load(self.get_cluster_path(), allow_pickle=True) @@ -203,7 +229,7 @@ def fit(self, datasets): train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y) trainer = init_trainer( - self.result_dir, + self.checkpoint_dir, epochs=self.epochs, patience=self.patience, early_stopping_metric=self.early_stopping_metric, @@ -262,13 +288,15 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - logger.info("Selecting relevant/irrelevant clusters of each instance for level 1 training") + logger.info( + "Selecting relevant/irrelevant clusters of each instance for generating labels for level 1 training" + ) clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) - # Select relevant clusters first. Then from top-predicted clusters, sequentially include their labels until - # reaching self.predict_top_k if the number of positive labels is less than self.predict_top_k + # Select relevant clusters first. Then from top-predicted clusters, sequentially include them until + # clusters reach top_k if len(pos) <= self.predict_top_k: selected = pos for y in ys: @@ -292,7 +320,7 @@ def fit(self, datasets): clusters_selected[i] = np.asarray(list(selected)) trainer = init_trainer( - self.result_dir, + self.checkpoint_dir, epochs=self.epochs, patience=self.patience, early_stopping_metric=self.val_metric, @@ -378,11 +406,11 @@ def test(self, dataset): # prediction starts from level 0 model_0 = Model.load_from_checkpoint( self.get_best_model_path(level=0), - top_k=self.predict_top_k, + save_k_predictions=self.predict_top_k, ) model_1 = PLTModel.load_from_checkpoint( self.get_best_model_path(level=1), - top_k=self.predict_top_k, + save_k_predictions=self.predict_top_k, metrics=self.metrics, ) self.word_dict = model_1.word_dict @@ -447,7 +475,7 @@ def reformat_text(self, dataset): return encoded_text def get_best_model_path(self, level: int) -> Path: - return self.result_dir / f"{self.CHECKPOINT_NAME}{level}{ModelCheckpoint.FILE_EXTENSION}" + return self.checkpoint_dir / f"{self.CHECKPOINT_NAME}{level}{ModelCheckpoint.FILE_EXTENSION}" def get_cluster_path(self) -> Path: - return self.result_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}" + return self.checkpoint_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}" diff --git a/torch_trainer.py b/torch_trainer.py index 7b8c57848..349edd36f 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -99,28 +99,29 @@ def __init__( ) self.config.monitor_metrics += [self.config.val_metric] self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict) - else: - self._setup_model( - word_dict=word_dict, - embed_vecs=embed_vecs, - log_path=self.log_path, - checkpoint_path=config.checkpoint_path, - ) - self.trainer = init_trainer( - checkpoint_dir=self.checkpoint_dir, - epochs=config.epochs, - patience=config.patience, - early_stopping_metric=config.early_stopping_metric, - val_metric=config.val_metric, - silent=config.silent, - use_cpu=config.cpu, - limit_train_batches=config.limit_train_batches, - limit_val_batches=config.limit_val_batches, - limit_test_batches=config.limit_test_batches, - save_checkpoints=save_checkpoints, - ) - callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] - self.checkpoint_callback = callbacks[0] if callbacks else None + return + self._setup_model( + classes=classes, + word_dict=word_dict, + embed_vecs=embed_vecs, + log_path=self.log_path, + checkpoint_path=config.checkpoint_path, + ) + self.trainer = init_trainer( + checkpoint_dir=self.checkpoint_dir, + epochs=config.epochs, + patience=config.patience, + early_stopping_metric=config.early_stopping_metric, + val_metric=config.val_metric, + silent=config.silent, + use_cpu=config.cpu, + limit_train_batches=config.limit_train_batches, + limit_val_batches=config.limit_val_batches, + limit_test_batches=config.limit_test_batches, + save_checkpoints=save_checkpoints, + ) + callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] + self.checkpoint_callback = callbacks[0] if callbacks else None def _setup_model( self, @@ -234,37 +235,35 @@ def train(self): self.trainer.fit(self.datasets) dump_log(self.log_path, config=self.config) + return + assert ( + self.trainer is not None + ), "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." + train_loader = self._get_dataset_loader(split="train", shuffle=self.config.shuffle) + + if "val" not in self.datasets: + logging.info("No validation dataset is provided. Train without vaildation.") + self.trainer.fit(self.model, train_loader) else: - assert ( - self.trainer is not None - ), "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." - train_loader = self._get_dataset_loader(split="train", shuffle=self.config.shuffle) - - if "val" not in self.datasets: - logging.info("No validation dataset is provided. Train without vaildation.") - self.trainer.fit(self.model, train_loader) - else: - val_loader = self._get_dataset_loader(split="val") - self.trainer.fit(self.model, train_loader, val_loader) - - # Set model to the best model. If the validation process is skipped during - # training (i.e., val_size=0), the model is set to the last model. - model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path - if model_path: - logging.info(f"Finished training. Load best model from {model_path}.") - self._setup_model(checkpoint_path=model_path, log_path=self.log_path) - else: - logging.info( - "No model is saved during training. \ - If you want to save the best and the last model, please set `save_checkpoints` to True." - ) + val_loader = self._get_dataset_loader(split="val") + self.trainer.fit(self.model, train_loader, val_loader) + + # Set model to the best model. If the validation process is skipped during + # training (i.e., val_size=0), the model is set to the last model. + model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path + if model_path: + logging.info(f"Finished training. Load best model from {model_path}.") + self._setup_model(checkpoint_path=model_path, log_path=self.log_path) + else: + logging.info( + "No model is saved during training. \ + If you want to save the best and the last model, please set `save_checkpoints` to True." + ) - dump_log(self.log_path, config=self.config) + dump_log(self.log_path, config=self.config) - # return best model score for ray - return ( - self.checkpoint_callback.best_model_score.item() if self.checkpoint_callback.best_model_score else None - ) + # return best model score for ray + return self.checkpoint_callback.best_model_score.item() if self.checkpoint_callback.best_model_score else None def test(self, split="test"): """Test model with pytorch lightning trainer. Top-k predictions are saved @@ -279,17 +278,17 @@ def test(self, split="test"): assert "test" in self.datasets and self.trainer is not None if self.config.model_name.lower() == "attentionxml": - self.trainer.test(self.datasets["test"], self.classes) - else: - logging.info(f"Testing on {split} set.") - test_loader = self._get_dataset_loader(split=split) - metric_dict = self.trainer.test(self.model, dataloaders=test_loader, verbose=False)[0] + self.trainer.test(self.datasets["test"]) + return + logging.info(f"Testing on {split} set.") + test_loader = self._get_dataset_loader(split=split) + metric_dict = self.trainer.test(self.model, dataloaders=test_loader, verbose=False)[0] - if self.config.save_k_predictions > 0: - self._save_predictions(test_loader, self.config.predict_out_path) + if self.config.save_k_predictions > 0: + self._save_predictions(test_loader, self.config.predict_out_path) - dump_log(self.log_path, config=self.config) - return metric_dict + dump_log(self.log_path, config=self.config) + return metric_dict def _save_predictions(self, dataloader, predict_out_path): """Save top k label results. From 34fa22ad52fa88ac877ad61ec23a872ec1ceb9b1 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 18 Mar 2024 14:42:19 +0400 Subject: [PATCH 19/29] fix according to feedback --- libmultilabel/nn/datasets_AttentionXML.py | 9 ++++----- libmultilabel/nn/model_AttentionXML.py | 3 ++- libmultilabel/nn/plt.py | 15 ++++++++------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index f2f31d7ea..e1e3b50f5 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -79,16 +79,15 @@ def __init__( self.cluster_scores = cluster_scores self.label_scores = None - # labels_selected are positive clusters at the current level. shape: (len(x), cluster_size * top_k) - # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] + # labels_selected are labels extracted from cluster_selected. + # shape: (len(x), len(clusters_selected) * cluster_size) self.labels_selected = [ np.concatenate(self.mapping[labels]) for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters") ] if self.cluster_scores is not None: - # label_scores are corresponding scores for selected labels and - # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) - # notice how scores repeat for each cluster. + # label_scores are probability scores corresponding to labels_selected. + # shape: (len(x), len(clusters_selected) * cluster_size) self.label_scores = [ np.repeat(scores, [len(i) for i in self.mapping[labels]]) for labels, scores in zip(self.clusters_selected, self.cluster_scores) diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 2114aa8b4..88b9685fb 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -31,7 +31,8 @@ def scatter_logits( labels_selected: Tensor, label_scores: Tensor, ) -> Tensor: - """map predictions from sample space to label space. The scores of unsampled labels are set to 0.""" + """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to + the whole label space. The scores of unsampled labels are set to 0.""" src = torch.sigmoid(logits.detach()) * label_scores preds = torch.zeros( labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index f64c06d34..b91a1368c 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -278,7 +278,10 @@ def fit(self, datasets): logger.info(f"Best model loaded from {best_model_path}") model_0 = Model.load_from_checkpoint(best_model_path) - logger.info(f"Generating predictions for level 1. Will use the top {self.predict_top_k} predictions") + logger.info( + f"Predicting clusters by level 0 model. We then select {self.predict_top_k} clusters and use then " + f"to extract labels for level 1 training." + ) # load training and validation data and predict corresponding level 0 clusters train_pred = trainer.predict(model_0, train_dataloader) @@ -288,15 +291,12 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - logger.info( - "Selecting relevant/irrelevant clusters of each instance for generating labels for level 1 training" - ) clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) # Select relevant clusters first. Then from top-predicted clusters, sequentially include them until - # clusters reach top_k + # cluster number reaches predict_top_k if len(pos) <= self.predict_top_k: selected = pos for y in ys: @@ -383,9 +383,9 @@ def fit(self, datasets): silent=self.silent, save_k_predictions=self.predict_top_k, ) + logger.info(f"Initialize model with weights from level 0") + # For weights not initialized by the level-0 model, use xavier uniform initialization torch.nn.init.xavier_uniform_(model_1.network.attention.attention.weight) - - logger.info(f"Initialize model with weights from the last level") # As the attention layer of model 1 is different from model 0, each layer needs to be initialized separately model_1.network.embedding.load_state_dict(model_0.network.embedding.state_dict()) model_1.network.encoder.load_state_dict(model_0.network.encoder.state_dict()) @@ -456,6 +456,7 @@ def test(self, dataset): logger.info("Testing process finished") def reformat_text(self, dataset): + # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length. encoded_text = list( map( lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64) From 8cdd36b1c6901a21322b8c3ff8d7b24197c6d28a Mon Sep 17 00:00:00 2001 From: Dongli He Date: Thu, 21 Mar 2024 12:25:07 +0400 Subject: [PATCH 20/29] minor fix --- example_config/AmazonCat-13K/attentionxml.yml | 42 ------- libmultilabel/nn/cluster.py | 2 +- libmultilabel/nn/data_utils.py | 1 - libmultilabel/nn/datasets_AttentionXML.py | 91 ++++++-------- libmultilabel/nn/model_AttentionXML.py | 6 +- libmultilabel/nn/plt.py | 118 ++++++++++-------- 6 files changed, 107 insertions(+), 153 deletions(-) delete mode 100644 example_config/AmazonCat-13K/attentionxml.yml diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml deleted file mode 100644 index 9c416ed29..000000000 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ /dev/null @@ -1,42 +0,0 @@ -data_name: AmazonCat-13K -# will change path later -training_file: data/AmazonCat-13K/train.txt -test_file: data/AmazonCat-13K/test.txt -# pretrained embeddings -embed_file: glove.840B.300d - -# preprocessing -min_vocab_freq: 1 -max_seq_length: 500 - -# label tree related parameters -cluster_size: 8 -save_k_predictions: 64 - -# data -batch_size: 200 -val_size: 4000 -shuffle: true - -# eval -eval_batch_size: 200 -monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] -val_metric: nDCG@5 - -# train -seed: 1337 -epochs: 10 -optimizer: adam -learning_rate: 0.001 -# early stopping -patience: 10 - -# model -model_name: AttentionXML -network_config: - embed_dropout: 0.2 - post_encoder_dropout: 0.5 - rnn_dim: 1024 - rnn_layers: 1 - linear_size: [512, 256] - freeze_embed_training: false diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index 8cbcc658b..35617531d 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -54,7 +54,7 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i for _ in range(height): next_clusters = [] for cluster in clusters: - next_clusters.extend(_split_cluster(cluster, label_repr[cluster])) + next_clusters += _split_cluster(cluster, label_repr[cluster]) clusters = next_clusters logger.info(f"Having grouped {len(clusters)} clusters") diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index 863a54036..d558737de 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -17,7 +17,6 @@ transformers.logging.set_verbosity_error() warnings.simplefilter(action="ignore", category=FutureWarning) -# selection of UNK: https://groups.google.com/g/globalvectors/c/9w8ZADXJclA/m/hRdn4prm-XUJ UNK = "" PAD = "" diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index e1e3b50f5..8a91e4bad 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -3,12 +3,10 @@ from typing import Sequence, Optional import numpy as np -import torch from numpy import ndarray from scipy.sparse import csr_matrix, issparse from torch import Tensor, is_tensor from torch.utils.data import Dataset -from tqdm import tqdm class PlainDataset(Dataset): @@ -19,11 +17,11 @@ class PlainDataset(Dataset): this while generating clusters. There is no need to do multilabel binarization again. Args: - x: texts - y: labels + x (list | ndarray | Tensor): texts. + y (Optional: csr_matrix | ndarray | Tensor): labels. """ - def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): + def __init__(self, x, y=None): if y is not None: assert len(x) == y.shape[0], "Sizes mismatch between texts and labels" self.x = x @@ -32,16 +30,18 @@ def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: item = {"text": self.x[idx]} - # train/valid/test + # train/val/test if self.y is not None: if issparse(self.y): - y = self.y[idx].toarray().squeeze(0) - elif is_tensor(self.y) or isinstance(self.y, (ndarray, torch.Tensor)): - y = self.y[idx] + y = self.y[idx].toarray().squeeze(0).astype(np.float32) + elif isinstance(self.y, ndarray): + y = self.y[idx].astype(np.float32) + elif is_tensor(self.y): + y = self.y[idx].float() else: raise TypeError( - "The type of y should be one of scipy.csr_matrix, torch.Tensor, and numpy.ndarry." - f"Instead, got {type(self.y)}." + "The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor." + f"But got {type(self.y)} instead." ) item["label"] = y return item @@ -54,12 +54,12 @@ class PLTDataset(PlainDataset): """Dataset for model_1 of AttentionXML. Args: - x: texts - y: labels + x: texts. + y: labels. num_classes: number of classes. - mapping: mapping from clusters to labels. Shape: (len(clusters), cluster_size). - clusters_selected: sampled predicted clusters from model_0. Shape: (len(x), predict_top_k). - cluster_scores: corresponding scores. Shape: (len(x), predict_top_k) + num_labels_selected: the number of selected labels. Pad any labels that fail to reach this number. + labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k). + label_scores: scores for each label. Shape: (len(x), predict_top_k). """ def __init__( @@ -68,58 +68,43 @@ def __init__( y: Optional[csr_matrix | ndarray] = None, *, num_classes: int, - mapping: ndarray, - clusters_selected: ndarray | Tensor, - cluster_scores: Optional[ndarray | Tensor] = None, + num_labels_selected: int, + labels_selected: ndarray | Tensor, + label_scores: Optional[ndarray | Tensor] = None, ): super().__init__(x, y) self.num_classes = num_classes - self.mapping = mapping - self.clusters_selected = clusters_selected - self.cluster_scores = cluster_scores - self.label_scores = None - - # labels_selected are labels extracted from cluster_selected. - # shape: (len(x), len(clusters_selected) * cluster_size) - self.labels_selected = [ - np.concatenate(self.mapping[labels]) - for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters") - ] - if self.cluster_scores is not None: - # label_scores are probability scores corresponding to labels_selected. - # shape: (len(x), len(clusters_selected) * cluster_size) - self.label_scores = [ - np.repeat(scores, [len(i) for i in self.mapping[labels]]) - for labels, scores in zip(self.clusters_selected, self.cluster_scores) - ] - - # top_k * n (n <= cluster_size). number of maximum possible number selected labels at the current level. - self.num_labels_selected = self.clusters_selected.shape[1] * max(len(clusters) for clusters in self.mapping) + self.num_labels_selected = num_labels_selected + self.labels_selected = labels_selected + self.label_scores = label_scores def __getitem__(self, idx: int): - item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx], dtype=np.int64)} + item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])} - # train/valid/test if self.y is not None: - item["label"] = self.y[idx].toarray().squeeze(0) + item["label"] = self.y[idx].toarray().squeeze(0).astype(np.float32) + # PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected # train if self.label_scores is None: - # As networks require input to be of fixed shape, randomly select labels when the number of selected label - # is not enough + # add real labels when the number is below num_labels_selected if len(item["labels_selected"]) < self.num_labels_selected: - sample = np.random.randint( - self.num_classes, size=self.num_labels_selected - len(item["labels_selected"]) + samples = np.random.randint( + self.num_classes, + size=self.num_labels_selected - len(item["labels_selected"]), ) - item["labels_selected"] = np.concatenate([item["labels_selected"], sample]) - # valid/test + item["labels_selected"] = np.concatenate([item["labels_selected"], samples]) + + # val/test/pred else: item["label_scores"] = self.label_scores[idx] - - # add dummy elements when less than required + # add dummy labels when the number is below num_labels_selected if len(item["labels_selected"]) < self.num_labels_selected: item["label_scores"] = np.concatenate( - [item["label_scores"], [-np.inf] * (self.num_labels_selected - len(item["labels_selected"]))] + [ + item["label_scores"], + [-np.inf] * (self.num_labels_selected - len(item["labels_selected"])), + ] ) item["labels_selected"] = np.concatenate( [ @@ -127,6 +112,4 @@ def __getitem__(self, idx: int): [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])), ] ) - - item["label_scores"] = np.asarray(item["label_scores"], dtype=np.float32) return item diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 88b9685fb..60c2c8147 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -53,9 +53,9 @@ def shared_step(self, batch): pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). """ y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) - logits = self(batch) - loss = self.loss_function(logits, y) - return loss, logits + pred_logits = self(batch) + loss = self.loss_function(pred_logits, y) + return loss, pred_logits def _shared_eval_step(self, batch, batch_idx): logits = self(batch) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index b91a1368c..54cbfc36f 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -126,61 +126,66 @@ def __init__( # save path self.log_path = config.log_path - def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: + def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. - - Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys of a given instance is [0, 1, 4]. - The clusters of the instance are [0, 2]. + Notice that this function deals with SPARSE matrix. + Assume there are 6 labels clustered as [(0, 1), (2, 3), (4, 5)]. Here (0, 1) is cluster with index 0 and so on. + Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2]. Args: - cluster_mapping: mapping from clusters to labels. - *ys: sparse labels. + cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + *labels (csr_matrix): labels in CSR sparse format. Returns: - Generator[csr_matrix]: clusters generated from labels + Generator[csr_matrix]: resulting clusters converted from labels in CSR sparse format """ - mapping = np.empty(self.num_classes, dtype=np.uint64) + mapping = np.empty(self.num_classes, dtype=np.uint32) for idx, clusters in enumerate(cluster_mapping): mapping[clusters] = idx - def _label2cluster(y: csr_matrix) -> csr_matrix: + def _label2cluster(label: csr_matrix) -> csr_matrix: row = [] col = [] data = [] - for i in range(y.shape[0]): + for i in range(label.shape[0]): # n include all mapped ancestor clusters - n = np.unique(mapping[y.indices[y.indptr[i] : y.indptr[i + 1]]]) + n = np.unique(mapping[label.indices[label.indptr[i] : label.indptr[i + 1]]]) row += [i] * len(n) col += n.tolist() data += [1] * len(n) - return csr_matrix((data, (row, col)), shape=(y.shape[0], len(cluster_mapping))) - - return (_label2cluster(y) for y in ys) - - # def cluster2label(self, cluster_mapping, *ys): - # """Map clusters to their corresponding labels. Notice this function only deals with dense matrix. - # - # Args: - # cluster_mapping: mapping from clusters to labels. - # *ys: sparse clusters. - # - # Returns: - # Generator[csr_matrix]: labels generated from clusters - # """ - # - # def _cluster2label(y: csr_matrix) -> csr_matrix: - # self.labels_selected = [np.concatenate(cluster_mapping[labels]) for labels in y] - # return (_cluster2label(y) for y in ys) - - # def generate_goals(self, cluster_scores, y): - # if cluster_scores is not None: - # # label_scores are corresponding scores for selected labels and - # # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) - # # notice how scores repeat for each cluster. - # self.label_scores = [ - # np.repeat(scores, [len(i) for i in cluster_mapping[labels]]) - # for labels, scores in zip(y, cluster_scores) - # ] + return csr_matrix((data, (row, col)), shape=(label.shape[0], len(cluster_mapping))) + + return (_label2cluster(label) for label in labels) + + @staticmethod + def cluster2label(cluster_mapping, clusters, cluster_scores=None): + """Expand clusters to their corresponding labels and, if available, assign scores to each label. + Labels inside the same cluster have the same scores. This function is applied to predictions from model 0. + Notice that the behaviors of this function are different from label2cluster. + Also notice that this function deals with DENSE matrix. + + Args: + cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + clusters (np.ndarray): predicted clusters from model 0. + cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0. + + Returns: + Generator[np.ndarray]: resulting labels expanded from clusters + """ + + labels_selected = [] + + if cluster_scores is not None: + # label_scores are corresponding scores for selected labels and + # shape: (len(x), cluster_size * top_k) + label_scores = [] + for score, cluster in zip(cluster_scores, clusters): + label_scores += [np.repeat(score, [len(labels) for labels in cluster_mapping[cluster]])] + labels_selected += [np.concatenate(cluster_mapping[cluster])] + return labels_selected, label_scores + else: + labels_selected = [np.concatenate(cluster_mapping[cluster]) for cluster in clusters] + return labels_selected def fit(self, datasets): """fit model to the training dataset @@ -284,6 +289,9 @@ def fit(self, datasets): ) # load training and validation data and predict corresponding level 0 clusters + train_dataloader = self.dataloader(PlainDataset(train_x)) + val_dataloader = self.dataloader(PlainDataset(val_x)) + train_pred = trainer.predict(model_0, train_dataloader) val_pred = trainer.predict(model_0, val_dataloader) @@ -291,7 +299,7 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) + train_clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.uint) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) @@ -317,7 +325,11 @@ def fit(self, datasets): break if len(selected) < self.predict_top_k: selected = (list(selected) + list(pos - selected))[: self.predict_top_k] - clusters_selected[i] = np.asarray(list(selected)) + train_clusters_selected[i] = np.asarray(list(selected)) + + train_labels_selected = PLTTrainer.cluster2label(clusters, train_clusters_selected) + val_labels_pred, val_scores_pred = PLTTrainer.cluster2label(clusters, val_clusters_pred, val_scores_pred) + num_labels_selected = self.predict_top_k * max(len(c) for c in clusters) trainer = init_trainer( self.checkpoint_dir, @@ -334,14 +346,14 @@ def fit(self, datasets): ) trainer.checkpoint_callback.filename = f"{self.CHECKPOINT_NAME}1" - # train & valid dataloaders for training + # train & val dataloaders for training train_dataloader = self.dataloader( PLTDataset( train_x, train_y, num_classes=self.num_classes, - mapping=clusters, - clusters_selected=clusters_selected, + num_labels_selected=num_labels_selected, + labels_selected=train_labels_selected, ), shuffle=self.shuffle, ) @@ -350,9 +362,9 @@ def fit(self, datasets): val_x, val_y, num_classes=self.num_classes, - mapping=clusters, - clusters_selected=val_clusters_pred, - cluster_scores=val_scores_pred, + num_labels_selected=num_labels_selected, + labels_selected=val_labels_pred, + label_scores=val_scores_pred, ), ) @@ -435,19 +447,21 @@ def test(self, dataset): logger.info(f"Predicting level 0, Top: {self.predict_top_k}") test_pred = trainer.predict(model_0, test_dataloader) - test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) - test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred]) + test_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) + test_clusters_pred = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) + test_labels_pred, test_scores_pred = PLTTrainer.cluster2label(clusters, test_clusters_pred, test_scores_pred) + num_labels_selected = self.predict_top_k * max(len(c) for c in clusters) test_dataloader = self.eval_dataloader( PLTDataset( test_x, test_y, num_classes=self.num_classes, - mapping=clusters, - clusters_selected=test_pred_cluters, - cluster_scores=test_pred_scores, + num_labels_selected=num_labels_selected, + labels_selected=test_labels_pred, + label_scores=test_scores_pred, ), ) From b56506647d89b870c8343fffc4c92b0c1cd42643 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 2 Apr 2024 13:41:26 +0400 Subject: [PATCH 21/29] add new argument beam_width --- example_config/EUR-Lex/attentionxml.yml | 2 +- example_config/Wiki10-31K/attentionxml.yml | 2 +- libmultilabel/nn/plt.py | 38 +++++++++++----------- main.py | 6 ++++ 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index 87299e2d6..adcc6d8fc 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # AttentionXML-related parameters cluster_size: 8 -save_k_predictions: 64 +beam_width: 64 # dataloader batch_size: 40 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 30fab98f5..dac9fddb1 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -save_k_predictions: 64 +beam_width: 64 # dataloader batch_size: 40 diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 54cbfc36f..c942f2ab5 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -54,7 +54,8 @@ def __init__( # cluster self.cluster_size = config.cluster_size # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training - self.predict_top_k = config.save_k_predictions + self.beam_width = config.beam_width + self.save_k_predictions = config.save_k_predictions # dataset meta info self.embed_vecs = embed_vecs @@ -273,7 +274,7 @@ def fit(self, datasets): multiclass=self.multiclass, loss_function=self.loss_function, silent=self.silent, - save_k_predictions=self.predict_top_k, + save_k_predictions=self.beam_width, ) logger.info(f"Training level 0. Number of clusters: {len(clusters)}") @@ -284,11 +285,10 @@ def fit(self, datasets): model_0 = Model.load_from_checkpoint(best_model_path) logger.info( - f"Predicting clusters by level 0 model. We then select {self.predict_top_k} clusters and use then " - f"to extract labels for level 1 training." + f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters and " + f"extract labels from them for level 1 training." ) # load training and validation data and predict corresponding level 0 clusters - train_dataloader = self.dataloader(PlainDataset(train_x)) val_dataloader = self.dataloader(PlainDataset(val_x)) @@ -299,17 +299,17 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - train_clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.uint) + train_clusters_selected = np.empty((len(train_x), self.beam_width), dtype=np.uint) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) # Select relevant clusters first. Then from top-predicted clusters, sequentially include them until - # cluster number reaches predict_top_k - if len(pos) <= self.predict_top_k: + # cluster number reaches beam width + if len(pos) <= self.beam_width: selected = pos for y in ys: y = y.item() - if len(selected) == self.predict_top_k: + if len(selected) == self.beam_width: break selected.add(y) # Regard positive (true) label as samples iff they appear in the predicted labels @@ -321,15 +321,15 @@ def fit(self, datasets): y = y.item() if y in pos: selected.add(y) - if len(selected) == self.predict_top_k: + if len(selected) == self.beam_width: break - if len(selected) < self.predict_top_k: - selected = (list(selected) + list(pos - selected))[: self.predict_top_k] + if len(selected) < self.beam_width: + selected = (list(selected) + list(pos - selected))[: self.beam_width] train_clusters_selected[i] = np.asarray(list(selected)) train_labels_selected = PLTTrainer.cluster2label(clusters, train_clusters_selected) val_labels_pred, val_scores_pred = PLTTrainer.cluster2label(clusters, val_clusters_pred, val_scores_pred) - num_labels_selected = self.predict_top_k * max(len(c) for c in clusters) + num_labels_selected = self.beam_width * max(len(c) for c in clusters) trainer = init_trainer( self.checkpoint_dir, @@ -393,7 +393,7 @@ def fit(self, datasets): multiclass=self.multiclass, loss_function=self.loss_function, silent=self.silent, - save_k_predictions=self.predict_top_k, + save_k_predictions=self.save_k_predictions, ) logger.info(f"Initialize model with weights from level 0") # For weights not initialized by the level-0 model, use xavier uniform initialization @@ -418,11 +418,11 @@ def test(self, dataset): # prediction starts from level 0 model_0 = Model.load_from_checkpoint( self.get_best_model_path(level=0), - save_k_predictions=self.predict_top_k, + save_k_predictions=self.beam_width, ) model_1 = PLTModel.load_from_checkpoint( self.get_best_model_path(level=1), - save_k_predictions=self.predict_top_k, + save_k_predictions=self.save_k_predictions, metrics=self.metrics, ) self.word_dict = model_1.word_dict @@ -445,14 +445,14 @@ def test(self, dataset): test_dataloader = self.eval_dataloader(PlainDataset(test_x)) - logger.info(f"Predicting level 0, Top: {self.predict_top_k}") + logger.info(f"Predicting level 0. Number of clusters: {self.beam_width}") test_pred = trainer.predict(model_0, test_dataloader) test_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) test_clusters_pred = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) test_labels_pred, test_scores_pred = PLTTrainer.cluster2label(clusters, test_clusters_pred, test_scores_pred) - num_labels_selected = self.predict_top_k * max(len(c) for c in clusters) + num_labels_selected = self.beam_width * max(len(c) for c in clusters) test_dataloader = self.eval_dataloader( PLTDataset( @@ -465,7 +465,7 @@ def test(self, dataset): ), ) - logger.info(f"Testing on level 1") + logger.info(f"Testing level 1") trainer.test(model_1, test_dataloader) logger.info("Testing process finished") diff --git a/main.py b/main.py index 80d8d256e..5342bc50c 100644 --- a/main.py +++ b/main.py @@ -223,6 +223,12 @@ def add_all_arguments(parser): parser.add_argument( "--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)" ) + parser.add_argument( + "--beam_width", + type=int, + default=100, + help="The top k clusters predicted by the last model to be fed to the next one (default: %(default)s)", + ) # AttentionXML parser.add_argument( "--cluster_size", From 690e53e1ba4d9b5d0e99951988978da1cefbb791 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 2 Apr 2024 13:50:51 +0400 Subject: [PATCH 22/29] add save_predictions --- example_config/AmazonCat-13K/attentionxml.yml | 42 +++++++++++++++++++ libmultilabel/nn/plt.py | 17 ++++++++ 2 files changed, 59 insertions(+) create mode 100644 example_config/AmazonCat-13K/attentionxml.yml diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml new file mode 100644 index 000000000..575e5ef12 --- /dev/null +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -0,0 +1,42 @@ +data_name: AmazonCat-13K +# will change path later +training_file: data/AmazonCat-13K/train.txt +test_file: data/AmazonCat-13K/test.txt +# pretrained embeddings +embed_file: glove.840B.300d + +# preprocessing +min_vocab_freq: 1 +max_seq_length: 500 + +# label tree related parameters +cluster_size: 8 +beam_width: 64 + +# data +batch_size: 200 +val_size: 4000 +shuffle: true + +# eval +eval_batch_size: 200 +monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] +val_metric: nDCG@5 + +# train +seed: 1337 +epochs: 10 +optimizer: adam +learning_rate: 0.001 +# early stopping +patience: 10 + +# model +model_name: AttentionXML +network_config: + embed_dropout: 0.2 + post_encoder_dropout: 0.5 + rnn_dim: 1024 + rnn_layers: 1 + linear_size: [512, 256] + freeze_embed_training: false diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index c942f2ab5..d91f2f848 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -21,6 +21,7 @@ from .data_utils import UNK from .datasets_AttentionXML import PlainDataset, PLTDataset from .model_AttentionXML import PLTModel +from ..common_utils import dump_log from ..linear.preprocessor import Preprocessor from ..nn import networks from ..nn.model import Model @@ -126,6 +127,8 @@ def __init__( # save path self.log_path = config.log_path + self.predict_out_path = config.predict_out_path + self.config = config def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. @@ -469,6 +472,20 @@ def test(self, dataset): trainer.test(model_1, test_dataloader) logger.info("Testing process finished") + if self.save_k_predictions > 0: + batch_predictions = trainer.predict(model_1, test_dataloader) + pred_labels = np.vstack([batch["top_k_pred"] for batch in batch_predictions]) + pred_scores = np.vstack([batch["top_k_pred_scores"] for batch in batch_predictions]) + with open(self.predict_out_path, "w") as fp: + for pred_label, pred_score in zip(pred_labels, pred_scores): + out_str = " ".join( + [f"{model_1.classes[label]}:{score:.4}" for label, score in zip(pred_label, pred_score)] + ) + fp.write(out_str + "\n") + logging.info(f"Saved predictions to: {self.predict_out_path}") + + dump_log(self.log_path, config=self.config) + def reformat_text(self, dataset): # Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length. encoded_text = list( From 2ab951c238dcf001a3d1b07f8f7d13b2a2b3640a Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 2 Apr 2024 15:39:55 +0400 Subject: [PATCH 23/29] change the returned dtype of labels by Dataset to int32 --- libmultilabel/nn/datasets_AttentionXML.py | 8 ++++---- libmultilabel/nn/model_AttentionXML.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index 8a91e4bad..d6b935b0b 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -33,11 +33,11 @@ def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: # train/val/test if self.y is not None: if issparse(self.y): - y = self.y[idx].toarray().squeeze(0).astype(np.float32) + y = self.y[idx].toarray().squeeze(0).astype(np.int32) elif isinstance(self.y, ndarray): - y = self.y[idx].astype(np.float32) + y = self.y[idx].astype(np.int32) elif is_tensor(self.y): - y = self.y[idx].float() + y = self.y[idx].int() else: raise TypeError( "The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor." @@ -82,7 +82,7 @@ def __getitem__(self, idx: int): item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])} if self.y is not None: - item["label"] = self.y[idx].toarray().squeeze(0).astype(np.float32) + item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32) # PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected # train diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 60c2c8147..feb2e2f8f 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -54,7 +54,7 @@ def shared_step(self, batch): """ y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) pred_logits = self(batch) - loss = self.loss_function(pred_logits, y) + loss = self.loss_function(pred_logits, y.float()) return loss, pred_logits def _shared_eval_step(self, batch, batch_idx): From c0b97807fe8925306f531e010206cccf5cb78563 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Tue, 2 Apr 2024 18:25:34 +0400 Subject: [PATCH 24/29] workaround test --- example_config/AmazonCat-13K/attentionxml.yml | 2 +- example_config/EUR-Lex/attentionxml.yml | 2 +- example_config/Wiki10-31K/attentionxml.yml | 2 +- libmultilabel/nn/plt.py | 7 +++++-- main.py | 6 ------ 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 575e5ef12..9c416ed29 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -11,7 +11,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -beam_width: 64 +save_k_predictions: 64 # data batch_size: 200 diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index adcc6d8fc..87299e2d6 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # AttentionXML-related parameters cluster_size: 8 -beam_width: 64 +save_k_predictions: 64 # dataloader batch_size: 40 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index dac9fddb1..30fab98f5 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -beam_width: 64 +save_k_predictions: 64 # dataloader batch_size: 40 diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index d91f2f848..213bfb34d 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -55,8 +55,8 @@ def __init__( # cluster self.cluster_size = config.cluster_size # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training - self.beam_width = config.beam_width - self.save_k_predictions = config.save_k_predictions + # self.beam_width = config.beam_width + self.beam_width = config.save_k_predictions # dataset meta info self.embed_vecs = embed_vecs @@ -125,6 +125,9 @@ def __init__( batch_size=config.eval_batch_size, ) + # predict + self.save_k_predictions = config.save_k_predictions + # save path self.log_path = config.log_path self.predict_out_path = config.predict_out_path diff --git a/main.py b/main.py index 5342bc50c..80d8d256e 100644 --- a/main.py +++ b/main.py @@ -223,12 +223,6 @@ def add_all_arguments(parser): parser.add_argument( "--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)" ) - parser.add_argument( - "--beam_width", - type=int, - default=100, - help="The top k clusters predicted by the last model to be fed to the next one (default: %(default)s)", - ) # AttentionXML parser.add_argument( "--cluster_size", From 5330b8b3e5f1c3c1fa59888184ee9b0ada81be96 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 3 Apr 2024 17:36:05 +0400 Subject: [PATCH 25/29] use beam_width --- example_config/AmazonCat-13K/attentionxml.yml | 2 +- example_config/EUR-Lex/attentionxml.yml | 2 +- example_config/Wiki10-31K/attentionxml.yml | 2 +- libmultilabel/nn/plt.py | 3 +-- main.py | 6 ++++++ 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 9c416ed29..575e5ef12 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -11,7 +11,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -save_k_predictions: 64 +beam_width: 64 # data batch_size: 200 diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index 87299e2d6..adcc6d8fc 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # AttentionXML-related parameters cluster_size: 8 -save_k_predictions: 64 +beam_width: 64 # dataloader batch_size: 40 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index 30fab98f5..dac9fddb1 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -13,7 +13,7 @@ max_seq_length: 500 # label tree related parameters cluster_size: 8 -save_k_predictions: 64 +beam_width: 64 # dataloader batch_size: 40 diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 213bfb34d..61aef1d5e 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -55,8 +55,7 @@ def __init__( # cluster self.cluster_size = config.cluster_size # predict the top k clusters for deciding relevant/irrelevant labels of each instance in level 1 model training - # self.beam_width = config.beam_width - self.beam_width = config.save_k_predictions + self.beam_width = config.beam_width # dataset meta info self.embed_vecs = embed_vecs diff --git a/main.py b/main.py index 80d8d256e..26538d365 100644 --- a/main.py +++ b/main.py @@ -223,6 +223,12 @@ def add_all_arguments(parser): parser.add_argument( "--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)" ) + parser.add_argument( + "--beam_width", + type=int, + default=10, + help="The width of the beam search (default: %(default)s)", + ) # AttentionXML parser.add_argument( "--cluster_size", From 322d8376ac3dbb123565028f04e7a10f5b86092c Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 3 Apr 2024 17:44:47 +0400 Subject: [PATCH 26/29] lowercase filenames --- .../nn/{datasets_AttentionXML.py => dataset_attentionxml.py} | 0 .../nn/{model_AttentionXML.py => model_attentionxml.py} | 0 libmultilabel/nn/plt.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename libmultilabel/nn/{datasets_AttentionXML.py => dataset_attentionxml.py} (100%) rename libmultilabel/nn/{model_AttentionXML.py => model_attentionxml.py} (100%) diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/dataset_attentionxml.py similarity index 100% rename from libmultilabel/nn/datasets_AttentionXML.py rename to libmultilabel/nn/dataset_attentionxml.py diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_attentionxml.py similarity index 100% rename from libmultilabel/nn/model_AttentionXML.py rename to libmultilabel/nn/model_attentionxml.py diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 61aef1d5e..1ae40b9eb 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -19,8 +19,8 @@ from .cluster import CLUSTER_NAME, FILE_EXTENSION as CLUSTER_FILE_EXTENSION, build_label_tree from .data_utils import UNK -from .datasets_AttentionXML import PlainDataset, PLTDataset -from .model_AttentionXML import PLTModel +from .dataset_attentionxml import PlainDataset, PLTDataset +from .model_attentionxml import PLTModel from ..common_utils import dump_log from ..linear.preprocessor import Preprocessor from ..nn import networks From 48de2edeb417297ea11a63dc470b7577bfeba2ac Mon Sep 17 00:00:00 2001 From: Dongli He Date: Fri, 5 Apr 2024 15:27:12 +0400 Subject: [PATCH 27/29] combine into single file --- docs/tutorial.rst | 2 + libmultilabel/nn/{plt.py => attentionxml.py} | 299 ++++++++++++++++++- libmultilabel/nn/cluster.py | 114 ------- libmultilabel/nn/dataset_attentionxml.py | 115 ------- libmultilabel/nn/model_attentionxml.py | 73 ----- torch_trainer.py | 2 +- 6 files changed, 293 insertions(+), 312 deletions(-) rename libmultilabel/nn/{plt.py => attentionxml.py} (64%) delete mode 100644 libmultilabel/nn/cluster.py delete mode 100644 libmultilabel/nn/dataset_attentionxml.py delete mode 100644 libmultilabel/nn/model_attentionxml.py diff --git a/docs/tutorial.rst b/docs/tutorial.rst index a5bc458c9..d6119d07c 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -5,6 +5,7 @@ Tutorials * `Feature Generation and Parameter Selection for Linear Methods <./auto_examples/plot_linear_gridsearch_tutorial.html>`_ * `Parameter Selection for Neural Networks `_ * `Handling Data with Many Labels <./auto_examples/plot_linear_tree_tutorial.html>`_ +* `Implement Extreme Multi-Label Text Classification with AttentionXML <./auto_examples/plot_AttentionXML_tutorial.html>`_ .. toctree:: @@ -15,4 +16,5 @@ Tutorials ../auto_examples/plot_linear_gridsearch_tutorial tutorials/Parameter_Selection_for_Neural_Networks ../auto_examples/plot_linear_tree_tutorial + ../auto_examples/plot_AttentionXML_tutorial \ No newline at end of file diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/attentionxml.py similarity index 64% rename from libmultilabel/nn/plt.py rename to libmultilabel/nn/attentionxml.py index 1ae40b9eb..54d19509c 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/attentionxml.py @@ -3,24 +3,22 @@ import logging from functools import partial from pathlib import Path -from typing import Generator, Optional +from typing import Generator, Sequence, Optional import numpy as np import torch -from scipy.special import expit from lightning import Trainer -from scipy.sparse import csr_matrix -from sklearn.preprocessing import MultiLabelBinarizer -from torch import Tensor +from numpy import ndarray +from scipy.sparse import csr_matrix, csc_matrix, issparse +from scipy.special import expit +from sklearn.preprocessing import MultiLabelBinarizer, normalize +from torch import Tensor, is_tensor from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from lightning.pytorch.callbacks import ModelCheckpoint -from .cluster import CLUSTER_NAME, FILE_EXTENSION as CLUSTER_FILE_EXTENSION, build_label_tree from .data_utils import UNK -from .dataset_attentionxml import PlainDataset, PLTDataset -from .model_attentionxml import PLTModel from ..common_utils import dump_log from ..linear.preprocessor import Preprocessor from ..nn import networks @@ -513,3 +511,286 @@ def get_best_model_path(self, level: int) -> Path: def get_cluster_path(self) -> Path: return self.checkpoint_dir / f"{CLUSTER_NAME}{CLUSTER_FILE_EXTENSION}" + + +###################################### Model ###################################### + + +class PLTModel(Model): + def __init__( + self, + classes, + word_dict, + embed_vecs, + network, + loss_function="binary_cross_entropy_with_logits", + log_path=None, + **kwargs, + ): + super().__init__( + classes=classes, + word_dict=word_dict, + embed_vecs=embed_vecs, + network=network, + loss_function=loss_function, + log_path=log_path, + **kwargs, + ) + + def scatter_logits( + self, + logits: Tensor, + labels_selected: Tensor, + label_scores: Tensor, + ) -> Tensor: + """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to + the whole label space. The scores of unsampled labels are set to 0.""" + src = torch.sigmoid(logits.detach()) * label_scores + preds = torch.zeros( + labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype + ) + preds.scatter_(dim=1, index=labels_selected, src=src) + # remove dummy labels + preds = preds[:, :-1] + return preds + + def shared_step(self, batch): + """Return loss and predicted logits of the network. + + Args: + batch (dict): A batch of text and label. + + Returns: + loss (torch.Tensor): Loss between target and predict logits. + pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). + """ + y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) + pred_logits = self(batch) + loss = self.loss_function(pred_logits, y.float()) + return loss, pred_logits + + def _shared_eval_step(self, batch, batch_idx): + logits = self(batch) + logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"]) + self.eval_metric.update(logits, batch["label"].long()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + logits = self(batch) + scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.save_k_predictions) + # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned + logits = torch.logit(scores) + return { + "top_k_pred": torch.take_along_dim(batch["labels_selected"], labels, dim=1).numpy(force=True), + "top_k_pred_scores": logits.numpy(force=True), + } + + +###################################### Dataset ###################################### +class PlainDataset(Dataset): + """Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset. + WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset + does not. + Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does + this while generating clusters. There is no need to do multilabel binarization again. + + Args: + x (list | ndarray | Tensor): texts. + y (Optional: csr_matrix | ndarray | Tensor): labels. + """ + + def __init__(self, x, y=None): + if y is not None: + assert len(x) == y.shape[0], "Sizes mismatch between texts and labels" + self.x = x + self.y = y + + def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: + item = {"text": self.x[idx]} + + # train/val/test + if self.y is not None: + if issparse(self.y): + y = self.y[idx].toarray().squeeze(0).astype(np.int32) + elif isinstance(self.y, ndarray): + y = self.y[idx].astype(np.int32) + elif is_tensor(self.y): + y = self.y[idx].int() + else: + raise TypeError( + "The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor." + f"But got {type(self.y)} instead." + ) + item["label"] = y + return item + + def __len__(self): + return len(self.x) + + +class PLTDataset(PlainDataset): + """Dataset for model_1 of AttentionXML. + + Args: + x: texts. + y: labels. + num_classes: number of classes. + num_labels_selected: the number of selected labels. Pad any labels that fail to reach this number. + labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k). + label_scores: scores for each label. Shape: (len(x), predict_top_k). + """ + + def __init__( + self, + x, + y: Optional[csr_matrix | ndarray] = None, + *, + num_classes: int, + num_labels_selected: int, + labels_selected: ndarray | Tensor, + label_scores: Optional[ndarray | Tensor] = None, + ): + super().__init__(x, y) + self.num_classes = num_classes + self.num_labels_selected = num_labels_selected + self.labels_selected = labels_selected + self.label_scores = label_scores + + def __getitem__(self, idx: int): + item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])} + + if self.y is not None: + item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32) + + # PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected + # train + if self.label_scores is None: + # add real labels when the number is below num_labels_selected + if len(item["labels_selected"]) < self.num_labels_selected: + samples = np.random.randint( + self.num_classes, + size=self.num_labels_selected - len(item["labels_selected"]), + ) + item["labels_selected"] = np.concatenate([item["labels_selected"], samples]) + + # val/test/pred + else: + item["label_scores"] = self.label_scores[idx] + # add dummy labels when the number is below num_labels_selected + if len(item["labels_selected"]) < self.num_labels_selected: + item["label_scores"] = np.concatenate( + [ + item["label_scores"], + [-np.inf] * (self.num_labels_selected - len(item["labels_selected"])), + ] + ) + item["labels_selected"] = np.concatenate( + [ + item["labels_selected"], + [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])), + ] + ) + return item + + +###################################### Cluster ###################################### + +CLUSTER_FILE_EXTENSION = CLUSTER_NAME = "label_clusters" +FILE_EXTENSION = ".npy" + + +def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): + """Build a binary tree to group labels into clusters, each of which contains up tp cluster_size labels. The tree has + several layers; nodes in the last layer correspond to the output clusters. + Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters look something like: + ((0, 2), (1, 3), (4, 5)). + + Args: + sparse_x: features extracted from texts in CSR sparse format + sparse_y: binarized labels in CSR sparse format + cluster_size: the maximum number of labels within each cluster + output_dir: directory to store the clustering file + """ + # skip constructing label tree if the output file already exists + output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir) + cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}" + if cluster_path.exists(): + logger.info("Clustering has finished in a previous run") + return + + # meta info + logger.info("Label clustering started") + logger.info(f"Cluster size: {cluster_size}") + # The height of the tree satisfies the following inequality: + # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size + height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size))) + logger.info(f"Labels will be grouped into {2 ** height} clusters") + + output_dir.mkdir(parents=True, exist_ok=True) + + # For each label, sum up normalized instances relevant to the label and normalize to get the label representation + label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x))) + + # clustering by a binary tree: + # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters. + clusters = [np.arange(sparse_y.shape[1])] + for _ in range(height): + next_clusters = [] + for cluster in clusters: + next_clusters += _split_cluster(cluster, label_repr[cluster]) + clusters = next_clusters + logger.info(f"Having grouped {len(clusters)} clusters") + + np.save(cluster_path, np.asarray(clusters, dtype=object)) + logger.info(f"Label clustering finished. Saving results to {cluster_path}") + + +def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]: + """A variant of KMeans implemented in AttentionXML. Here K = 2. The cluster is partitioned into two groups, each + with approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are: + 1. the distance metric is cosine similarity. + 2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids. + + Args: + cluster: a subset of labels + label_repr: the normalized representations of the relationship between labels and texts of the given cluster + """ + # Randomly choose two points as initial centroids and obtain their label representations + centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray() + + # Initialize distances (cosine similarity) + # Cosine similarity always falls to the interval [-1, 1] + old_dist = -2.0 + new_dist = -1.0 + + # "c" denotes clusters + c0_idx = None + c1_idx = None + + while new_dist - old_dist >= 1e-4: + # Notice that label_repr and centroids.T have been normalized + # Thus, dist indicates the cosine similarity between points and centroids. + dist = label_repr @ centroids.T # shape: (n, 2) + + # generate clusters + # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1 + k = len(cluster) // 2 + c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k) + c0_idx = c_idx[:k] + c1_idx = c_idx[k:] + + # update distances + # the new distance is the average of in-cluster distances to the centroids + old_dist = new_dist + new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster) + + # update centroids + # the new centroid is the normalized average of the points in the cluster + centroids = normalize( + np.asarray( + [ + np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))), + np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))), + ] + ) + ) + return cluster[c0_idx], cluster[c1_idx] diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py deleted file mode 100644 index 35617531d..000000000 --- a/libmultilabel/nn/cluster.py +++ /dev/null @@ -1,114 +0,0 @@ -from __future__ import annotations - -import logging -from pathlib import Path - -import numpy as np -from numpy import ndarray -from scipy.sparse import csc_matrix, csr_matrix -from sklearn.preprocessing import normalize - -__all__ = ["CLUSTER_NAME", "FILE_EXTENSION", "build_label_tree"] - -logger = logging.getLogger(__name__) - -CLUSTER_NAME = "label_clusters" -FILE_EXTENSION = ".npy" - - -def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): - """Build a binary tree to group labels into clusters, each of which contains up tp cluster_size labels. The tree has - several layers; nodes in the last layer correspond to the output clusters. - Given a set of labels (0, 1, 2, 3, 4, 5) and a cluster size of 2, the resulting clusters look something like: - ((0, 2), (1, 3), (4, 5)). - - Args: - sparse_x: features extracted from texts in CSR sparse format - sparse_y: binarized labels in CSR sparse format - cluster_size: the maximum number of labels within each cluster - output_dir: directory to store the clustering file - """ - # skip constructing label tree if the output file already exists - output_dir = output_dir if isinstance(output_dir, Path) else Path(output_dir) - cluster_path = output_dir / f"{CLUSTER_NAME}{FILE_EXTENSION}" - if cluster_path.exists(): - logger.info("Clustering has finished in a previous run") - return - - # meta info - logger.info("Label clustering started") - logger.info(f"Cluster size: {cluster_size}") - # The height of the tree satisfies the following inequality: - # 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size - height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size))) - logger.info(f"Labels will be grouped into {2**height} clusters") - - output_dir.mkdir(parents=True, exist_ok=True) - - # For each label, sum up normalized instances relevant to the label and normalize to get the label representation - label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x))) - - # clustering by a binary tree: - # at each layer split each cluster to two. Leave nodes correspond to the obtained clusters. - clusters = [np.arange(sparse_y.shape[1])] - for _ in range(height): - next_clusters = [] - for cluster in clusters: - next_clusters += _split_cluster(cluster, label_repr[cluster]) - clusters = next_clusters - logger.info(f"Having grouped {len(clusters)} clusters") - - np.save(cluster_path, np.asarray(clusters, dtype=object)) - logger.info(f"Label clustering finished. Saving results to {cluster_path}") - - -def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, ndarray]: - """A variant of KMeans implemented in AttentionXML. Here K = 2. The cluster is partitioned into two groups, each - with approximately equal size. Its main differences with the KMeans algorithm in scikit-learn are: - 1. the distance metric is cosine similarity. - 2. the end-of-loop criterion is the difference between the new and old average in-cluster distances to centroids. - - Args: - cluster: a subset of labels - label_repr: the normalized representations of the relationship between labels and texts of the given cluster - """ - # Randomly choose two points as initial centroids and obtain their label representations - centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray() - - # Initialize distances (cosine similarity) - # Cosine similarity always falls to the interval [-1, 1] - old_dist = -2.0 - new_dist = -1.0 - - # "c" denotes clusters - c0_idx = None - c1_idx = None - - while new_dist - old_dist >= 1e-4: - # Notice that label_repr and centroids.T have been normalized - # Thus, dist indicates the cosine similarity between points and centroids. - dist = label_repr @ centroids.T # shape: (n, 2) - - # generate clusters - # let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1 - k = len(cluster) // 2 - c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k) - c0_idx = c_idx[:k] - c1_idx = c_idx[k:] - - # update distances - # the new distance is the average of in-cluster distances to the centroids - old_dist = new_dist - new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster) - - # update centroids - # the new centroid is the normalized average of the points in the cluster - centroids = normalize( - np.asarray( - [ - np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))), - np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))), - ] - ) - ) - return cluster[c0_idx], cluster[c1_idx] diff --git a/libmultilabel/nn/dataset_attentionxml.py b/libmultilabel/nn/dataset_attentionxml.py deleted file mode 100644 index d6b935b0b..000000000 --- a/libmultilabel/nn/dataset_attentionxml.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -from typing import Sequence, Optional - -import numpy as np -from numpy import ndarray -from scipy.sparse import csr_matrix, issparse -from torch import Tensor, is_tensor -from torch.utils.data import Dataset - - -class PlainDataset(Dataset): - """Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset. - WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset - does not. - Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does - this while generating clusters. There is no need to do multilabel binarization again. - - Args: - x (list | ndarray | Tensor): texts. - y (Optional: csr_matrix | ndarray | Tensor): labels. - """ - - def __init__(self, x, y=None): - if y is not None: - assert len(x) == y.shape[0], "Sizes mismatch between texts and labels" - self.x = x - self.y = y - - def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: - item = {"text": self.x[idx]} - - # train/val/test - if self.y is not None: - if issparse(self.y): - y = self.y[idx].toarray().squeeze(0).astype(np.int32) - elif isinstance(self.y, ndarray): - y = self.y[idx].astype(np.int32) - elif is_tensor(self.y): - y = self.y[idx].int() - else: - raise TypeError( - "The type of y should be one of scipy.csr_matrix, numpy.ndarry, and torch.Tensor." - f"But got {type(self.y)} instead." - ) - item["label"] = y - return item - - def __len__(self): - return len(self.x) - - -class PLTDataset(PlainDataset): - """Dataset for model_1 of AttentionXML. - - Args: - x: texts. - y: labels. - num_classes: number of classes. - num_labels_selected: the number of selected labels. Pad any labels that fail to reach this number. - labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k). - label_scores: scores for each label. Shape: (len(x), predict_top_k). - """ - - def __init__( - self, - x, - y: Optional[csr_matrix | ndarray] = None, - *, - num_classes: int, - num_labels_selected: int, - labels_selected: ndarray | Tensor, - label_scores: Optional[ndarray | Tensor] = None, - ): - super().__init__(x, y) - self.num_classes = num_classes - self.num_labels_selected = num_labels_selected - self.labels_selected = labels_selected - self.label_scores = label_scores - - def __getitem__(self, idx: int): - item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])} - - if self.y is not None: - item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32) - - # PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected - # train - if self.label_scores is None: - # add real labels when the number is below num_labels_selected - if len(item["labels_selected"]) < self.num_labels_selected: - samples = np.random.randint( - self.num_classes, - size=self.num_labels_selected - len(item["labels_selected"]), - ) - item["labels_selected"] = np.concatenate([item["labels_selected"], samples]) - - # val/test/pred - else: - item["label_scores"] = self.label_scores[idx] - # add dummy labels when the number is below num_labels_selected - if len(item["labels_selected"]) < self.num_labels_selected: - item["label_scores"] = np.concatenate( - [ - item["label_scores"], - [-np.inf] * (self.num_labels_selected - len(item["labels_selected"])), - ] - ) - item["labels_selected"] = np.concatenate( - [ - item["labels_selected"], - [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])), - ] - ) - return item diff --git a/libmultilabel/nn/model_attentionxml.py b/libmultilabel/nn/model_attentionxml.py deleted file mode 100644 index feb2e2f8f..000000000 --- a/libmultilabel/nn/model_attentionxml.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -from torch import Tensor - -from .model import Model - - -class PLTModel(Model): - def __init__( - self, - classes, - word_dict, - embed_vecs, - network, - loss_function="binary_cross_entropy_with_logits", - log_path=None, - **kwargs, - ): - super().__init__( - classes=classes, - word_dict=word_dict, - embed_vecs=embed_vecs, - network=network, - loss_function=loss_function, - log_path=log_path, - **kwargs, - ) - - def scatter_logits( - self, - logits: Tensor, - labels_selected: Tensor, - label_scores: Tensor, - ) -> Tensor: - """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to - the whole label space. The scores of unsampled labels are set to 0.""" - src = torch.sigmoid(logits.detach()) * label_scores - preds = torch.zeros( - labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype - ) - preds.scatter_(dim=1, index=labels_selected, src=src) - # remove dummy labels - preds = preds[:, :-1] - return preds - - def shared_step(self, batch): - """Return loss and predicted logits of the network. - - Args: - batch (dict): A batch of text and label. - - Returns: - loss (torch.Tensor): Loss between target and predict logits. - pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). - """ - y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) - pred_logits = self(batch) - loss = self.loss_function(pred_logits, y.float()) - return loss, pred_logits - - def _shared_eval_step(self, batch, batch_idx): - logits = self(batch) - logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"]) - self.eval_metric.update(logits, batch["label"].long()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - logits = self(batch) - scores, labels = torch.topk(torch.sigmoid(logits) * batch["label_scores"], self.save_k_predictions) - # This calculation is to align with LibMultiLabel class where logits rather than probabilities are returned - logits = torch.logit(scores) - return { - "top_k_pred": torch.take_along_dim(batch["labels_selected"], labels, dim=1).numpy(force=True), - "top_k_pred_scores": logits.numpy(force=True), - } diff --git a/torch_trainer.py b/torch_trainer.py index 349edd36f..8dc259b5d 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -9,7 +9,7 @@ from libmultilabel.nn import data_utils from libmultilabel.nn.model import Model from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer, set_seed -from libmultilabel.nn.plt import PLTTrainer +from libmultilabel.nn.attentionxml import PLTTrainer class TorchTrainer: From 06d3ddc842218c1a78c3537baa32f60a342c5b2b Mon Sep 17 00:00:00 2001 From: Dongli He Date: Fri, 5 Apr 2024 16:03:12 +0400 Subject: [PATCH 28/29] explain +1 in scatter_logits --- docs/tutorial.rst | 2 -- libmultilabel/nn/attentionxml.py | 33 +++++++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/docs/tutorial.rst b/docs/tutorial.rst index d6119d07c..a5bc458c9 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -5,7 +5,6 @@ Tutorials * `Feature Generation and Parameter Selection for Linear Methods <./auto_examples/plot_linear_gridsearch_tutorial.html>`_ * `Parameter Selection for Neural Networks `_ * `Handling Data with Many Labels <./auto_examples/plot_linear_tree_tutorial.html>`_ -* `Implement Extreme Multi-Label Text Classification with AttentionXML <./auto_examples/plot_AttentionXML_tutorial.html>`_ .. toctree:: @@ -16,5 +15,4 @@ Tutorials ../auto_examples/plot_linear_gridsearch_tutorial tutorials/Parameter_Selection_for_Neural_Networks ../auto_examples/plot_linear_tree_tutorial - ../auto_examples/plot_AttentionXML_tutorial \ No newline at end of file diff --git a/libmultilabel/nn/attentionxml.py b/libmultilabel/nn/attentionxml.py index 54d19509c..c327cbea1 100644 --- a/libmultilabel/nn/attentionxml.py +++ b/libmultilabel/nn/attentionxml.py @@ -137,7 +137,7 @@ def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]: Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2]. Args: - cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels . *labels (csr_matrix): labels in CSR sparse format. Returns: @@ -169,7 +169,7 @@ def cluster2label(cluster_mapping, clusters, cluster_scores=None): Also notice that this function deals with DENSE matrix. Args: - cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels . clusters (np.ndarray): predicted clusters from model 0. cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0. @@ -234,7 +234,7 @@ def fit(self, datasets): clusters = np.load(self.get_cluster_path(), allow_pickle=True) - # each y has been mapped to the cluster indices of its parent + # map each y to the parent cluster indices train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y) trainer = init_trainer( @@ -288,8 +288,8 @@ def fit(self, datasets): model_0 = Model.load_from_checkpoint(best_model_path) logger.info( - f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters and " - f"extract labels from them for level 1 training." + f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and " + f"extract labels from these clusters for level 1 training." ) # load training and validation data and predict corresponding level 0 clusters train_dataloader = self.dataloader(PlainDataset(train_x)) @@ -546,11 +546,14 @@ def scatter_logits( """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to the whole label space. The scores of unsampled labels are set to 0.""" src = torch.sigmoid(logits.detach()) * label_scores + # During validation/testing, many fake labels might exist in a batch for the purpose of padding. + # A fake label has index len(classes) and does not belong to the real label space. preds = torch.zeros( labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype ) preds.scatter_(dim=1, index=labels_selected, src=src) - # remove dummy labels + # slicing removes fake labels whose index is exactly len(self.classes) + # afterwards, preds is restored to the real label space preds = preds[:, :-1] return preds @@ -586,6 +589,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): ###################################### Dataset ###################################### + + class PlainDataset(Dataset): """Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset. WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset @@ -634,7 +639,7 @@ class PLTDataset(PlainDataset): x: texts. y: labels. num_classes: number of classes. - num_labels_selected: the number of selected labels. Pad any labels that fail to reach this number. + num_labels_selected: the number of selected labels. labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k). label_scores: scores for each label. Shape: (len(x), predict_top_k). """ @@ -661,10 +666,12 @@ def __getitem__(self, idx: int): if self.y is not None: item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32) - # PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected - # train + # PyTorch requires inputs to be of the same shape. Pad any instance with length below num_labels_selected by + # randomly selecting labels. + # training if self.label_scores is None: - # add real labels when the number is below num_labels_selected + # randomly add real labels when the number is below num_labels_selected + # some labels might be selected more than once if len(item["labels_selected"]) < self.num_labels_selected: samples = np.random.randint( self.num_classes, @@ -675,7 +682,7 @@ def __getitem__(self, idx: int): # val/test/pred else: item["label_scores"] = self.label_scores[idx] - # add dummy labels when the number is below num_labels_selected + # add fake labels when the number of labels is below num_labels_selected if len(item["labels_selected"]) < self.num_labels_selected: item["label_scores"] = np.concatenate( [ @@ -694,8 +701,8 @@ def __getitem__(self, idx: int): ###################################### Cluster ###################################### -CLUSTER_FILE_EXTENSION = CLUSTER_NAME = "label_clusters" -FILE_EXTENSION = ".npy" +CLUSTER_NAME = "label_clusters" +CLUSTER_FILE_EXTENSION = FILE_EXTENSION = ".npy" def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: int, output_dir: str | Path): From a6521f835fdcfc283143c5700089d6740880ceae Mon Sep 17 00:00:00 2001 From: Dongli He Date: Mon, 8 Apr 2024 16:42:25 +0400 Subject: [PATCH 29/29] remove absolute paths --- example_config/AmazonCat-13K/attentionxml.yml | 1 - example_config/EUR-Lex/attentionxml.yml | 3 --- example_config/Wiki10-31K/attentionxml.yml | 3 --- 3 files changed, 7 deletions(-) diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml index 575e5ef12..e6131df5b 100644 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ b/example_config/AmazonCat-13K/attentionxml.yml @@ -1,5 +1,4 @@ data_name: AmazonCat-13K -# will change path later training_file: data/AmazonCat-13K/train.txt test_file: data/AmazonCat-13K/test.txt # pretrained embeddings diff --git a/example_config/EUR-Lex/attentionxml.yml b/example_config/EUR-Lex/attentionxml.yml index adcc6d8fc..558f4cdc9 100644 --- a/example_config/EUR-Lex/attentionxml.yml +++ b/example_config/EUR-Lex/attentionxml.yml @@ -3,9 +3,6 @@ training_file: data/EUR-Lex/train.txt test_file: data/EUR-Lex/test.txt # pretrained embeddings embed_file: glove.840B.300d -# save path -# will change path later -result_dir: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/results # preprocessing min_vocab_freq: 1 diff --git a/example_config/Wiki10-31K/attentionxml.yml b/example_config/Wiki10-31K/attentionxml.yml index dac9fddb1..bbe2774a5 100644 --- a/example_config/Wiki10-31K/attentionxml.yml +++ b/example_config/Wiki10-31K/attentionxml.yml @@ -1,11 +1,8 @@ data_name: Wiki10-31K -# will change path later training_file: data/Wiki10-31K/train.txt test_file: data/Wiki10-31K/test.txt # pretrained embeddings embed_file: glove.840B.300d -# save path -result_dir: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/results # preprocessing min_vocab_freq: 1