Skip to content

Commit

Permalink
update for paper review 📃
Browse files Browse the repository at this point in the history
  • Loading branch information
orisenbazuru committed Jun 10, 2021
1 parent 4905d42 commit 9c723ca
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 479 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,7 @@ venv.bak/
# orisenbazuru
explore.py
notebooks/orisenbazuru/*
cluster/data/medinfmk/ddi/processed/*
ideas.txt.rtf
notebooks/archive/*
ideas.txt.rtf
trained_models
data/processed
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# side-effects
# 📣 AttentionDDI 💊

This repository contains the code for the AttentionDDI model implementation with PyTorch.

AttentionDDI is a Siamese multi-head self-Attention multi-modal neural network model used for drug-drug interaction (DDI) predictions.

## Installation

* `git clone` the repo and `cd` into it.
* Run `pip install -e .` to install the repo's python package.
* Run `pip install -e .` to install the repo's python package.

## Running 🏃

1. use `notebooks/jupyter/AttnWSiamese_data_generation.ipynb` to generate DataTensors from the drug similarity matrices.
2. use `notebooks/jupyter/AttnWSiamese_hyperparam.ipynb` to find the best performing model hyperparameters.
3. use `notebooks/jupyter/AttnWSiamese_train_eval.ipynb` to train / test on the best hyperparameters.
28 changes: 4 additions & 24 deletions ddi/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class DDIDataTensor(Dataset):

def __init__(self, y, X_a, X_b):
def __init__(self, X_a, X_b, y):

self.X_a = X_a # tensor.float32, (drug pairs, features)
self.X_b = X_b # tensor.float32, (drug pairs, features)
Expand Down Expand Up @@ -46,7 +46,7 @@ def __len__(self):

class PartitionDataTensor(Dataset):

def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num, is_siamese):
def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num, is_siamese=True):
self.ddi_datatensor = ddi_datatensor # instance of :class:`DDIDataTensor`
self.gip_datatensor = gip_datatensor # instance of :class:`GIPDataTensor`
self.partition_ids = partition_ids # list of indices for drug pairs
Expand All @@ -60,6 +60,7 @@ def __getitem__(self, indx):
X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id]
# combine gip with other matrices
X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
# (sim_types, features)
X_a_comb = torch.cat([X_a, X_a_gip], axis=0)
X_b_comb = torch.cat([X_b, X_b_gip], axis=0)
X_comb = torch.cat([X_a_comb, X_b_comb])#.view(-1)
Expand Down Expand Up @@ -187,27 +188,6 @@ def get_y_from_interactionmat(interaction_mat):
# c_comb = c.tolist() + cl.tolist()
# return interaction_mat[r_comb,c_comb]

def compute_gip_profile(adj, bw=1.):
"""approach based on Olayan et al. https://doi.org/10.1093/bioinformatics/btx731 """

ga = np.dot(adj,np.transpose(adj))
ga = bw*ga/np.mean(np.diag(ga))
di = np.diag(ga)
x = np.tile(di,(1,di.shape[0])).reshape(di.shape[0],di.shape[0])
d =x+np.transpose(x)-2*ga
return np.exp(-d)

def compute_kernel(mat, k_bandwidth, epsilon=1e-9):
"""computes gaussian kernel from 2D matrix
Approach based on van Laarhoven et al. doi:10.1093/bioinformatics/btr500
"""
r, c = mat.shape # 2D matrix
# computes pairwise l2 distance
dist_kernel = squareform(pdist(mat, metric='euclidean')**2)
gamma = k_bandwidth/(np.clip((scpnorm(mat, axis=1, keepdims=True)**2) * 1/c, a_min=epsilon, a_max=None))
return np.exp(-gamma*dist_kernel)

def construct_sampleid_ddipairs(interaction_mat):
# take indices off the diagnoal by 1
Expand Down Expand Up @@ -339,7 +319,7 @@ def report_label_distrib(labels):
print("class:", label, "norm count:", norm_counts[i])


def generate_partition_datatensor(ddi_datatensor, gip_dtensor_perfold, data_partitions, is_siamese):
def generate_partition_datatensor(ddi_datatensor, gip_dtensor_perfold, data_partitions, is_siamese=True):
datatensor_partitions = {}
for fold_num in data_partitions:
datatensor_partitions[fold_num] = {}
Expand Down
31 changes: 19 additions & 12 deletions ddi/model_attn_siamese.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,19 @@ def forward(self, X):
Args:
X: tensor, (batch, ddi similarity type vector, input_size)
"""

bsize, num_modal, inp_dim = X.shape
attn_tensor = X.new_zeros((bsize, num_modal, num_modal))
out = []
for SH_layer in self.multihead_pipeline:
z, __ = SH_layer(X)
z, attn_w_normalized = SH_layer(X)
out.append(z)
attn_tensor += attn_w_normalized
# concat on the feature dimension
out = torch.cat(out, -1)
attn_tensor = attn_tensor/len(self.multihead_pipeline)

# return a unified vector mapping of the different self-attention blocks
return self.Wz(out)
return self.Wz(out), attn_tensor


class TransformerUnit(nn.Module):
Expand Down Expand Up @@ -98,15 +101,15 @@ def forward(self, X):
X: tensor, (batch, ddi similarity type vector, input_size)
"""
# z is tensor of size (batch, ddi similarity type vector, input_size)
z = self.multihead_attn(X)
z, attn_tensor = self.multihead_attn(X)
# layer norm with residual connection
z = self.layernorm_1(z + X)
z = self.dropout(z)
z_ff= self.MLP(z)
z = self.layernorm_2(z_ff + z)
z = self.dropout(z)

return z
return z, attn_tensor

class FeatureEmbAttention(nn.Module):
def __init__(self, input_dim):
Expand Down Expand Up @@ -167,7 +170,7 @@ def __init__(self, input_size=586, input_embed_dim=64, num_attn_heads=8, mlp_emb
self.Wembed = nn.Linear(input_size, embed_size)

trfunit_layers = [TransformerUnit(embed_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)]
self.trfunit_pipeline = nn.Sequential(*trfunit_layers)
self.trfunit_pipeline = nn.ModuleList(trfunit_layers)

self.pooling_mode = pooling_mode
if pooling_mode == 'attn':
Expand All @@ -187,9 +190,15 @@ def forward(self, X):
X: tensor, (batch, ddi similarity type vector, input_size)
"""

# X = self.Wembed(X)
# mean pooling TODO: add global attention layer or other pooling strategy
z = self.trfunit_pipeline(X)
bsize, num_modal, inp_dim = X.shape
attn_tensor = X.new_zeros((bsize, num_modal, num_modal))
xinput = X
for encunit in self.trfunit_pipeline:
z, attn_h_tensor = encunit(xinput)
xinput = z
attn_tensor += attn_h_tensor
attn_tensor = attn_tensor/len(self.trfunit_pipeline)

# pool across similarity type vectors
# Note: z.mean(dim=1) will change shape of z to become (batch, input_size)
Expand All @@ -204,7 +213,7 @@ def forward(self, X):
z = self.pooling(z, dim=1)
fattn_w_norm = None

return z, fattn_w_norm
return z, fattn_w_norm, attn_tensor

class DDI_SiameseTrf(nn.Module):

Expand All @@ -226,9 +235,7 @@ def __init__(self, input_dim, dist, num_classes=2):
# perform log softmax on the feature dimension
self.log_softmax = nn.LogSoftmax(dim=-1)

self._init_params_()
print('updated')

self._init_params_()

def _init_params_(self):
_init_model_params(self.named_parameters())
Expand Down
50 changes: 37 additions & 13 deletions ddi/run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run_ddi(data_partition, dsettypes, config, options, wrk_dir,
else:
class_weights = torch.tensor([1]*2).type(fdtype).to(device) # weighting all casess equally

print("class weights", class_weights)
# print("class weights", class_weights)
# binary cross entropy
loss_bce = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights, reduction='mean')
loss_nlll = torch.nn.NLLLoss(weight=class_weights, reduction='mean') # negative log likelihood loss
Expand Down Expand Up @@ -354,7 +354,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
cld = construct_load_dataloaders(data_partition, dsettypes, dataloader_config, wrk_dir)
# dictionaries by dsettypes
data_loaders, epoch_loss_avgbatch, score_dict, class_weights, flog_out = cld
print(flog_out)
# print(flog_out)
device = get_device(to_gpu, gpu_index) # gpu device
fdtype = options['fdtype']

Expand All @@ -363,7 +363,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
else:
class_weights = torch.tensor([1]*2).type(fdtype).to(device) # weighting all casess equally

print("class weights", class_weights)
# print("class weights", class_weights)
loss_func = torch.nn.NLLLoss(weight=class_weights, reduction='mean') # negative log likelihood loss
loss_contrastive = ContrastiveLoss(options.get('contrastiveloss_margin', 0.5), reduction='mean')
loss_contrastive.type(fdtype).to(device)
Expand Down Expand Up @@ -401,10 +401,9 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
for m, m_name in models:
m.type(fdtype).to(device)

print('cool')
if('train' in data_loaders):
weight_decay = options.get('weight_decay', 1e-4)
print('weight_decay', weight_decay)
# print('weight_decay', weight_decay)
# split model params into attn parameters and other params
# models_param = add_weight_decay_except_attn([ddi_model, ddi_siamese], weight_decay)
# see paper Cyclical Learning rates for Training Neural Networks for parameters' choice
Expand All @@ -415,7 +414,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
c_step_size = int(np.ceil(5*num_iter)) # this should be 2-10 times num_iter
base_lr = 3e-4
max_lr = 5*base_lr # 3-5 times base_lr
print('max lr', max_lr)
# print('max lr', max_lr)
optimizer = torch.optim.Adam(models_param, weight_decay=weight_decay, lr=base_lr)
cyc_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr, max_lr, step_size_up=c_step_size,
mode='triangular', cycle_momentum=False)
Expand All @@ -432,6 +431,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
ReaderWriter.dump_data(options, os.path.join(config_dir, 'exp_options.pkl'))
# store attention weights for validation and test set
seqid_fattnw_map = {dsettype: {'X_a':{}, 'X_b':{}} for dsettype in data_loaders if dsettype in {'test'}}
seqid_hattnw_map = {dsettype: {'X_a':{}, 'X_b':{}} for dsettype in data_loaders if dsettype in {'test'}}

pair_names = ('a', 'b')

for epoch in range(num_epochs):
Expand Down Expand Up @@ -469,13 +470,17 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,

with torch.set_grad_enabled(dsettype == 'train'):
num_samples_perbatch = X_a.size(0)
z_a, fattn_w_scores_a = ddi_model(X_a)
z_b, fattn_w_scores_b = ddi_model(X_b)
z_a, fattn_w_scores_a, hattn_w_scores_a = ddi_model(X_a)
z_b, fattn_w_scores_b, hattn_w_scores_b = ddi_model(X_b)

if(dsettype in seqid_fattnw_map and model_config.pooling_mode == 'attn'):
for l, attn_scores in enumerate((fattn_w_scores_a, fattn_w_scores_b)):
suffix = pair_names[l]
seqid_fattnw_map[dsettype][f'X_{suffix}'].update({sid.item():attn_scores[c].detach().cpu() for c, sid in enumerate(ids)})

for l, attn_scores in enumerate((hattn_w_scores_a, hattn_w_scores_b)):
suffix = pair_names[l]
seqid_hattnw_map[dsettype][f'X_{suffix}'].update({sid.item():attn_scores[c].detach().cpu() for c, sid in enumerate(ids)})


logsoftmax_scores, dist = ddi_siamese(z_a, z_b)
Expand Down Expand Up @@ -520,6 +525,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
elif(dsettype == 'test'):
# dump attention weights for the test data
dump_dict_content(seqid_fattnw_map, ['test'], 'sampleid_fattnw_map', wrk_dir)
dump_dict_content(seqid_hattnw_map, ['test'], 'sampleid_hattnw_map', wrk_dir)
if dsettype in {'test', 'validation'}:
predictions_df = build_predictions_df(ddi_ids, ref_class, pred_class, prob_scores_arr)
predictions_path = os.path.join(wrk_dir, f'predictions_{dsettype}.csv')
Expand Down Expand Up @@ -632,7 +638,7 @@ def get_best_config_from_hyperparamsearch(hyperparam_search_dir, num_folds=5, nu
if(os.path.isfile(score_file)):
try:
mscore = ReaderWriter.read_data(score_file)
print(mscore)
# print(mscore)
scores[config_num, 0] = mscore.best_epoch_indx
scores[config_num, 1] = mscore.s_precision
scores[config_num, 2] = mscore.s_recall
Expand Down Expand Up @@ -677,9 +683,13 @@ def train_val_run(datatensor_partitions, config_map, train_val_dir, fold_gpu_map
state_dict_dir=None, to_gpu=True,
gpu_index=fold_gpu_map[fold_num])



def test_run(datatensor_partitions, config_map, train_val_dir, test_dir, fold_gpu_map, num_epochs=1):
def test_run(datatensor_partitions,
config_map,
train_val_dir,
test_dir,
fold_gpu_map,
suffix_testfname=None,
num_epochs=1):
dsettypes = ['test']
mconfig, options = config_map
options['num_epochs'] = num_epochs # override number of epochs using user specified value
Expand All @@ -692,7 +702,10 @@ def test_run(datatensor_partitions, config_map, train_val_dir, test_dir, fold_gp
if os.path.exists(train_dir):
# load state_dict pth
state_dict_pth = os.path.join(train_dir, 'model_statedict')
path = os.path.join(test_dir, 'test', 'fold_{}'.format(fold_num))
if suffix_testfname:
path = os.path.join(test_dir, f'test_{suffix_testfname}', 'fold_{}'.format(fold_num))
else:
path = os.path.join(test_dir, 'test', 'fold_{}'.format(fold_num))
test_wrk_dir = create_directory(path)
if options.get('loss_func') == 'bceloss':
run_ddi(data_partition, dsettypes, mconfig, options, test_wrk_dir,
Expand All @@ -711,6 +724,17 @@ def train_test_partition(datatensor_partition, config_map, tr_val_dir, fold_gpu_
train_val_run(datatensor_partition, config_map, tr_val_dir, fold_gpu_map, num_epochs=config_epochs)
test_run(datatensor_partition, config_map, tr_val_dir, tr_val_dir, fold_gpu_map, num_epochs=1)

def test_partition(datatensor_partition, config_map, tr_val_dir, fold_gpu_map, suffix_testfname):
config_epochs = config_map[0]['model_config'].num_epochs
print(config_epochs)
test_run(datatensor_partition,
config_map,
tr_val_dir,
tr_val_dir,
fold_gpu_map,
suffix_testfname=suffix_testfname,
num_epochs=1)

def train_test_hyperparam_conf(hyperparam_comb, gpu_num, datatensor_partition, fold_gpu_map, exp_dir, num_drugs, queue, exp_iden):
text_to_save = str(hyperparam_comb)
print("hyperparam_comb:", text_to_save, "gpu num:", str(gpu_num))
Expand Down
Loading

0 comments on commit 9c723ca

Please sign in to comment.