Skip to content

Commit 0f137b8

Browse files
committed
3d ◀️ mean position of connected atoms for aug nodes
1 parent 4b4179b commit 0f137b8

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

chebai_graph/preprocessing/utils/visualize_augmented_molecule.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib
44
import matplotlib.pyplot as plt
55
import networkx as nx
6+
import numpy as np
67
from jsonargparse import CLI
78
from PIL import Image
89
from rdkit.Chem import AllChem, BondType, Mol, rdDepictor
@@ -232,14 +233,34 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None:
232233
for pos in [conf.GetAtomPosition(atom.GetIdx())]
233234
}
234235

235-
# Generate 3D layout for FG and graph nodes
236-
fg_graph = _get_subgraph_by_node_type(G, "fg")
237-
fg_pos_3d = nx.spring_layout(fg_graph, seed=42, dim=3)
238-
fg_pos = {k: (x, y, z + 2) for k, (x, y, z) in fg_pos_3d.items()}
236+
# Dictionary to store functional group node positions
237+
fg_pos = {}
239238

240-
graph_node_graph = _get_subgraph_by_node_type(G, "graph")
241-
graph_pos_3d = nx.spring_layout(graph_node_graph, seed=123, dim=3)
242-
graph_pos = {k: (x, y, z + 4) for k, (x, y, z) in graph_pos_3d.items()}
239+
# Loop through each functional group node in the graph
240+
for fg_node in _get_subgraph_by_node_type(G, "fg").nodes():
241+
# Get connected atom nodes (assuming edges are between fg and atom nodes)
242+
connected_atoms = [
243+
nbr
244+
for nbr in G.neighbors(fg_node)
245+
if G.nodes[nbr].get("node_type") == "atom"
246+
]
247+
248+
# Get the 2D positions of the connected atoms
249+
positions = np.array([atom_pos[atom] for atom in connected_atoms])
250+
x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean()
251+
fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation
252+
253+
graph_node = next(iter(_get_subgraph_by_node_type(G, "graph").nodes()))
254+
graph_pos_arr = np.array(
255+
[
256+
fg_pos[nbr]
257+
for nbr in G.neighbors(graph_node)
258+
if G.nodes[nbr].get("node_type") == "fg"
259+
]
260+
)
261+
graph_pos = {
262+
graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4)
263+
}
243264

244265
pos = {**atom_pos, **fg_pos, **graph_pos}
245266

0 commit comments

Comments
 (0)