Skip to content

Commit

Permalink
loss and model update
Browse files Browse the repository at this point in the history
  • Loading branch information
orisenbazuru committed Aug 17, 2020
1 parent ce3d519 commit 550a571
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 29 deletions.
73 changes: 65 additions & 8 deletions ddi/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __getitem__(self, indx):
target_id = self.partition_ids[indx]
X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id]
assert ddi_indx == gip_indx
# assert ddi_indx == gip_indx
# combine gip with other matrices
X_a_comb = torch.cat([X_a, X_a_gip], axis=0)
X_b_comb = torch.cat([X_b, X_b_gip], axis=0)
Expand All @@ -66,6 +66,34 @@ def __len__(self):
return(self.num_samples)


# no GIP involved
# class PartitionDataTensor(Dataset):

# def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num):
# self.ddi_datatensor = ddi_datatensor # instance of :class:`DDIDataTensor`
# self.gip_datatensor = gip_datatensor # instance of :class:`GIPDataTensor`
# self.partition_ids = partition_ids # list of indices for drug pairs
# self.dsettype = dsettype # string, dataset type (i.e. train, validation, test)
# self.fold_num = fold_num # int, fold number
# self.num_samples = len(self.partition_ids) # int, number of docs in the partition

# def __getitem__(self, indx):
# target_id = self.partition_ids[indx]
# X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
# # X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id]
# # assert ddi_indx == gip_indx
# # combine gip with other matrices
# # X_a_comb = torch.cat([X_a, X_a_gip], axis=0)
# # X_b_comb = torch.cat([X_b, X_b_gip], axis=0)
# X_a_comb = X_a
# X_b_comb = X_b
# return X_a_comb, X_b_comb, y, ddi_indx

# def __len__(self):
# return(self.num_samples)



def construct_load_dataloaders(dataset_fold, dsettypes, config, wrk_dir):
"""construct dataloaders for the dataset for one fold
Expand Down Expand Up @@ -108,12 +136,14 @@ def construct_load_dataloaders(dataset_fold, dsettypes, config, wrk_dir):

return (data_loaders, epoch_loss_avgbatch, score_dict, class_weights, flog_out)

def preprocess_features(feat_fpath, dsetname):
def preprocess_features(feat_fpath, dsetname, fill_diag = None):
if dsetname in {'DS1', 'DS3'}:
X_fea = np.loadtxt(feat_fpath,dtype=float,delimiter=",")
elif dsetname == 'DS2':
X_fea = pd.read_csv(feat_fpath).values[:,1:]
X_fea = X_fea.astype(np.float32)
if fill_diag is not None:
np.fill_diagonal(X_fea, fill_diag)
return get_features_from_simmatrix(X_fea)

def get_features_from_simmatrix(sim_mat):
Expand All @@ -122,12 +152,34 @@ def get_features_from_simmatrix(sim_mat):
sim_mat: np.array, mxm (drug pair similarity matrix)
"""
r, c = np.triu_indices(len(sim_mat),1) # take indices off the diagnoal by 1
return np.concatenate((sim_mat[r], sim_mat[c]), axis=1)
return np.concatenate((sim_mat[r], sim_mat[c], sim_mat[r,c].reshape(-1,1), sim_mat[c,r].reshape(-1,1)), axis=1)

# def get_features_from_simmatrix(sim_mat):
# """
# Args:
# sim_mat: np.array, mxm (drug pair similarity matrix)
# """
# r, c = np.triu_indices(len(sim_mat),1) # take indices off the diagnoal by 1
# rl, cl = np.tril_indices(len(sim_mat),0)
# r_comb = r.tolist() + rl.tolist()
# c_comb = c.tolist() + cl.tolist()
# return np.concatenate((sim_mat[r_comb], sim_mat[c_comb]), axis=1)

def preprocess_labels(interaction_fpath, dsetname):
interaction_mat = get_interaction_mat(interaction_fpath, dsetname)
return get_y_from_interactionmat(interaction_mat)

def get_y_from_interactionmat(interaction_mat):
r, c = np.triu_indices(len(interaction_mat),1) # take indices off the diagnoal by 1
return interaction_mat[r,c]

# def get_y_from_interactionmat(interaction_mat):
# r, c = np.triu_indices(len(interaction_mat),1)
# rl, cl = np.tril_indices(len(interaction_mat),0)
# r_comb = r.tolist() + rl.tolist()
# c_comb = c.tolist() + cl.tolist()
# return interaction_mat[r_comb,c_comb]

def compute_gip_profile(adj, bw=1.):
"""approach based on Olayan et al. https://doi.org/10.1093/bioinformatics/btx731 """

Expand Down Expand Up @@ -156,6 +208,15 @@ def construct_sampleid_ddipairs(interaction_mat):
sid_ddipairs = {sid:ddi_pair for sid, ddi_pair in enumerate(zip(r,c))}
return sid_ddipairs

# def construct_sampleid_ddipairs(interaction_mat):
# # take indices off the diagnoal by 1
# r, c = np.triu_indices(len(interaction_mat),1)
# rl, cl = np.tril_indices(len(interaction_mat),0)
# r_comb = r.tolist() + rl.tolist()
# c_comb = c.tolist() + cl.tolist()
# sid_ddipairs = {sid:ddi_pair for sid, ddi_pair in enumerate(zip(r_comb, c_comb))}
# return sid_ddipairs

def get_num_drugs(interaction_fpath, dsetname):
if dsetname in {'DS1', 'DS3'}:
interaction_matrix = np.loadtxt(interaction_fpath,dtype=float,delimiter=",")
Expand All @@ -178,10 +239,6 @@ def get_similarity_matrix(feat_fpath, dsetname):
X_fea = X_fea.astype(np.float32)
return X_fea

def get_y_from_interactionmat(interaction_mat):
r, c = np.triu_indices(len(interaction_mat),1) # take indices off the diagnoal by 1
return interaction_mat[r,c]

def create_setvector_features(X, num_sim_types):
"""reshape concatenated features from every similarity type matrix into set of vectors per ddi example"""
e = X[np.newaxis, :, :]
Expand Down Expand Up @@ -374,7 +431,7 @@ def compute_class_weights(labels_tensor):
classes, counts = np.unique(labels_tensor, return_counts=True)
# print("classes", classes)
# print("counts", counts)
class_weights = compute_class_weight('balanced', classes, labels_tensor.numpy())
class_weights = compute_class_weight('balanced', classes=classes, y=labels_tensor.numpy())
return class_weights


Expand Down
33 changes: 33 additions & 0 deletions ddi/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,39 @@ def forward(self, dist, target):
# repel = target * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2))
# attract = (1-target) * 0.5 * torch.pow(dist, 2)

if reduction == 'mean':
loss_contrastive = torch.mean(repel + attract)
elif reduction == 'sum':
loss_contrastive = torch.sum(repel + attract)
elif reduction == 'none':
loss_contrastive = repel + attract
return loss_contrastive


class CosEmbLoss(nn.Module):
"""
Cosine Embedding loss
"""

def __init__(self, margin, reduction='mean'):
super().__init__()
self.margin = margin
self.reduction = reduction

def forward(self, sim, target):
"""
Args:
sim: tensor, (batch, ), computed similarity between two inputs
target: tensor, (batch, ), labels (0/1)
"""
margin = self.margin
reduction = self.reduction
repel = (1-target) * (torch.clamp(sim - margin, min=0.0))
attract = target * (1-sim)

# repel = target * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2))
# attract = (1-target) * 0.5 * torch.pow(dist, 2)

if reduction == 'mean':
loss_contrastive = torch.mean(repel + attract)
elif reduction == 'sum':
Expand Down
30 changes: 16 additions & 14 deletions ddi/model_attn_siamese.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,8 @@ def __init__(self, input_size=586, input_embed_dim=64, num_attn_heads=8, mlp_emb
pooling_mode = 'attn'):

super().__init__()

embed_size = input_embed_dim

embed_size = input_size
# embed_size = input_embed_dim
self.Wembed = nn.Linear(input_size, embed_size)

trfunit_layers = [TransformerUnit(embed_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)]
Expand All @@ -190,7 +189,7 @@ def forward(self, X):
X: tensor, (batch, ddi similarity type vector, input_size)
"""

X = self.Wembed(X)
# X = self.Wembed(X)
z = self.trfunit_pipeline(X)

# mean pooling TODO: add global attention layer or other pooling strategy
Expand Down Expand Up @@ -225,12 +224,13 @@ def __init__(self, input_dim, dist, num_classes=2):
self.dist = nn.CosineSimilarity(dim=1)
self.alpha = 1

self.pooling = FeatureEmbAttention(input_dim)
self.Wy = nn.Linear(input_dim+1, num_classes)
# self.pooling = FeatureEmbAttention(input_dim)
self.Wy = nn.Linear(2*input_dim+1, num_classes)
# perform log softmax on the feature dimension
self.log_softmax = nn.LogSoftmax(dim=-1)

self._init_params_()
print('updated')


def _init_params_(self):
Expand All @@ -246,13 +246,15 @@ def forward(self, Z_a, Z_b):
dist = self.dist(Z_a, Z_b).reshape(-1,1)
# update dist to distance measure if cosine is chosen
dist = self.alpha * (1-dist) + (1-self.alpha) * dist
# concat both vectors to pass to feature attention for unified representation
# Z_a: (batch, 1, embedding dim)
Z_a = Z_a.unsqueeze(1)
Z_b = Z_b.unsqueeze(1)
# Z_union: (batch, embedding dim)
Z_union, __ = self.pooling(torch.cat([Z_a, Z_b], axis=1))
# print("Z_union:", Z_union.shape)
out = torch.cat([Z_union, dist], axis=-1)
# # concat both vectors to pass to feature attention for unified representation
# # Z_a: (batch, 1, embedding dim)
# Z_a = Z_a.unsqueeze(1)
# Z_b = Z_b.unsqueeze(1)
# # Z_union: (batch, embedding dim)
# Z_union, __ = self.pooling(torch.cat([Z_a, Z_b], axis=1))
# # print("Z_union:", Z_union.shape)
# out = torch.cat([Z_union, dist], axis=-1)

out = torch.cat([Z_a, Z_b, dist], axis=-1)
y = self.Wy(out)
return self.log_softmax(y), dist
19 changes: 12 additions & 7 deletions ddi/run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# from .model_attn import DDI_Transformer
from .model_attn_siamese import DDI_SiameseTrf, DDI_Transformer, FeatureEmbAttention
from .dataset import construct_load_dataloaders
from .losses import ContrastiveLoss
from .losses import ContrastiveLoss, CosEmbLoss
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -330,6 +330,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
print("class weights", class_weights)
loss_func = torch.nn.NLLLoss(weight=class_weights, reduction='mean') # negative log likelihood loss
loss_contrastive = ContrastiveLoss(options.get('contrastiveloss_margin', 0.5), reduction='mean')
# loss_contrastive = CosEmbLoss(options.get('contrastiveloss_margin', 0.5), reduction='mean')
loss_contrastive.type(fdtype).to(device)
# loss_attn = FeatureEmbAttention(1)
# loss_attn.type(fdtype).to(device)
Expand All @@ -351,7 +352,9 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
pdropout=model_config.p_dropout,
num_transformer_units=model_config.num_transformer_units,
pooling_mode=model_config.pooling_mode)
ddi_siamese = DDI_SiameseTrf(model_config.input_embed_dim,model_config.dist_opt, num_classes=2)
ddi_siamese = DDI_SiameseTrf(options['input_dim'],model_config.dist_opt, num_classes=2)

# ddi_siamese = DDI_SiameseTrf(model_config.input_embed_dim,model_config.dist_opt, num_classes=2)


# define optimizer and group parameters
Expand Down Expand Up @@ -379,8 +382,9 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
c_step_size = int(np.ceil(5*num_iter)) # this should be 2-10 times num_iter
base_lr = 3e-4
max_lr = 5*base_lr # 3-5 times base_lr
# optimizer = torch.optim.Adam(models_param, weight_decay=weight_decay, lr=base_lr)
optimizer = torch.optim.Adam(models_param, lr=base_lr)
print('max lr', max_lr)
optimizer = torch.optim.Adam(models_param, weight_decay=weight_decay, lr=base_lr)
# optimizer = torch.optim.Adam(models_param, lr=base_lr)
cyc_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr, max_lr, step_size_up=c_step_size,
mode='triangular', cycle_momentum=False)

Expand Down Expand Up @@ -462,7 +466,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
ddi_ids.extend(ids.tolist())

cl = loss_func(logsoftmax_scores, y_batch)
# dl = loss_contrastive(dist.reshape(-1), y_batch)

dl = loss_contrastive(dist.reshape(-1), y_batch.type(fdtype))
# print(cl)
# print('cl', cl.shape)
# print('dl', dl.shape)
Expand All @@ -473,8 +478,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
# loss = loss.mean()
# # print(loss)

# loss = 0.8*cl + 0.2*dl
loss = cl
loss = cl + dl
# loss = cl
# loss = 0.8*loss_func(logsoftmax_scores, y_batch) + 0.2*loss_contrastive(dist.reshape(-1), y_batch)
# loss = loss_func(logsoftmax_scores, y_batch)

Expand Down

0 comments on commit 550a571

Please sign in to comment.