|
3 | 3 | import matplotlib |
4 | 4 | import matplotlib.pyplot as plt |
5 | 5 | import networkx as nx |
| 6 | +import numpy as np |
6 | 7 | from jsonargparse import CLI |
7 | 8 | from PIL import Image |
8 | 9 | from rdkit.Chem import AllChem, BondType, Mol, rdDepictor |
@@ -232,14 +233,34 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: |
232 | 233 | for pos in [conf.GetAtomPosition(atom.GetIdx())] |
233 | 234 | } |
234 | 235 |
|
235 | | - # Generate 3D layout for FG and graph nodes |
236 | | - fg_graph = _get_subgraph_by_node_type(G, "fg") |
237 | | - fg_pos_3d = nx.spring_layout(fg_graph, seed=42, dim=3) |
238 | | - fg_pos = {k: (x, y, z + 2) for k, (x, y, z) in fg_pos_3d.items()} |
| 236 | + # Dictionary to store functional group node positions |
| 237 | + fg_pos = {} |
239 | 238 |
|
240 | | - graph_node_graph = _get_subgraph_by_node_type(G, "graph") |
241 | | - graph_pos_3d = nx.spring_layout(graph_node_graph, seed=123, dim=3) |
242 | | - graph_pos = {k: (x, y, z + 4) for k, (x, y, z) in graph_pos_3d.items()} |
| 239 | + # Loop through each functional group node in the graph |
| 240 | + for fg_node in _get_subgraph_by_node_type(G, "fg").nodes(): |
| 241 | + # Get connected atom nodes (assuming edges are between fg and atom nodes) |
| 242 | + connected_atoms = [ |
| 243 | + nbr |
| 244 | + for nbr in G.neighbors(fg_node) |
| 245 | + if G.nodes[nbr].get("node_type") == "atom" |
| 246 | + ] |
| 247 | + |
| 248 | + # Get the 2D positions of the connected atoms |
| 249 | + positions = np.array([atom_pos[atom] for atom in connected_atoms]) |
| 250 | + x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean() |
| 251 | + fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation |
| 252 | + |
| 253 | + graph_node = next(iter(_get_subgraph_by_node_type(G, "graph").nodes())) |
| 254 | + graph_pos_arr = np.array( |
| 255 | + [ |
| 256 | + fg_pos[nbr] |
| 257 | + for nbr in G.neighbors(graph_node) |
| 258 | + if G.nodes[nbr].get("node_type") == "fg" |
| 259 | + ] |
| 260 | + ) |
| 261 | + graph_pos = { |
| 262 | + graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4) |
| 263 | + } |
243 | 264 |
|
244 | 265 | pos = {**atom_pos, **fg_pos, **graph_pos} |
245 | 266 |
|
|
0 commit comments