1+ import networkx as nx
2+ import numpy as np
3+ from stlearn .utils import _read_graph
4+
5+ def shortest_path_spatial_PAGA (adata ,use_label ,key = "dpt_pseudotime" ,):
6+ # Read original PAGA graph
7+ G = nx .from_numpy_array (adata .uns ["paga" ]["connectivities" ].toarray ())
8+ edge_weights = nx .get_edge_attributes (G , "weight" )
9+ G .remove_edges_from ((e for e , w in edge_weights .items () if w < 0 ))
10+ H = G .to_directed ()
11+
12+ # Get min_node and max_node
13+ min_node ,max_node = find_min_max_node (adata ,key ,use_label )
14+
15+ # Calculate pseudotime for each node
16+ node_pseudotime = {}
17+
18+ for node in H .nodes :
19+ node_pseudotime [node ] = adata .obs .query (use_label + " == '" + str (node ) + "'" )[
20+ key
21+ ].max ()
22+
23+ # Force original PAGA to directed PAGA based on pseudotime
24+ edge_to_remove = []
25+ for edge in H .edges :
26+ if node_pseudotime [edge [0 ]] - node_pseudotime [edge [1 ]] > 0 :
27+ edge_to_remove .append (edge )
28+ H .remove_edges_from (edge_to_remove )
29+
30+ # Extract all available paths
31+ all_paths = {}
32+ j = 0
33+ for source in H .nodes :
34+ for target in H .nodes :
35+ paths = nx .all_simple_paths (H , source = source , target = target )
36+ for i , path in enumerate (paths ):
37+ j += 1
38+ all_paths [j ] = path
39+
40+ # Filter the target paths from min_node to max_node
41+ target_paths = []
42+ for path in list (all_paths .values ()):
43+ if path [0 ] == min_node and path [- 1 ] == max_node :
44+ target_paths .append (path )
45+
46+ # Get the global graph
47+ G = _read_graph (adata , "global_graph" )
48+
49+ centroid_dict = adata .uns ["centroid_dict" ]
50+ centroid_dict = {int (key ): centroid_dict [key ] for key in centroid_dict }
51+
52+ # Generate total length of every path. Store by dictionary
53+ dist_dict = {}
54+ for path in target_paths :
55+ path_name = "," .join (list (map (str ,path )))
56+ result = []
57+ query_node = get_node (path , adata .uns ["split_node" ])
58+ for edge in G .edges ():
59+ if (edge [0 ] in query_node ) and (edge [1 ] in query_node ):
60+ result .append (edge )
61+ if len (result ) >= len (path ):
62+ dist_dict [path_name ] = calculate_total_dist (result ,centroid_dict )
63+
64+ # Find the shortest path
65+ shortest_path = min (dist_dict , key = lambda x : dist_dict [x ])
66+ return shortest_path .split (',' )
67+
68+ # get name of cluster by subcluster
69+ def get_cluster (search , dictionary ):
70+ for cl , sub in dictionary .items ():
71+ if search in sub :
72+ return cl
73+
74+ def get_node (node_list , split_node ):
75+ result = np .array ([])
76+ for node in node_list :
77+ result = np .append (result , np .array (split_node [int (node )]).astype (int ))
78+ return result .astype (int )
79+
80+ def find_min_max_node (adata ,key = "dpt_pseudotime" ,use_label = "leiden" ):
81+ min_cluster = int (adata .obs [adata .obs [key ]== 0 ][use_label ].values [0 ])
82+ max_cluster = int (adata .obs [adata .obs [key ]== 1 ][use_label ].values [0 ])
83+
84+ return [min_cluster ,max_cluster ]
85+
86+ def calculate_total_dist (result ,centroid_dict ):
87+ import math
88+ total_dist = 0
89+ for edge in result :
90+ source = centroid_dict [edge [0 ]]
91+ target = centroid_dict [edge [1 ]]
92+ dist = math .dist (source ,target )
93+ total_dist += dist
94+ return total_dist
0 commit comments