Skip to content

Commit

Permalink
fix: python 3.10 type error with None in edge weight
Browse files Browse the repository at this point in the history
  • Loading branch information
duypham2108 committed Oct 28, 2023
1 parent 09c8d7a commit 3fe707a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
2 changes: 1 addition & 1 deletion stlearn/spatials/trajectory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .compare_transitions import compare_transitions

from .set_root import set_root
from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA
from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA
10 changes: 5 additions & 5 deletions stlearn/spatials/trajectory/global_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def global_level(
verbose: bool = True,
copy: bool = False,
) -> Optional[AnnData]:

"""\
Perform global sptial trajectory inference.
Expand Down Expand Up @@ -152,15 +151,18 @@ def global_level(
labels = nx.get_edge_attributes(H_sub, "weight")

for edge, _ in labels.items():

dm = dm_list[order_big_dict[query_dict[edge[0]]]]
sdm = sdm_list[order_big_dict[query_dict[edge[0]]]]

weight = dm[order_dict[edge[0]], order_dict[edge[1]]] * w + sdm[
order_dict[edge[0]], order_dict[edge[1]]
] * (1 - w)
H_sub[edge[0]][edge[1]]["weight"] = weight
# tmp = H_sub

# Set edges with weight=None to weight=0
for u, v, tmp in H_sub.edges(data=True):
if tmp.get("weight") is None:
H_sub[u][v]["weight"] = 0

H_sub = nx.algorithms.tree.minimum_spanning_arborescence(H_sub)
H_nodes = list(range(len(H_sub.nodes)))
Expand Down Expand Up @@ -236,7 +238,6 @@ def ordering_nodes(node_list, use_label, adata):


def spatial_distance_matrix(adata, cluster1, cluster2, use_label):

tmp = adata.obs[adata.obs[use_label] == str(cluster1)]
chosen_adata1 = adata[list(tmp.index)]
tmp = adata.obs[adata.obs[use_label] == str(cluster2)]
Expand Down Expand Up @@ -267,7 +268,6 @@ def spatial_distance_matrix(adata, cluster1, cluster2, use_label):


def ge_distance_matrix(adata, cluster1, cluster2, use_label, use_rep, n_dims):

tmp = adata.obs[adata.obs[use_label] == str(cluster1)]
chosen_adata1 = adata[list(tmp.index)]
tmp = adata.obs[adata.obs[use_label] == str(cluster2)]
Expand Down
1 change: 0 additions & 1 deletion stlearn/spatials/trajectory/pseudotime.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key
if len(path) < max_nodes:
all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path


adata.uns["available_paths"] = all_paths
print(
"All available trajectory paths are stored in adata.uns['available_paths'] with length < "
Expand Down
56 changes: 33 additions & 23 deletions stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import numpy as np
from stlearn.utils import _read_graph

def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",):

def shortest_path_spatial_PAGA(
adata,
use_label,
key="dpt_pseudotime",
):
# Read original PAGA graph
G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray())
edge_weights = nx.get_edge_attributes(G, "weight")
G.remove_edges_from((e for e, w in edge_weights.items() if w <0))
G.remove_edges_from((e for e, w in edge_weights.items() if w < 0))
H = G.to_directed()

# Get min_node and max_node
min_node,max_node = find_min_max_node(adata,key,use_label)
min_node, max_node = find_min_max_node(adata, key, use_label)

# Calculate pseudotime for each node
node_pseudotime = {}

Expand All @@ -26,69 +31,74 @@ def shortest_path_spatial_PAGA(adata,use_label,key="dpt_pseudotime",):
if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0:
edge_to_remove.append(edge)
H.remove_edges_from(edge_to_remove)

# Extract all available paths
all_paths = {}
j = 0
for source in H.nodes:
for target in H.nodes:
paths = nx.all_simple_paths(H, source=source, target=target)
for i, path in enumerate(paths):
j+=1
j += 1
all_paths[j] = path

# Filter the target paths from min_node to max_node
target_paths = []
for path in list(all_paths.values()):
if path[0] == min_node and path[-1] == max_node:
target_paths.append(path)

# Get the global graph
G = _read_graph(adata, "global_graph")

centroid_dict = adata.uns["centroid_dict"]
centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict}

# Generate total length of every path. Store by dictionary
dist_dict = {}
for path in target_paths:
path_name = ",".join(list(map(str,path)))
path_name = ",".join(list(map(str, path)))
result = []
query_node = get_node(path, adata.uns["split_node"])
for edge in G.edges():
if (edge[0] in query_node) and (edge[1] in query_node):
result.append(edge)
if len(result) >= len(path):
dist_dict[path_name] = calculate_total_dist(result,centroid_dict)
dist_dict[path_name] = calculate_total_dist(result, centroid_dict)

# Find the shortest path
shortest_path = min(dist_dict, key=lambda x: dist_dict[x])
return shortest_path.split(',')
return shortest_path.split(",")


# get name of cluster by subcluster
def get_cluster(search, dictionary):
for cl, sub in dictionary.items():
if search in sub:
return cl


def get_node(node_list, split_node):
result = np.array([])
for node in node_list:
result = np.append(result, np.array(split_node[int(node)]).astype(int))
return result.astype(int)

def find_min_max_node(adata,key="dpt_pseudotime",use_label="leiden"):
min_cluster = int(adata.obs[adata.obs[key]==0][use_label].values[0])
max_cluster = int(adata.obs[adata.obs[key]==1][use_label].values[0])

return [min_cluster,max_cluster]

def calculate_total_dist(result,centroid_dict):
def find_min_max_node(adata, key="dpt_pseudotime", use_label="leiden"):
min_cluster = int(adata.obs[adata.obs[key] == 0][use_label].values[0])
max_cluster = int(adata.obs[adata.obs[key] == 1][use_label].values[0])

return [min_cluster, max_cluster]


def calculate_total_dist(result, centroid_dict):
import math

total_dist = 0
for edge in result:
source = centroid_dict[edge[0]]
target = centroid_dict[edge[1]]
dist =math.dist(source,target)
dist = math.dist(source, target)
total_dist += dist
return total_dist
return total_dist

0 comments on commit 3fe707a

Please sign in to comment.