Skip to content

Commit

Permalink
fix misused prob & improve readibility
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Mar 7, 2024
1 parent 8d11756 commit e46d796
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 144 deletions.
11 changes: 4 additions & 7 deletions example_config/AmazonCat-13K/attentionxml.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
data_name: AmazonCat-13K
training_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/train.txt
test_file: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/test.txt
# will change path later
training_file: data/AmazonCat-13K/train.txt
test_file: data/AmazonCat-13K/test.txt
# pretrained embeddings
embed_file: glove.840B.300d
embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding
# save path
result_dir: /l/users/dongli.he/libml/LibMultiLabel/AmazonCat-13K/results

# preprocessing
min_vocab_freq: 1
Expand All @@ -28,11 +26,10 @@ val_metric: nDCG@5
# train
seed: 1337
epochs: 10
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: adam
learning_rate: 0.001
# early stopping
patience: 5
patience: 10

# model
model_name: AttentionXML
Expand Down
9 changes: 4 additions & 5 deletions example_config/EUR-Lex/attentionxml.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
data_name: EUR-Lex
training_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/train.txt
test_file: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/test.txt
training_file: data/EUR-Lex/train.txt
test_file: data/EUR-Lex/test.txt
# pretrained embeddings
embed_file: glove.840B.300d
embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding
# save path
# will change path later
result_dir: /l/users/dongli.he/libml/LibMultiLabel/EUR-Lex/results

# preprocessing
Expand All @@ -28,11 +28,10 @@ val_metric: nDCG@5
# train
seed: 1337
epochs: 30
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: adam
learning_rate: 0.001
# early stopping
patience: 5
patience: 30

# model
model_name: AttentionXML
Expand Down
9 changes: 4 additions & 5 deletions example_config/Wiki10-31K/attentionxml.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
data_name: Wiki10-31K
training_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/train.txt
test_file: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/test.txt
# will change path later
training_file: data/Wiki10-31K/train.txt
test_file: data/Wiki10-31K/test.txt
# pretrained embeddings
embed_file: glove.840B.300d
embed_cache_dir: /l/users/dongli.he/libml/LibMultiLabel/embedding
# save path
result_dir: /l/users/dongli.he/libml/LibMultiLabel/Wiki10-31K/results

Expand All @@ -28,11 +28,10 @@ val_metric: nDCG@5
# train
seed: 1337
epochs: 30
# https://github.com/Lightning-AI/lightning/issues/8826
optimizer: adam
learning_rate: 0.001
# early stopping
patience: 5
patience: 30

# model
model_name: AttentionXML
Expand Down
35 changes: 13 additions & 22 deletions libmultilabel/nn/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,23 @@ def build_label_tree(sparse_x: csr_matrix, sparse_y: csr_matrix, cluster_size: i
# meta info
logger.info("Label clustering started")
logger.info(f"Cluster size: {cluster_size}")
num_labels = sparse_y.shape[1]
# The height of the tree satisfies the following inequality:
# 2**(tree_height - 1) * cluster_size < num_labels <= 2**tree_height * cluster_size
height = int(np.ceil(np.log2(num_labels / cluster_size)))
height = int(np.ceil(np.log2(sparse_y.shape[1] / cluster_size)))
logger.info(f"Labels will be grouped into {2**height} clusters")

output_dir.mkdir(parents=True, exist_ok=True)

# For each label, sum up instances relevant to the label and normalize to get the label representation
# For each label, sum up normalized instances relevant to the label and normalize to get the label representation
label_repr = normalize(sparse_y.T @ csc_matrix(normalize(sparse_x)))

# clustering by a binary tree:
# at each layer split each cluster to two. Leave nodes correspond to the obtained clusters.
clusters = [np.arange(num_labels)]
clusters = [np.arange(sparse_y.shape[1])]
for _ in range(height):
next_clusters = []
for cluster in clusters:
next_clusters.extend(_split_cluster(cluster, label_repr))
next_clusters.extend(_split_cluster(cluster, label_repr[cluster]))
clusters = next_clusters
logger.info(f"Having grouped {len(clusters)} clusters")

Expand All @@ -71,18 +70,10 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n
Args:
cluster: a subset of labels
label_repr: the normalized representations of the relationship between labels and texts
label_repr: the normalized representations of the relationship between labels and texts of the given cluster
"""
tol = 1e-4

# the normalized label representations corresponding to the cluster
tgt_repr = label_repr[cluster]

# the number of labels in the cluster
n = len(cluster)

# Randomly choose two points as initial centroids and obtain their label representations
centroids = tgt_repr[np.random.choice(n, size=2, replace=False)].toarray()
centroids = label_repr[np.random.choice(len(cluster), size=2, replace=False)].toarray()

# Initialize distances (cosine similarity)
# Cosine similarity always falls to the interval [-1, 1]
Expand All @@ -93,30 +84,30 @@ def _split_cluster(cluster: ndarray, label_repr: csr_matrix) -> tuple[ndarray, n
c0_idx = None
c1_idx = None

while new_dist - old_dist >= tol:
# Notice that tgt_repr and centroids.T have been normalized
while new_dist - old_dist >= 1e-4:
# Notice that label_repr and centroids.T have been normalized
# Thus, dist indicates the cosine similarity between points and centroids.
dist = tgt_repr @ centroids.T # shape: (n, 2)
dist = label_repr @ centroids.T # shape: (n, 2)

# generate clusters
# let a = dist[:, 1] - dist[:, 0], the larger the element in a is, the closer the point is to c1
k = n // 2
k = len(cluster) // 2
c_idx = np.argpartition(dist[:, 1] - dist[:, 0], kth=k)
c0_idx = c_idx[:k]
c1_idx = c_idx[k:]

# update distances
# the new distance is the average of in-cluster distances to the centroids
old_dist = new_dist
new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / n
new_dist = (dist[c0_idx, 0].sum() + dist[c1_idx, 1].sum()) / len(cluster)

# update centroids
# the new centroid is the normalized average of the points in the cluster
centroids = normalize(
np.asarray(
[
np.squeeze(np.asarray(tgt_repr[c0_idx].sum(axis=0))),
np.squeeze(np.asarray(tgt_repr[c1_idx].sum(axis=0))),
np.squeeze(np.asarray(label_repr[c0_idx].sum(axis=0))),
np.squeeze(np.asarray(label_repr[c1_idx].sum(axis=0))),
]
)
)
Expand Down
44 changes: 23 additions & 21 deletions libmultilabel/nn/datasets_AttentionXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@


class PlainDataset(Dataset):
"""Basic class for multi-label dataset."""
"""Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset.
WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset
does not.
Moreover, TextDataset implements multilabel binarization in a mandatory way. Nevertheless, AttentionXML already does
this while generating clusters. There is no need to do multilabel binarization again.
def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None):
"""General dataset class for multi-label dataset.
Args:
x: texts
y: labels
"""

Args:
x: texts
y: labels
"""
def __init__(self, x: list | ndarray | Tensor, y: Optional[csr_matrix | ndarray | Tensor] = None):
if y is not None:
assert len(x) == y.shape[0], "Sizes mismatch between x and y"
assert len(x) == y.shape[0], "Sizes mismatch between texts and labels"
self.x = x
self.y = y

Expand All @@ -48,7 +51,16 @@ def __len__(self):


class PLTDataset(PlainDataset):
"""Dataset class for AttentionXML."""
"""Dataset for model_1 of AttentionXML.
Args:
x: texts
y: labels
num_classes: number of classes.
mapping: mapping from clusters to labels. Shape: (len(clusters), cluster_size).
clusters_selected: sampled predicted clusters from model_0. Shape: (len(x), predict_top_k).
cluster_scores: corresponding scores. Shape: (len(x), predict_top_k)
"""

def __init__(
self,
Expand All @@ -60,17 +72,6 @@ def __init__(
clusters_selected: ndarray | Tensor,
cluster_scores: Optional[ndarray | Tensor] = None,
):
"""Dataset for AttentionXML.
Args:
x: texts
y: labels
num_classes: number of clusters at the current level.
mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(clusters), cluster_size). Map from clusters to labels.
clusters_selected: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted clusters
from last level.
cluster_scores: corresponding scores. shape: (len(x), top_k)
"""
super().__init__(x, y)
self.num_classes = num_classes
self.mapping = mapping
Expand Down Expand Up @@ -105,7 +106,8 @@ def __getitem__(self, idx: int):

# train
if self.label_scores is None:
# randomly select clusters as selected labels when less than required
# As networks require input to be of fixed shape, randomly select labels when the number of selected label
# is not enough
if len(item["labels_selected"]) < self.num_labels_selected:
sample = np.random.randint(
self.num_classes, size=self.num_labels_selected - len(item["labels_selected"])
Expand Down
7 changes: 3 additions & 4 deletions libmultilabel/nn/model_AttentionXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ def __init__(
**kwargs,
)

def multilabel_binarize(
def scatter_preds(
self,
logits: Tensor,
labels_selected: Tensor,
label_scores: Tensor,
) -> Tensor:
"""self-implemented MultiLabelBinarizer for AttentionXML"""
"""map predictions from sample space to label space. The scores of unsampled labels are set to 0."""
src = torch.sigmoid(logits.detach()) * label_scores
# make sure preds and src use the same precision, e.g., either float16 or float32
preds = torch.zeros(
labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype
)
Expand Down Expand Up @@ -65,7 +64,7 @@ def _shared_eval_step(self, batch, batch_idx):
labels_selected = batch["labels_selected"]
label_scores = batch["label_scores"]
logits = self.network(x, labels_selected=labels_selected)["logits"]
y_pred = self.multilabel_binarize(logits, labels_selected, label_scores)
y_pred = self.scatter_preds(logits, labels_selected, label_scores)
self.eval_metric.update(y_pred, y.long())

def predict_step(self, batch, batch_idx, dataloader_idx=0):
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def forward(self, inputs, labels_selected):


class MultilayerLinearOutput(nn.Module):
def __init__(self, linear_size: list[int], output_size: int):
def __init__(self, linear_size, output_size):
super().__init__()
self.linears = nn.ModuleList(nn.Linear(in_s, out_s) for in_s, out_s in zip(linear_size[:-1], linear_size[1:]))
self.output = nn.Linear(linear_size[-1], output_size)
Expand Down
Loading

0 comments on commit e46d796

Please sign in to comment.