Skip to content

Commit

Permalink
merged models into Siamese
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriakosschwarz committed Mar 26, 2021
1 parent 9de1989 commit 371b51b
Show file tree
Hide file tree
Showing 20 changed files with 2,241 additions and 4,357 deletions.
15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
# AttentionDDI

This repository contains the code for the AttentionDDI model implementation with PyTorch.

AttentionDDI is a Siamese multi-head self-Attention multi-modal neural network model used for drug-drug interaction (DDI) predictions.
# side-effects

## Installation

* `git clone` the repo and `cd` into it.
* Run `pip install -e .` to install the repo's python package.

## Running

1. use `notebooks/jupyter/AttnWSiamese_data_generation.ipynb` to generate DataTensors from the drug similarity matrices.
2. use `notebooks/jupyter/AttnWSiamese_hyperparam.ipynb` to find the best performing model hyperparameters.
3. use `notebooks/jupyter/AttnWSiamese_train_eval.ipynb` to train / test on the best hyperparameters.
4. use `notebooks/jupyter/AttnWSiamese_AttnWeights.ipynb` to plot the Attention weights.
* Run `pip install -e .` to install the repo's python package.
2 changes: 0 additions & 2 deletions ddi.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ setup.py
ddi/__init__.py
ddi/dataset.py
ddi/losses.py
ddi/model.py
ddi/model_attn.py
ddi/model_attn_siamese.py
ddi/run_workflow.py
ddi/utilities.py
Expand Down
Binary file modified ddi/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/dataset.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/losses.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/model.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/model_attn_siamese.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/run_workflow.cpython-36.pyc
Binary file not shown.
Binary file modified ddi/__pycache__/utilities.cpython-36.pyc
Binary file not shown.
20 changes: 14 additions & 6 deletions ddi/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

class DDIDataTensor(Dataset):

def __init__(self, X_a, X_b, y):
def __init__(self, y, X_a, X_b):

self.X_a = X_a # tensor.float32, (drug pairs, features)
self.X_b = X_b # tensor.float32, (drug pairs, features)

# drug interactions
self.y = y # tensor.float32, (drug pairs,)
self.num_samples = self.y.size(0) # int, number of drug pairs
Expand Down Expand Up @@ -44,22 +46,28 @@ def __len__(self):

class PartitionDataTensor(Dataset):

def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num):
def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num, is_siamese):
self.ddi_datatensor = ddi_datatensor # instance of :class:`DDIDataTensor`
self.gip_datatensor = gip_datatensor # instance of :class:`GIPDataTensor`
self.partition_ids = partition_ids # list of indices for drug pairs
self.dsettype = dsettype # string, dataset type (i.e. train, validation, test)
self.fold_num = fold_num # int, fold number
self.num_samples = len(self.partition_ids) # int, number of docs in the partition
self.is_siamese = is_siamese

def __getitem__(self, indx):
target_id = self.partition_ids[indx]
X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id]
# combine gip with other matrices
X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
X_a_comb = torch.cat([X_a, X_a_gip], axis=0)
X_b_comb = torch.cat([X_b, X_b_gip], axis=0)
return X_a_comb, X_b_comb, y, ddi_indx
X_comb = torch.cat([X_a_comb, X_b_comb])#.view(-1)

if (self.is_siamese):
return X_a_comb, X_b_comb, y, ddi_indx
else:
return X_comb, y, ddi_indx

def __len__(self):
return(self.num_samples)
Expand Down Expand Up @@ -271,14 +279,14 @@ def report_label_distrib(labels):
print("class:", label, "norm count:", norm_counts[i])


def generate_partition_datatensor(ddi_datatensor, gip_dtensor_perfold, data_partitions):
def generate_partition_datatensor(ddi_datatensor, gip_dtensor_perfold, data_partitions, is_siamese):
datatensor_partitions = {}
for fold_num in data_partitions:
datatensor_partitions[fold_num] = {}
gip_datatensor = gip_dtensor_perfold[fold_num]
for dsettype in data_partitions[fold_num]:
target_ids = data_partitions[fold_num][dsettype]
datatensor_partition = PartitionDataTensor(ddi_datatensor, gip_datatensor, target_ids, dsettype, fold_num)
datatensor_partition = PartitionDataTensor(ddi_datatensor, gip_datatensor, target_ids, dsettype, fold_num, is_siamese)
datatensor_partitions[fold_num][dsettype] = datatensor_partition
compute_class_weights_per_fold_(datatensor_partitions)

Expand Down
27 changes: 27 additions & 0 deletions ddi/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch import nn
import torch.nn.functional as F

class NDD_Code(nn.Module):
def __init__(self, D_in=1096, H1=400, H2=300, D_out=1, drop=0.5):
super(NDD_Code, self).__init__()
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(D_in, H1) # Fully Connected
self.fc2 = nn.Linear(H1, H2)
self.fc3 = nn.Linear(H2, D_out)
self.drop = nn.Dropout(drop)
self._init_weights()

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.drop(x)
x = F.relu(self.fc2(x))
x = self.drop(x)
x = self.fc3(x)
return x

def _init_weights(self):
for m in self.modules():
if(isinstance(m, nn.Linear)):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.uniform_(-1,0)
65 changes: 64 additions & 1 deletion ddi/model_attn_siamese.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,67 @@ def forward(self, Z_a, Z_b):

out = torch.cat([Z_a, Z_b, dist], axis=-1)
y = self.Wy(out)
return self.log_softmax(y), dist
return self.log_softmax(y), dist

class DDI_Transformer_Softmax(nn.Module):

def __init__(self, input_size=586, input_embed_dim=64, num_attn_heads=8, mlp_embed_factor=2,
nonlin_func=nn.ReLU(), pdropout=0.3, num_transformer_units=12,
pooling_mode = 'attn', num_classes=2):

super().__init__()

embed_size = input_size #input_embed_dim

self.Wembed = nn.Linear(input_size, embed_size)

trfunit_layers = [TransformerUnit(embed_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)]
self.trfunit_pipeline = nn.Sequential(*trfunit_layers)

self.Wy = nn.Linear(embed_size, num_classes)
self.pooling_mode = pooling_mode
if pooling_mode == 'attn':
self.pooling = FeatureEmbAttention(embed_size)
elif pooling_mode == 'mean':
self.pooling = torch.mean

# perform log softmax on the feature dimension
self.log_softmax = nn.LogSoftmax(dim=-1)
self._init_params_()


def _init_params_(self):
for p_name, p in self.named_parameters():
param_dim = p.dim()
if param_dim > 1: # weight matrices
nn.init.xavier_uniform_(p)
elif param_dim == 1: # bias parameters
if p_name.endswith('bias'):
nn.init.uniform_(p, a=-1.0, b=1.0)

def forward(self, X):
"""
Args:
X: tensor, (batch, ddi similarity type vector, input_size)
"""

X = self.Wembed(X)
z = self.trfunit_pipeline(X)

# mean pooling TODO: add global attention layer or other pooling strategy
# pool across similarity type vectors
# Note: z.mean(dim=1) will change shape of z to become (batch, input_size)
# we can keep dimension by running z.mean(dim=1, keepdim=True) to have (batch, 1, input_size)

# pool across similarity type vectors
if self.pooling_mode == 'attn':
z, fattn_w_norm = self.pooling(z)
# Note: z.mean(dim=1) or self.pooling(z, dim=1) will change shape of z to become (batch, embedding dim)
# we can keep dimension by running z.mean(dim=1, keepdim=True) to have (batch, 1, embedding dim)
elif self.pooling_mode == 'mean':
z = self.pooling(z, dim=1)
fattn_w_norm = None

y = self.Wy(z)

return self.log_softmax(y) #,fattn_w_norm
99 changes: 75 additions & 24 deletions ddi/run_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
from .utilities import get_device, create_directory, ReaderWriter, perfmetric_report, plot_loss, add_weight_decay_except_attn
from .model import NDD_Code
from .model_attn_siamese import DDI_SiameseTrf, DDI_Transformer, FeatureEmbAttention
from .model_attn_siamese import DDI_SiameseTrf, DDI_Transformer, DDI_Transformer_Softmax, FeatureEmbAttention
from .dataset import construct_load_dataloaders
from .losses import ContrastiveLoss, CosEmbLoss
import numpy as np
Expand Down Expand Up @@ -164,39 +164,59 @@ def run_ddi(data_partition, dsettypes, config, options, wrk_dir,
data_loaders, epoch_loss_avgbatch, score_dict, class_weights, flog_out = cld
device = get_device(to_gpu, gpu_index) # gpu device
fdtype = options['fdtype']
if('train' in class_weights):
class_weights = class_weights['train'][1].type(fdtype).to(device) # update class weights to fdtype tensor
else:
class_weights = torch.tensor([1]).type(fdtype).to(device) # weighting all casess equally

print("class weights", class_weights)
# binary cross entropy
loss_func = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights, reduction='mean')


num_epochs = options.get('num_epochs', 50)
fold_num = options.get('fold_num')

# parse config dict
model_config = config['model_config']
model_name = options['model_name']

if(model_name == 'NDD'):
if('train' in class_weights):
class_weights = class_weights['train'][1].type(fdtype).to(device) # update class weights to fdtype tensor
else:
class_weights = torch.tensor([1]).type(fdtype).to(device) # weighting all casess equally

elif(model_name == 'Transformer'):
if('train' in class_weights):
class_weights = class_weights['train'].type(fdtype).to(device) # update class weights to fdtype tensor
else:
class_weights = torch.tensor([1]*2).type(fdtype).to(device) # weighting all casess equally

print("class weights", class_weights)
# binary cross entropy
loss_bce = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights, reduction='mean')
loss_nlll = torch.nn.NLLLoss(weight=class_weights, reduction='mean') # negative log likelihood loss

print("run_ddi model_name:", model_name)

if(model_name == 'NDD'):
# ddi model
ddi_model = NDD_Code(D_in=options['input_dim'],
H1=model_config.fc1_dim,
H2=model_config.fc2_dim,
D_out=1,
drop=model_config.p_dropout)

elif(model_name == 'Transformer'):
ddi_model = DDI_Transformer_Softmax(input_size=options['input_dim'],
input_embed_dim=model_config.input_embed_dim,
num_attn_heads=model_config.num_attn_heads,
mlp_embed_factor=model_config.mlp_embed_factor,
nonlin_func=model_config.nonlin_func,
pdropout=model_config.p_dropout,
num_transformer_units=model_config.num_transformer_units,
pooling_mode=model_config.pooling_mode)


# define optimizer and group parameters
models_param = list(ddi_model.parameters())
models = [(ddi_model, model_name)]

if(state_dict_dir): # load state dictionary of saved models
num_train_epochs = 20
for m, m_name in models:
m.load_state_dict(torch.load(os.path.join(state_dict_dir, '{}_{}.pkl'.format(m_name, num_train_epochs)), map_location=device))
m.load_state_dict(torch.load(os.path.join(state_dict_dir, '{}.pkl'.format(m_name)), map_location=device))

# update models fdtype and move to device
for m, m_name in models:
Expand Down Expand Up @@ -252,23 +272,45 @@ def run_ddi(data_partition, dsettypes, config, options, wrk_dir,

X_batch, y_batch, ids = samples_batch

if(model_name == 'NDD'):
X_batch = torch.flatten(X_batch, 1)

X_batch = X_batch.to(device)
y_batch = y_batch.reshape(-1, 1)
y_batch = y_batch.to(device)


with torch.set_grad_enabled(dsettype == 'train'):
num_samples_perbatch = X_batch.size(0)
y_pred_logit = ddi_model(X_batch)
y_pred_prob = sigmoid(y_pred_logit)
y_pred_clss = torch.zeros(y_pred_prob.shape, device=device, dtype=torch.int32)
y_pred_clss[y_pred_prob > 0.5] = 1

if(model_name == 'NDD'):
y_pred_logit = ddi_model(X_batch)
y_pred_prob = sigmoid(y_pred_logit)
y_pred_clss = torch.zeros(y_pred_prob.shape, device=device, dtype=torch.int32)
y_pred_clss[y_pred_prob > 0.5] = 1

y_batch = y_batch.reshape(-1, 1)
y_batch = y_batch.type(torch.int64).to(device)
loss = loss_bce(y_pred_logit, y_batch.type(fdtype))


elif(model_name == 'Transformer'):
logsoftmax_scores = ddi_model(X_batch)

__, y_pred_clss = torch.max(logsoftmax_scores, -1)
y_pred_prob = torch.exp(logsoftmax_scores.detach().cpu()).numpy()

y_batch = y_batch.reshape(-1)
y_batch = y_batch.type(torch.int64).to(device)

loss = loss_nlll(logsoftmax_scores, y_batch)


pred_class.extend(y_pred_clss.view(-1).tolist())
ref_class.extend(y_batch.view(-1).tolist())
prob_scores.extend(y_pred_prob.view(-1).tolist())
# prob_scores.append(y_pred_prob.view(-1).tolist())
prob_scores.append(y_pred_prob.tolist())
ddi_ids.extend(ids.tolist())


loss = loss_func(y_pred_logit, y_batch)
if(dsettype == 'train'):
# backward step (i.e. compute gradients)
loss.backward()
Expand All @@ -279,9 +321,14 @@ def run_ddi(data_partition, dsettypes, config, options, wrk_dir,
epoch_loss += loss.item()

epoch_loss_avgbatch[dsettype].append(epoch_loss/len(data_loader))
modelscore = perfmetric_report(pred_class, ref_class, prob_scores, epoch+1, flog_out[dsettype])
# modelscore = perfmetric_report(pred_class, ref_class, prob_scores, epoch+1, flog_out[dsettype])
prob_scores_arr = np.concatenate(prob_scores, axis=0)


if(model_name == 'NDD'):
modelscore = perfmetric_report(pred_class, ref_class, prob_scores_arr, epoch, flog_out[dsettype])
elif(model_name == 'Transformer'):
modelscore = perfmetric_report(pred_class, ref_class, prob_scores_arr[:,1], epoch, flog_out[dsettype])

perf = modelscore.s_aupr
if(perf > score_dict[dsettype].s_aupr):
score_dict[dsettype] = modelscore
Expand All @@ -298,6 +345,7 @@ def run_ddi(data_partition, dsettypes, config, options, wrk_dir,
# dump_scores
dump_dict_content(score_dict, list(score_dict.keys()), 'score', wrk_dir)


def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
state_dict_dir=None, to_gpu=True, gpu_index=0):
pid = "{}".format(os.getpid()) # process id description
Expand Down Expand Up @@ -328,7 +376,8 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
model_config = config['model_config']
model_name = options['model_name']


print("run_ddiTrf model_name:", model_name)

if(model_name == 'Transformer'):
ddi_model = DDI_Transformer(input_size=options['input_dim'],
input_embed_dim=model_config.input_embed_dim,
Expand Down Expand Up @@ -434,6 +483,7 @@ def run_ddiTrf(data_partition, dsettypes, config, options, wrk_dir,
__, y_pred_clss = torch.max(logsoftmax_scores, -1)

y_pred_prob = torch.exp(logsoftmax_scores.detach().cpu()).numpy()
# print("y_pred_prob shape", y_pred_prob.shape)

pred_class.extend(y_pred_clss.view(-1).tolist())
ref_class.extend(y_batch.view(-1).tolist())
Expand Down Expand Up @@ -615,6 +665,7 @@ def train_val_run(datatensor_partitions, config_map, train_val_dir, fold_gpu_map
options['fold_num'] = fold_num
data_partition = datatensor_partitions[fold_num]
path = os.path.join(train_val_dir, 'train_val', 'fold_{}'.format(fold_num))
# normpath
wrk_dir = create_directory(path)
print(wrk_dir)
if options.get('loss_func') == 'bceloss':
Expand Down
2 changes: 1 addition & 1 deletion ddi/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def create_directory(folder_name, directory="current"):
path_current_dir = directory
print("path_current_dir", path_current_dir)

path_new_dir = os.path.join(path_current_dir, folder_name)
path_new_dir = os.path.normpath(os.path.join(path_current_dir, folder_name))
if not os.path.exists(path_new_dir):
os.makedirs(path_new_dir)
return(path_new_dir)
Expand Down
Loading

0 comments on commit 371b51b

Please sign in to comment.