Skip to content

Commit

Permalink
Refactored gsva
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM committed Aug 8, 2024
1 parent 506a050 commit 4db0c85
Showing 1 changed file with 90 additions and 62 deletions.
152 changes: 90 additions & 62 deletions decoupler/method_gsva.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from scipy.sparse import csr_matrix
from scipy.sparse import issparse
import math

from .pre import extract, rename_net, filt_min_n, return_data, break_ties
Expand Down Expand Up @@ -140,76 +140,99 @@ def density(mat, kcdf):
return mat


@nb.njit(nb.types.Tuple((nb.f8[:, :], nb.i8[:, :]))(nb.f8[:, :]), parallel=True, cache=True)
def nb_get_D_I(mat):
n = mat.shape[1]
rev_idx = np.abs(np.arange(n, 0, -1, nb.f8) - n / 2)
Idx = np.zeros(mat.shape, dtype=nb.i8)
for i in nb.prange(mat.shape[0]):
Idx[i] = np.argsort(-mat[i])
tmp = np.zeros(n, dtype=nb.f8)
tmp[Idx[i]] = rev_idx
mat[i] = tmp
return mat, Idx


@nb.njit(nb.f8(nb.f8[:], nb.i8[:], nb.i8, nb.i8[:], nb.i8[:], nb.i8, nb.f8), cache=True)
def ks_sample(D, Idx, n_genes, geneset_mask, fset, n_geneset, dec):

sum_gset = 0.0
for i in nb.prange(n_geneset):
sum_gset += D[fset[i]]

mx_value_sign = 0.0
cum_sum = 0.0
mx_pos = 0.0
mx_neg = 0.0

for i in nb.prange(n_genes):
idx = Idx[i]
if geneset_mask[idx] == 1:
cum_sum += D[idx] / sum_gset
@nb.njit(nb.types.Tuple((nb.i8[:, :], nb.i8[:, :]))(nb.f8[:, :]), parallel=True, cache=True)
def order_rankstat(mat):
n_rows, n_cols = mat.shape
ord_mat = np.zeros((n_rows, n_cols), dtype=nb.i8)
rst_mat = np.zeros((n_rows, n_cols), dtype=nb.i8)
for i in range(n_rows):
ord = np.argsort(-mat[i, :]) + 1
rst = np.zeros(n_cols, dtype=nb.i8)
for j in range(n_cols):
rst[ord[j] - 1] = abs(n_cols - j - (n_cols // 2))
ord_mat[i, :] = ord
rst_mat[i, :] = rst
return ord_mat, rst_mat


@nb.njit(nb.types.UniTuple(nb.f8, 2)(nb.i8[:], nb.i8, nb.i8[:], nb.i8[:], nb.i8), cache=True)
def rnd_walk(gsetidx, k, generanking, rankstat, n):
stepcdfingeneset = np.zeros(n, dtype=np.int32)
stepcdfoutgeneset = np.ones(n, dtype=np.int32)
for i in range(k):
idx = gsetidx[i] - 1
stepcdfingeneset[idx] = rankstat[generanking[idx] - 1]
stepcdfoutgeneset[idx] = 0
for i in range(1, n):
stepcdfingeneset[i] += stepcdfingeneset[i-1]
stepcdfoutgeneset[i] += stepcdfoutgeneset[i-1]
walkstatpos = -np.inf
walkstatneg = np.inf
walkstat = np.zeros(n, dtype=np.float64)
for i in range(n):
wlkstat = (stepcdfingeneset[i] / stepcdfingeneset[-1]) - (stepcdfoutgeneset[i] / stepcdfoutgeneset[-1])
walkstat[i] = wlkstat
if wlkstat > walkstatpos:
walkstatpos = wlkstat
if wlkstat < walkstatneg:
walkstatneg = wlkstat
return walkstatpos, walkstatneg


@nb.njit(nb.f8(nb.i8[:], nb.i8[:], nb.i8[:], nb.b1, nb.b1), cache=True)
def score_geneset(gsetidx, generanking, rankstat, maxdiff, absrnk):
n = len(generanking)
k = len(gsetidx)
walkstatpos, walkstatneg = rnd_walk(gsetidx, k, generanking, rankstat, n)
if maxdiff:
if absrnk:
es = walkstatpos - walkstatneg
else:
cum_sum -= dec

if cum_sum > mx_pos:
mx_pos = cum_sum
if cum_sum < mx_neg:
mx_neg = cum_sum

mx_value_sign = mx_pos + mx_neg

return mx_value_sign


@nb.njit(nb.f8[:](nb.f8[:, :], nb.i8[:, :], nb.i8[:]), parallel=True, cache=True)
def ks_matrix(D, Idx, fset):
n_samples, n_genes = D.shape
n_geneset = fset.shape[0]

geneset_mask = np.zeros(n_genes, dtype=nb.i8)
geneset_mask[fset] = 1

dec = 1.0 / (n_genes - n_geneset)

es = walkstatpos + walkstatneg
else:
es = walkstatpos if abs(walkstatpos) > abs(walkstatneg) else walkstatneg
return es


@nb.njit(nb.i8[:](nb.i8[:], nb.i8[:]), cache=True)
def match(a, b):
max_b = np.max(b) if len(b) > 0 else 0
index_array = np.full(max_b + 1, -1, dtype=nb.i8)
for idx, value in enumerate(b):
if 0 <= value <= max_b:
index_array[value] = idx
result = np.full(len(a), -1, dtype=nb.i8)
for i in range(len(a)):
if 0 <= a[i] <= max_b:
result[i] = index_array[a[i]]
return result + 1


@nb.njit(nb.f8[:](nb.i8[:, :], nb.i8[:, :], nb.i8[:], nb.b1, nb.b1), parallel=True, cache=True)
def ks_fset(ord, rst, fset, maxdiff, absrnk):
n_samples, n_genes = ord.shape
res = np.zeros(n_samples, dtype=nb.f8)
for i in nb.prange(n_samples):
res[i] = ks_sample(D[i], Idx[i], n_genes, geneset_mask, fset, n_geneset, dec)

generanking = ord[i]
rankstat = rst[i]
genesetsrankidx = match(fset, generanking)
res[i] = score_geneset(genesetsrankidx, generanking, rankstat, maxdiff, absrnk)
return res


def gsva(mat, net, kcdf=False, verbose=False):
def gsva(mat, net, kcdf=False, maxdiff=True, absrnk=False, verbose=False):

if issparse(mat):
mat = mat.toarray()
# Get feature Density
mat = density(mat, kcdf=kcdf)
mat, Idx = nb_get_D_I(mat)
ord, rst = order_rankstat(mat)

# Run GSVA for each feature set
acts = np.zeros((mat.shape[0], len(net)))
acts = np.zeros((ord.shape[0], len(net)))
for j in tqdm(range(len(net)), disable=not verbose):
fset = net.iloc[j]
acts[:, j] = ks_matrix(mat, Idx, fset)
fset = net.iloc[j] + 1
acts[:, j] = ks_fset(ord, rst, fset, maxdiff, absrnk)

return acts

Expand Down Expand Up @@ -265,6 +288,13 @@ def run_gsva(mat, net, source='source', target='target', kcdf='gaussian', mx_dif
# Extract sparse matrix and array of genes
m, r, c = extract(mat, use_raw=use_raw, verbose=verbose)

# Remove repeated features
if issparse(m):
m = m.toarray()
msk = ~np.all(m == m[0, :], axis=0)
m = m[:, msk]
c = c[msk]

# Transform net
net = rename_net(net, source=source, target=target, weight=None)
net = filt_min_n(c, net, min_n=min_n)
Expand All @@ -281,9 +311,7 @@ def run_gsva(mat, net, source='source', target='target', kcdf='gaussian', mx_dif
print('Running gsva on mat with {0} samples and {1} targets for {2} sources.'.format(m.shape[0], len(c), len(net)))

# Run GSVA
if isinstance(m, csr_matrix):
m = m.toarray()
estimate = gsva(m, net, kcdf=kcdf, verbose=verbose)
estimate = gsva(m, net, kcdf=kcdf, maxdiff=mx_diff, absrnk=abs_rnk, verbose=verbose)

# Transform to df
estimate = pd.DataFrame(estimate, index=r, columns=net.index)
Expand Down

0 comments on commit 4db0c85

Please sign in to comment.