From d9c9b505806d35b23b95925079d2bbcb067d4aed Mon Sep 17 00:00:00 2001 From: Dongli He Date: Thu, 21 Mar 2024 12:25:07 +0400 Subject: [PATCH] remove label selecting in dataset --- example_config/AmazonCat-13K/attentionxml.yml | 42 ------- libmultilabel/nn/cluster.py | 2 +- libmultilabel/nn/datasets_AttentionXML.py | 75 +++--------- libmultilabel/nn/model_AttentionXML.py | 21 +--- libmultilabel/nn/plt.py | 109 +++++++++--------- 5 files changed, 71 insertions(+), 178 deletions(-) delete mode 100644 example_config/AmazonCat-13K/attentionxml.yml diff --git a/example_config/AmazonCat-13K/attentionxml.yml b/example_config/AmazonCat-13K/attentionxml.yml deleted file mode 100644 index 9c416ed2..00000000 --- a/example_config/AmazonCat-13K/attentionxml.yml +++ /dev/null @@ -1,42 +0,0 @@ -data_name: AmazonCat-13K -# will change path later -training_file: data/AmazonCat-13K/train.txt -test_file: data/AmazonCat-13K/test.txt -# pretrained embeddings -embed_file: glove.840B.300d - -# preprocessing -min_vocab_freq: 1 -max_seq_length: 500 - -# label tree related parameters -cluster_size: 8 -save_k_predictions: 64 - -# data -batch_size: 200 -val_size: 4000 -shuffle: true - -# eval -eval_batch_size: 200 -monitor_metrics: [P@1, P@3, P@5, nDCG@3, nDCG@5, RP@3, RP@5] -val_metric: nDCG@5 - -# train -seed: 1337 -epochs: 10 -optimizer: adam -learning_rate: 0.001 -# early stopping -patience: 10 - -# model -model_name: AttentionXML -network_config: - embed_dropout: 0.2 - post_encoder_dropout: 0.5 - rnn_dim: 1024 - rnn_layers: 1 - linear_size: [512, 256] - freeze_embed_training: false diff --git a/libmultilabel/nn/cluster.py b/libmultilabel/nn/cluster.py index 8cbcc658..35617531 100644 --- a/libmultilabel/nn/cluster.py +++ b/libmultilabel/nn/cluster.py @@ -54,7 +54,7 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i for _ in range(height): next_clusters = [] for cluster in clusters: - next_clusters.extend(_split_cluster(cluster, label_repr[cluster])) + next_clusters += _split_cluster(cluster, label_repr[cluster]) clusters = next_clusters logger.info(f"Having grouped {len(clusters)} clusters") diff --git a/libmultilabel/nn/datasets_AttentionXML.py b/libmultilabel/nn/datasets_AttentionXML.py index e1e3b50f..5f4046c3 100644 --- a/libmultilabel/nn/datasets_AttentionXML.py +++ b/libmultilabel/nn/datasets_AttentionXML.py @@ -19,11 +19,11 @@ class PlainDataset(Dataset): this while generating clusters. There is no need to do multilabel binarization again. Args: - x: texts - y: labels + x (list | ndarray | Tensor): texts + y (Optional: csr_matrix | ndarray | Tensor): labels """ - def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None): + def __init__(self, x, y=None): if y is not None: assert len(x) == y.shape[0], "Sizes mismatch between texts and labels" self.x = x @@ -56,10 +56,8 @@ class PLTDataset(PlainDataset): Args: x: texts y: labels - num_classes: number of classes. - mapping: mapping from clusters to labels. Shape: (len(clusters), cluster_size). - clusters_selected: sampled predicted clusters from model_0. Shape: (len(x), predict_top_k). - cluster_scores: corresponding scores. Shape: (len(x), predict_top_k) + labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k). + label_scores: scores for each label. Shape: (len(x), predict_top_k) """ def __init__( @@ -67,66 +65,19 @@ def __init__( x, y: Optional[csr_matrix | ndarray] = None, *, - num_classes: int, - mapping: ndarray, - clusters_selected: ndarray | Tensor, - cluster_scores: Optional[ndarray | Tensor] = None, + labels_selected: ndarray | Tensor, + label_scores: Optional[ndarray | Tensor] = None, ): super().__init__(x, y) - self.num_classes = num_classes - self.mapping = mapping - self.clusters_selected = clusters_selected - self.cluster_scores = cluster_scores - self.label_scores = None - - # labels_selected are labels extracted from cluster_selected. - # shape: (len(x), len(clusters_selected) * cluster_size) - self.labels_selected = [ - np.concatenate(self.mapping[labels]) - for labels in tqdm(self.clusters_selected, leave=False, desc="Retrieving labels from selected clusters") - ] - if self.cluster_scores is not None: - # label_scores are probability scores corresponding to labels_selected. - # shape: (len(x), len(clusters_selected) * cluster_size) - self.label_scores = [ - np.repeat(scores, [len(i) for i in self.mapping[labels]]) - for labels, scores in zip(self.clusters_selected, self.cluster_scores) - ] - - # top_k * n (n <= cluster_size). number of maximum possible number selected labels at the current level. - self.num_labels_selected = self.clusters_selected.shape[1] * max(len(clusters) for clusters in self.mapping) + self.labels_selected = labels_selected + self.label_scores = label_scores def __getitem__(self, idx: int): - item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx], dtype=np.int64)} + item = {"text": self.x[idx], "labels_selected": np.asarray(self.labels_selected[idx])} - # train/valid/test if self.y is not None: - item["label"] = self.y[idx].toarray().squeeze(0) - - # train - if self.label_scores is None: - # As networks require input to be of fixed shape, randomly select labels when the number of selected label - # is not enough - if len(item["labels_selected"]) < self.num_labels_selected: - sample = np.random.randint( - self.num_classes, size=self.num_labels_selected - len(item["labels_selected"]) - ) - item["labels_selected"] = np.concatenate([item["labels_selected"], sample]) - # valid/test - else: - item["label_scores"] = self.label_scores[idx] + item["label"] = self.y[idx, item["labels_selected"]].toarray().squeeze(0) - # add dummy elements when less than required - if len(item["labels_selected"]) < self.num_labels_selected: - item["label_scores"] = np.concatenate( - [item["label_scores"], [-np.inf] * (self.num_labels_selected - len(item["labels_selected"]))] - ) - item["labels_selected"] = np.concatenate( - [ - item["labels_selected"], - [self.num_classes] * (self.num_labels_selected - len(item["labels_selected"])), - ] - ) - - item["label_scores"] = np.asarray(item["label_scores"], dtype=np.float32) + if self.label_scores is not None: + item["label_scores"] = self.label_scores[idx] return item diff --git a/libmultilabel/nn/model_AttentionXML.py b/libmultilabel/nn/model_AttentionXML.py index 88b9685f..366b6ad7 100644 --- a/libmultilabel/nn/model_AttentionXML.py +++ b/libmultilabel/nn/model_AttentionXML.py @@ -34,29 +34,10 @@ def scatter_logits( """For each instance, we only have predictions on selected labels. This subroutine maps these predictions to the whole label space. The scores of unsampled labels are set to 0.""" src = torch.sigmoid(logits.detach()) * label_scores - preds = torch.zeros( - labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype - ) + preds = torch.zeros(labels_selected.size(0), len(self.classes), device=labels_selected.device, dtype=src.dtype) preds.scatter_(dim=1, index=labels_selected, src=src) - # remove dummy labels - preds = preds[:, :-1] return preds - def shared_step(self, batch): - """Return loss and predicted logits of the network. - - Args: - batch (dict): A batch of text and label. - - Returns: - loss (torch.Tensor): Loss between target and predict logits. - pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). - """ - y = torch.take_along_dim(batch["label"], batch["labels_selected"], dim=1) - logits = self(batch) - loss = self.loss_function(logits, y) - return loss, logits - def _shared_eval_step(self, batch, batch_idx): logits = self(batch) logits = self.scatter_logits(logits, batch["labels_selected"], batch["label_scores"]) diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index b91a1368..3b4ee50a 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -126,61 +126,66 @@ def __init__( # save path self.log_path = config.log_path - def label2cluster(self, cluster_mapping, *ys) -> Generator[csr_matrix, ...]: + def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]: """Map labels to their corresponding clusters in CSR sparse format. - - Suppose there are 6 labels and clusters are [(0, 1), (2, 3), (4, 5)] and ys of a given instance is [0, 1, 4]. - The clusters of the instance are [0, 2]. + Notice that this function deals with SPARSE matrix. + Assume there are 6 labels clustered as [(0, 1), (2, 3), (4, 5)]. Here (0, 1) is cluster with index 0 and so on. + Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2]. Args: - cluster_mapping: mapping from clusters to labels. - *ys: sparse labels. + cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + *labels (csr_matrix): labels in CSR sparse format. Returns: - Generator[csr_matrix]: clusters generated from labels + Generator[csr_matrix]: resulting clusters converted from labels in CSR sparse format """ mapping = np.empty(self.num_classes, dtype=np.uint64) for idx, clusters in enumerate(cluster_mapping): mapping[clusters] = idx - def _label2cluster(y: csr_matrix) -> csr_matrix: + def _label2cluster(label: csr_matrix) -> csr_matrix: row = [] col = [] data = [] - for i in range(y.shape[0]): + for i in range(label.shape[0]): # n include all mapped ancestor clusters - n = np.unique(mapping[y.indices[y.indptr[i] : y.indptr[i + 1]]]) + n = np.unique(mapping[label.indices[label.indptr[i] : label.indptr[i + 1]]]) row += [i] * len(n) col += n.tolist() data += [1] * len(n) - return csr_matrix((data, (row, col)), shape=(y.shape[0], len(cluster_mapping))) - - return (_label2cluster(y) for y in ys) - - # def cluster2label(self, cluster_mapping, *ys): - # """Map clusters to their corresponding labels. Notice this function only deals with dense matrix. - # - # Args: - # cluster_mapping: mapping from clusters to labels. - # *ys: sparse clusters. - # - # Returns: - # Generator[csr_matrix]: labels generated from clusters - # """ - # - # def _cluster2label(y: csr_matrix) -> csr_matrix: - # self.labels_selected = [np.concatenate(cluster_mapping[labels]) for labels in y] - # return (_cluster2label(y) for y in ys) - - # def generate_goals(self, cluster_scores, y): - # if cluster_scores is not None: - # # label_scores are corresponding scores for selected labels and - # # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), cluster_size * top_k) - # # notice how scores repeat for each cluster. - # self.label_scores = [ - # np.repeat(scores, [len(i) for i in cluster_mapping[labels]]) - # for labels, scores in zip(y, cluster_scores) - # ] + return csr_matrix((data, (row, col)), shape=(label.shape[0], len(cluster_mapping))) + + return (_label2cluster(label) for label in labels) + + @staticmethod + def cluster2label(cluster_mapping, clusters, cluster_scores=None): + """Expand clusters to their corresponding labels and, if available, assign scores to each label. + Labels inside the same cluster have the same scores. This function is applied to predictions from model 0. + Notice that the behaviors of this function are different from label2cluster. + Also notice that this function deals with DENSE matrix. + + Args: + cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree. + clusters (np.ndarray): predicted clusters from model 0. + cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0. + + Returns: + Generator[np.ndarray]: resulting labels expanded from clusters + """ + + labels_selected = [] + + if cluster_scores is not None: + # label_scores are corresponding scores for selected labels and + # shape: (len(x), cluster_size * top_k) + label_scores = [] + for score, cluster in zip(cluster_scores, clusters): + label_scores += [np.repeat(score, [len(labels) for labels in cluster_mapping[cluster]])] + labels_selected += [np.concatenate(cluster_mapping[cluster])] + return labels_selected, label_scores + else: + labels_selected = [np.concatenate(cluster_mapping[cluster]) for cluster in clusters] + return labels_selected def fit(self, datasets): """fit model to the training dataset @@ -291,7 +296,7 @@ def fit(self, datasets): val_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in val_pred])) val_clusters_pred = np.vstack([i["top_k_pred"] for i in val_pred]) - clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) + train_clusters_selected = np.empty((len(train_x), self.predict_top_k), dtype=np.int64) for i, ys in enumerate(tqdm(train_clusters_pred, leave=False, desc="Sampling clusters")): # relevant clusters are positive pos = set(train_y_clustered.indices[train_y_clustered.indptr[i] : train_y_clustered.indptr[i + 1]]) @@ -317,7 +322,10 @@ def fit(self, datasets): break if len(selected) < self.predict_top_k: selected = (list(selected) + list(pos - selected))[: self.predict_top_k] - clusters_selected[i] = np.asarray(list(selected)) + train_clusters_selected[i] = np.asarray(list(selected)) + + train_labels_selected = PLTTrainer.cluster2label(clusters, train_clusters_selected) + val_labels_pred, val_scores_pred = PLTTrainer.cluster2label(clusters, val_clusters_pred, val_scores_pred) trainer = init_trainer( self.checkpoint_dir, @@ -339,9 +347,7 @@ def fit(self, datasets): PLTDataset( train_x, train_y, - num_classes=self.num_classes, - mapping=clusters, - clusters_selected=clusters_selected, + labels_selected=train_labels_selected, ), shuffle=self.shuffle, ) @@ -349,10 +355,8 @@ def fit(self, datasets): PLTDataset( val_x, val_y, - num_classes=self.num_classes, - mapping=clusters, - clusters_selected=val_clusters_pred, - cluster_scores=val_scores_pred, + labels_selected=val_labels_pred, + label_scores=val_scores_pred, ), ) @@ -435,19 +439,18 @@ def test(self, dataset): logger.info(f"Predicting level 0, Top: {self.predict_top_k}") test_pred = trainer.predict(model_0, test_dataloader) - test_pred_scores = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) - test_pred_cluters = np.vstack([i["top_k_pred"] for i in test_pred]) + test_scores_pred = expit(np.vstack([i["top_k_pred_scores"] for i in test_pred])) + test_clusters_pred = np.vstack([i["top_k_pred"] for i in test_pred]) clusters = np.load(self.get_cluster_path(), allow_pickle=True) + test_labels_pred, test_scores_pred = PLTTrainer.cluster2label(clusters, test_clusters_pred, test_scores_pred) test_dataloader = self.eval_dataloader( PLTDataset( test_x, test_y, - num_classes=self.num_classes, - mapping=clusters, - clusters_selected=test_pred_cluters, - cluster_scores=test_pred_scores, + labels_selected=test_labels_pred, + label_scores=test_scores_pred, ), )