diff --git a/ddi/dataset.py b/ddi/dataset.py index dfe98e1..7e13454 100644 --- a/ddi/dataset.py +++ b/ddi/dataset.py @@ -56,7 +56,7 @@ def __getitem__(self, indx): target_id = self.partition_ids[indx] X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id] X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id] - assert ddi_indx == gip_indx + # assert ddi_indx == gip_indx # combine gip with other matrices X_a_comb = torch.cat([X_a, X_a_gip], axis=0) X_b_comb = torch.cat([X_b, X_b_gip], axis=0) @@ -66,6 +66,34 @@ def __len__(self): return(self.num_samples) +# no GIP involved +# class PartitionDataTensor(Dataset): + +# def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num): +# self.ddi_datatensor = ddi_datatensor # instance of :class:`DDIDataTensor` +# self.gip_datatensor = gip_datatensor # instance of :class:`GIPDataTensor` +# self.partition_ids = partition_ids # list of indices for drug pairs +# self.dsettype = dsettype # string, dataset type (i.e. train, validation, test) +# self.fold_num = fold_num # int, fold number +# self.num_samples = len(self.partition_ids) # int, number of docs in the partition + +# def __getitem__(self, indx): +# target_id = self.partition_ids[indx] +# X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id] +# # X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id] +# # assert ddi_indx == gip_indx +# # combine gip with other matrices +# # X_a_comb = torch.cat([X_a, X_a_gip], axis=0) +# # X_b_comb = torch.cat([X_b, X_b_gip], axis=0) +# X_a_comb = X_a +# X_b_comb = X_b +# return X_a_comb, X_b_comb, y, ddi_indx + +# def __len__(self): +# return(self.num_samples) + + + def construct_load_dataloaders(dataset_fold, dsettypes, config, wrk_dir): """construct dataloaders for the dataset for one fold @@ -108,12 +136,14 @@ def construct_load_dataloaders(dataset_fold, dsettypes, config, wrk_dir): return (data_loaders, epoch_loss_avgbatch, score_dict, class_weights, flog_out) -def preprocess_features(feat_fpath, dsetname): +def preprocess_features(feat_fpath, dsetname, fill_diag = None): if dsetname in {'DS1', 'DS3'}: X_fea = np.loadtxt(feat_fpath,dtype=float,delimiter=",") elif dsetname == 'DS2': X_fea = pd.read_csv(feat_fpath).values[:,1:] X_fea = X_fea.astype(np.float32) + if fill_diag is not None: + np.fill_diagonal(X_fea, fill_diag) return get_features_from_simmatrix(X_fea) def get_features_from_simmatrix(sim_mat): @@ -122,12 +152,34 @@ def get_features_from_simmatrix(sim_mat): sim_mat: np.array, mxm (drug pair similarity matrix) """ r, c = np.triu_indices(len(sim_mat),1) # take indices off the diagnoal by 1 - return np.concatenate((sim_mat[r], sim_mat[c]), axis=1) + return np.concatenate((sim_mat[r], sim_mat[c], sim_mat[r,c].reshape(-1,1), sim_mat[c,r].reshape(-1,1)), axis=1) + +# def get_features_from_simmatrix(sim_mat): +# """ +# Args: +# sim_mat: np.array, mxm (drug pair similarity matrix) +# """ +# r, c = np.triu_indices(len(sim_mat),1) # take indices off the diagnoal by 1 +# rl, cl = np.tril_indices(len(sim_mat),0) +# r_comb = r.tolist() + rl.tolist() +# c_comb = c.tolist() + cl.tolist() +# return np.concatenate((sim_mat[r_comb], sim_mat[c_comb]), axis=1) def preprocess_labels(interaction_fpath, dsetname): interaction_mat = get_interaction_mat(interaction_fpath, dsetname) return get_y_from_interactionmat(interaction_mat) +def get_y_from_interactionmat(interaction_mat): + r, c = np.triu_indices(len(interaction_mat),1) # take indices off the diagnoal by 1 + return interaction_mat[r,c] + +# def get_y_from_interactionmat(interaction_mat): +# r, c = np.triu_indices(len(interaction_mat),1) +# rl, cl = np.tril_indices(len(interaction_mat),0) +# r_comb = r.tolist() + rl.tolist() +# c_comb = c.tolist() + cl.tolist() +# return interaction_mat[r_comb,c_comb] + def compute_gip_profile(adj, bw=1.): """approach based on Olayan et al. https://doi.org/10.1093/bioinformatics/btx731 """ @@ -156,6 +208,15 @@ def construct_sampleid_ddipairs(interaction_mat): sid_ddipairs = {sid:ddi_pair for sid, ddi_pair in enumerate(zip(r,c))} return sid_ddipairs +# def construct_sampleid_ddipairs(interaction_mat): +# # take indices off the diagnoal by 1 +# r, c = np.triu_indices(len(interaction_mat),1) +# rl, cl = np.tril_indices(len(interaction_mat),0) +# r_comb = r.tolist() + rl.tolist() +# c_comb = c.tolist() + cl.tolist() +# sid_ddipairs = {sid:ddi_pair for sid, ddi_pair in enumerate(zip(r_comb, c_comb))} +# return sid_ddipairs + def get_num_drugs(interaction_fpath, dsetname): if dsetname in {'DS1', 'DS3'}: interaction_matrix = np.loadtxt(interaction_fpath,dtype=float,delimiter=",") @@ -178,10 +239,6 @@ def get_similarity_matrix(feat_fpath, dsetname): X_fea = X_fea.astype(np.float32) return X_fea -def get_y_from_interactionmat(interaction_mat): - r, c = np.triu_indices(len(interaction_mat),1) # take indices off the diagnoal by 1 - return interaction_mat[r,c] - def create_setvector_features(X, num_sim_types): """reshape concatenated features from every similarity type matrix into set of vectors per ddi example""" e = X[np.newaxis, :, :] @@ -374,7 +431,7 @@ def compute_class_weights(labels_tensor): classes, counts = np.unique(labels_tensor, return_counts=True) # print("classes", classes) # print("counts", counts) - class_weights = compute_class_weight('balanced', classes, labels_tensor.numpy()) + class_weights = compute_class_weight('balanced', classes=classes, y=labels_tensor.numpy()) return class_weights diff --git a/ddi/losses.py b/ddi/losses.py index b61b669..ca698cf 100644 --- a/ddi/losses.py +++ b/ddi/losses.py @@ -26,6 +26,39 @@ def forward(self, dist, target): # repel = target * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2)) # attract = (1-target) * 0.5 * torch.pow(dist, 2) + if reduction == 'mean': + loss_contrastive = torch.mean(repel + attract) + elif reduction == 'sum': + loss_contrastive = torch.sum(repel + attract) + elif reduction == 'none': + loss_contrastive = repel + attract + return loss_contrastive + + +class CosEmbLoss(nn.Module): + """ + Cosine Embedding loss + """ + + def __init__(self, margin, reduction='mean'): + super().__init__() + self.margin = margin + self.reduction = reduction + + def forward(self, sim, target): + """ + Args: + sim: tensor, (batch, ), computed similarity between two inputs + target: tensor, (batch, ), labels (0/1) + """ + margin = self.margin + reduction = self.reduction + repel = (1-target) * (torch.clamp(sim - margin, min=0.0)) + attract = target * (1-sim) + + # repel = target * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2)) + # attract = (1-target) * 0.5 * torch.pow(dist, 2) + if reduction == 'mean': loss_contrastive = torch.mean(repel + attract) elif reduction == 'sum': diff --git a/ddi/model_attn_siamese.py b/ddi/model_attn_siamese.py index 86eff2f..a4c9003 100644 --- a/ddi/model_attn_siamese.py +++ b/ddi/model_attn_siamese.py @@ -164,9 +164,8 @@ def __init__(self, input_size=586, input_embed_dim=64, num_attn_heads=8, mlp_emb pooling_mode = 'attn'): super().__init__() - - embed_size = input_embed_dim - + embed_size = input_size + # embed_size = input_embed_dim self.Wembed = nn.Linear(input_size, embed_size) trfunit_layers = [TransformerUnit(embed_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)] @@ -190,7 +189,7 @@ def forward(self, X): X: tensor, (batch, ddi similarity type vector, input_size) """ - X = self.Wembed(X) + # X = self.Wembed(X) z = self.trfunit_pipeline(X) # mean pooling TODO: add global attention layer or other pooling strategy @@ -225,12 +224,13 @@ def __init__(self, input_dim, dist, num_classes=2): self.dist = nn.CosineSimilarity(dim=1) self.alpha = 1 - self.pooling = FeatureEmbAttention(input_dim) - self.Wy = nn.Linear(input_dim+1, num_classes) + # self.pooling = FeatureEmbAttention(input_dim) + self.Wy = nn.Linear(2*input_dim+1, num_classes) # perform log softmax on the feature dimension self.log_softmax = nn.LogSoftmax(dim=-1) self._init_params_() + print('updated') def _init_params_(self): @@ -246,13 +246,15 @@ def forward(self, Z_a, Z_b): dist = self.dist(Z_a, Z_b).reshape(-1,1) # update dist to distance measure if cosine is chosen dist = self.alpha * (1-dist) + (1-self.alpha) * dist - # concat both vectors to pass to feature attention for unified representation - # Z_a: (batch, 1, embedding dim) - Z_a = Z_a.unsqueeze(1) - Z_b = Z_b.unsqueeze(1) - # Z_union: (batch, embedding dim) - Z_union, __ = self.pooling(torch.cat([Z_a, Z_b], axis=1)) - # print("Z_union:", Z_union.shape) - out = torch.cat([Z_union, dist], axis=-1) + # # concat both vectors to pass to feature attention for unified representation + # # Z_a: (batch, 1, embedding dim) + # Z_a = Z_a.unsqueeze(1) + # Z_b = Z_b.unsqueeze(1) + # # Z_union: (batch, embedding dim) + # Z_union, __ = self.pooling(torch.cat([Z_a, Z_b], axis=1)) + # # print("Z_union:", Z_union.shape) + # out = torch.cat([Z_union, dist], axis=-1) + + out = torch.cat([Z_a, Z_b, dist], axis=-1) y = self.Wy(out) return self.log_softmax(y), dist \ No newline at end of file diff --git a/ddi/run_workflow.py b/ddi/run_workflow.py index 4b74fe2..c912337 100644 --- a/ddi/run_workflow.py +++ b/ddi/run_workflow.py @@ -6,7 +6,7 @@ # from .model_attn import DDI_Transformer from .model_attn_siamese import DDI_SiameseTrf, DDI_Transformer, FeatureEmbAttention from .dataset import construct_load_dataloaders -from .losses import ContrastiveLoss +from .losses import ContrastiveLoss, CosEmbLoss import numpy as np import pandas as pd import torch @@ -330,6 +330,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir, print("class weights", class_weights) loss_func = torch.nn.NLLLoss(weight=class_weights, reduction='mean') # negative log likelihood loss loss_contrastive = ContrastiveLoss(options.get('contrastiveloss_margin', 0.5), reduction='mean') + # loss_contrastive = CosEmbLoss(options.get('contrastiveloss_margin', 0.5), reduction='mean') loss_contrastive.type(fdtype).to(device) # loss_attn = FeatureEmbAttention(1) # loss_attn.type(fdtype).to(device) @@ -351,7 +352,9 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir, pdropout=model_config.p_dropout, num_transformer_units=model_config.num_transformer_units, pooling_mode=model_config.pooling_mode) - ddi_siamese = DDI_SiameseTrf(model_config.input_embed_dim,model_config.dist_opt, num_classes=2) + ddi_siamese = DDI_SiameseTrf(options['input_dim'],model_config.dist_opt, num_classes=2) + + # ddi_siamese = DDI_SiameseTrf(model_config.input_embed_dim,model_config.dist_opt, num_classes=2) # define optimizer and group parameters @@ -379,8 +382,9 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir, c_step_size = int(np.ceil(5*num_iter)) # this should be 2-10 times num_iter base_lr = 3e-4 max_lr = 5*base_lr # 3-5 times base_lr - # optimizer = torch.optim.Adam(models_param, weight_decay=weight_decay, lr=base_lr) - optimizer = torch.optim.Adam(models_param, lr=base_lr) + print('max lr', max_lr) + optimizer = torch.optim.Adam(models_param, weight_decay=weight_decay, lr=base_lr) + # optimizer = torch.optim.Adam(models_param, lr=base_lr) cyc_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr, max_lr, step_size_up=c_step_size, mode='triangular', cycle_momentum=False) @@ -462,7 +466,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir, ddi_ids.extend(ids.tolist()) cl = loss_func(logsoftmax_scores, y_batch) - # dl = loss_contrastive(dist.reshape(-1), y_batch) + + dl = loss_contrastive(dist.reshape(-1), y_batch.type(fdtype)) # print(cl) # print('cl', cl.shape) # print('dl', dl.shape) @@ -473,8 +478,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir, # loss = loss.mean() # # print(loss) - # loss = 0.8*cl + 0.2*dl - loss = cl + loss = cl + dl + # loss = cl # loss = 0.8*loss_func(logsoftmax_scores, y_batch) + 0.2*loss_contrastive(dist.reshape(-1), y_batch) # loss = loss_func(logsoftmax_scores, y_batch)