Skip to content

Commit

Permalink
siamese attn model
Browse files Browse the repository at this point in the history
  • Loading branch information
orisenbazuru committed Aug 9, 2020
1 parent b15dffa commit ce3d519
Show file tree
Hide file tree
Showing 6 changed files with 530 additions and 73 deletions.
135 changes: 113 additions & 22 deletions ddi/dataset.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,67 @@
import os
import numpy as np
import torch
from .utilities import ModelScore, ReaderWriter
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.utils.class_weight import compute_class_weight
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import norm as scpnorm
import pandas as pd
from .utilities import ModelScore, ReaderWriter


class DDIDataTensor(Dataset):

def __init__(self, X_feat, y):
self.X_feat = X_feat # tensor.float32, (drug pairs, features)
def __init__(self, X_a, X_b, y):
self.X_a = X_a # tensor.float32, (drug pairs, features)
self.X_b = X_b # tensor.float32, (drug pairs, features)
# drug interactions
self.y = y # tensor.float32, (drug pairs,)
self.num_samples = self.y.size(0) # int, number of drug pairs

def __getitem__(self, indx):

return(self.X_feat[indx], self.y[indx], indx)
return(self.X_a[indx], self.X_b[indx], self.y[indx], indx)

def __len__(self):
return(self.num_samples)

class GIPDataTensor(Dataset):

def __init__(self, X_a, X_b):
self.X_a = X_a # tensor.float32, (drug pairs, gip features)
self.X_b = X_b # tensor.float32, (drug pairs, gip features)
# drug interactions
self.num_samples = self.X_a.size(0) # int, number of drug pairs

def __getitem__(self, indx):

return(self.X_a[indx], self.X_b[indx], indx)

def __len__(self):
return(self.num_samples)


class PartitionDataTensor(Dataset):

def __init__(self, ddi_datatensor, partition_ids, dsettype, fold_num):
def __init__(self, ddi_datatensor, gip_datatensor, partition_ids, dsettype, fold_num):
self.ddi_datatensor = ddi_datatensor # instance of :class:`DDIDataTensor`
self.gip_datatensor = gip_datatensor # instance of :class:`GIPDataTensor`
self.partition_ids = partition_ids # list of indices for drug pairs
self.dsettype = dsettype # string, dataset type (i.e. train, validation, test)
self.fold_num = fold_num # int, fold number
self.num_samples = len(self.partition_ids) # int, number of docs in the partition

def __getitem__(self, indx):
target_id = self.partition_ids[indx]
return self.ddi_datatensor[target_id]

X_a, X_b, y, ddi_indx = self.ddi_datatensor[target_id]
X_a_gip, X_b_gip, gip_indx = self.gip_datatensor[target_id]
assert ddi_indx == gip_indx
# combine gip with other matrices
X_a_comb = torch.cat([X_a, X_a_gip], axis=0)
X_b_comb = torch.cat([X_b, X_b_gip], axis=0)
return X_a_comb, X_b_comb, y, ddi_indx

def __len__(self):
return(self.num_samples)

Expand Down Expand Up @@ -82,43 +108,106 @@ def construct_load_dataloaders(dataset_fold, dsettypes, config, wrk_dir):

return (data_loaders, epoch_loss_avgbatch, score_dict, class_weights, flog_out)

def preprocess_features(feat_fpath):
X_fea = np.loadtxt(feat_fpath,dtype=float,delimiter=",")
r, c = np.triu_indices(len(X_fea),1) # take indices off the diagnoal by 1
return np.concatenate((X_fea[r], X_fea[c]), axis=1)
def preprocess_features(feat_fpath, dsetname):
if dsetname in {'DS1', 'DS3'}:
X_fea = np.loadtxt(feat_fpath,dtype=float,delimiter=",")
elif dsetname == 'DS2':
X_fea = pd.read_csv(feat_fpath).values[:,1:]
X_fea = X_fea.astype(np.float32)
return get_features_from_simmatrix(X_fea)

def get_features_from_simmatrix(sim_mat):
"""
Args:
sim_mat: np.array, mxm (drug pair similarity matrix)
"""
r, c = np.triu_indices(len(sim_mat),1) # take indices off the diagnoal by 1
return np.concatenate((sim_mat[r], sim_mat[c]), axis=1)

def preprocess_labels(interaction_fpath):
interaction_matrix = np.loadtxt(interaction_fpath,dtype=float,delimiter=",")
r, c = np.triu_indices(len(interaction_matrix),1) # take indices off the diagnoal by 1
return interaction_matrix[r,c]
def preprocess_labels(interaction_fpath, dsetname):
interaction_mat = get_interaction_mat(interaction_fpath, dsetname)
return get_y_from_interactionmat(interaction_mat)

def compute_gip_profile(adj, bw=1.):
"""approach based on Olayan et al. https://doi.org/10.1093/bioinformatics/btx731 """

ga = np.dot(adj,np.transpose(adj))
ga = bw*ga/np.mean(np.diag(ga))
di = np.diag(ga)
x = np.tile(di,(1,di.shape[0])).reshape(di.shape[0],di.shape[0])
d =x+np.transpose(x)-2*ga
return np.exp(-d)

def compute_kernel(mat, k_bandwidth, epsilon=1e-9):
"""computes gaussian kernel from 2D matrix
Approach based on van Laarhoven et al. doi:10.1093/bioinformatics/btr500
"""
r, c = mat.shape # 2D matrix
# computes pairwise l2 distance
dist_kernel = squareform(pdist(mat, metric='euclidean')**2)
gamma = k_bandwidth/(np.clip((scpnorm(mat, axis=1, keepdims=True)**2) * 1/c, a_min=epsilon, a_max=None))
return np.exp(-gamma*dist_kernel)

def construct_sampleid_ddipairs(interaction_mat):
# take indices off the diagnoal by 1
r, c = np.triu_indices(len(interaction_mat),1)
sid_ddipairs = {sid:ddi_pair for sid, ddi_pair in enumerate(zip(r,c))}
return sid_ddipairs

def get_num_drugs(interaction_fpath, dsetname):
if dsetname in {'DS1', 'DS3'}:
interaction_matrix = np.loadtxt(interaction_fpath,dtype=float,delimiter=",")
elif dsetname == 'DS2':
interaction_matrix = pd.read_csv(interaction_fpath).values[:,1:]
return interaction_matrix.shape[0]

def get_interaction_mat(interaction_fpath, dsetname):
if dsetname in {'DS1', 'DS3'}:
interaction_matrix = np.loadtxt(interaction_fpath,dtype=float,delimiter=",")
elif dsetname == 'DS2':
interaction_matrix = pd.read_csv(interaction_fpath).values[:,1:]
return interaction_matrix.astype(np.int32)

def get_similarity_matrix(feat_fpath, dsetname):
if dsetname in {'DS1', 'DS3'}:
X_fea = np.loadtxt(feat_fpath,dtype=float,delimiter=",")
elif dsetname == 'DS2':
X_fea = pd.read_csv(feat_fpath).values[:,1:]
X_fea = X_fea.astype(np.float32)
return X_fea

def get_y_from_interactionmat(interaction_mat):
r, c = np.triu_indices(len(interaction_mat),1) # take indices off the diagnoal by 1
return interaction_mat[r,c]

def create_setvector_features(X, num_sim_types):
"""reshape concatenated features from every similarity type matrix into set of vectors per ddi example"""
e = X[np.newaxis, :, :]
# print('e.shape', e.shape)
f = np.transpose(e, axes=(0, 2, 1))
# print('f.shape', f.shape)
splitter = 2*num_sim_types
splitter = num_sim_types
g = np.concatenate(np.split(f, splitter, axis=1), axis=0)
# print('g.shape', g.shape)
h = np.transpose(g, axes=(2,0, 1))
# print('h.shape', h.shape)
return h

def get_stratified_partitions(ddi_datatensor, num_folds=5, valid_set_portion=0.1, random_state=42):
def get_stratified_partitions(y, num_folds=5, valid_set_portion=0.1, random_state=42):
"""Generate 5-fold stratified sample of drug-pair ids based on the interaction label
Args:
ddi_datatensor: instance of :class:`DDIDataTensor`
y: ddi labels
"""
skf_trte = StratifiedKFold(n_splits=num_folds, random_state=random_state, shuffle=True) # split train and test

skf_trv = StratifiedShuffleSplit(n_splits=2,
test_size=valid_set_portion,
random_state=random_state) # split train and test
data_partitions = {}
X = ddi_datatensor.X_feat
y = ddi_datatensor.y
X = np.zeros(len(y))
fold_num = 0
for train_index, test_index in skf_trte.split(X,y):

Expand Down Expand Up @@ -262,13 +351,15 @@ def report_label_distrib(labels):
print("class:", label, "norm count:", norm_counts[i])


def generate_partition_datatensor(ddi_datatensor, data_partitions):
def generate_partition_datatensor(ddi_datatensor, gip_dtensor_perfold, data_partitions):
datatensor_partitions = {}
print('hello')
for fold_num in data_partitions:
datatensor_partitions[fold_num] = {}
gip_datatensor = gip_dtensor_perfold[fold_num]
for dsettype in data_partitions[fold_num]:
target_ids = data_partitions[fold_num][dsettype]
datatensor_partition = PartitionDataTensor(ddi_datatensor, target_ids, dsettype, fold_num)
datatensor_partition = PartitionDataTensor(ddi_datatensor, gip_datatensor, target_ids, dsettype, fold_num)
datatensor_partitions[fold_num][dsettype] = datatensor_partition
compute_class_weights_per_fold_(datatensor_partitions)

Expand Down
35 changes: 35 additions & 0 deletions ddi/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn

class ContrastiveLoss(nn.Module):
"""
Contrastive loss
"""

def __init__(self, margin, reduction='mean', eps=1e-8):
super().__init__()
self.margin = margin
self.reduction = reduction
self.eps = eps

def forward(self, dist, target):
"""
Args:
dist: tensor, (batch, ), computed distance between two inputs
target: tensor, (batch, ), labels (0/1)
"""
margin = self.margin
reduction = self.reduction
repel = (1-target) * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2))
attract = target * 0.5 * torch.pow(dist, 2)

# repel = target * (0.5 * torch.pow(torch.clamp(margin - dist, min=0.0), 2))
# attract = (1-target) * 0.5 * torch.pow(dist, 2)

if reduction == 'mean':
loss_contrastive = torch.mean(repel + attract)
elif reduction == 'sum':
loss_contrastive = torch.sum(repel + attract)
elif reduction == 'none':
loss_contrastive = repel + attract
return loss_contrastive
23 changes: 13 additions & 10 deletions ddi/model_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class SH_SelfAttention(nn.Module):
"""
def __init__(self, input_size):

super(SH_SelfAttention, self).__init__()
super().__init__()
# define query, key and value transformation matrices
# usually input_size is equal to embed_size
self.embed_size = input_size
Expand Down Expand Up @@ -43,7 +43,7 @@ class MH_SelfAttention(nn.Module):
"""
def __init__(self, input_size, num_attn_heads):

super(MH_SelfAttention, self).__init__()
super().__init__()

layers = [SH_SelfAttention(input_size) for i in range(num_attn_heads)]

Expand Down Expand Up @@ -74,7 +74,7 @@ class TransformerUnit(nn.Module):

def __init__(self, input_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout):

super(TransformerUnit, self).__init__()
super().__init__()

embed_size = input_size
self.multihead_attn = MH_SelfAttention(input_size, num_attn_heads)
Expand Down Expand Up @@ -150,22 +150,23 @@ def forward(self, X):

class DDI_Transformer(nn.Module):

def __init__(self, input_size=586, num_attn_heads=8, mlp_embed_factor=2,
def __init__(self, input_size=586, input_embed_dim=64, num_attn_heads=8, mlp_embed_factor=2,
nonlin_func=nn.ReLU(), pdropout=0.3, num_transformer_units=12,
pooling_mode = 'attn', num_classes=2):

super(DDI_Transformer, self).__init__()

embed_size = input_size
super().__init__()

trfunit_layers = [TransformerUnit(input_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)]
embed_size = input_embed_dim

self.Wembed = nn.Linear(input_size, embed_size)

trfunit_layers = [TransformerUnit(embed_size, num_attn_heads, mlp_embed_factor, nonlin_func, pdropout) for i in range(num_transformer_units)]
self.trfunit_pipeline = nn.Sequential(*trfunit_layers)
embed_size = input_size

self.Wy = nn.Linear(embed_size, num_classes)
self.pooling_mode = pooling_mode
if pooling_mode == 'attn':
self.pooling = FeatureEmbAttention(input_size)
self.pooling = FeatureEmbAttention(embed_size)
elif pooling_mode == 'mean':
self.pooling = torch.mean

Expand All @@ -188,6 +189,8 @@ def forward(self, X):
Args:
X: tensor, (batch, ddi similarity type vector, input_size)
"""

X = self.Wembed(X)
z = self.trfunit_pipeline(X)

# mean pooling TODO: add global attention layer or other pooling strategy
Expand Down
Loading

0 comments on commit ce3d519

Please sign in to comment.