Skip to content

Commit 200d2cd

Browse files
committed
num_undir_edges key fix
1 parent 65dfda6 commit 200d2cd

3 files changed

Lines changed: 6 additions & 5 deletions

File tree

chebai_graph/preprocessing/properties/augmented_properties.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_property_value(self, augmented_mol: Dict) -> List:
5353
[self.get_bond_value(bond) for bond in fg_graph_node_edges.values()]
5454
)
5555

56-
num_directed_edges = augmented_mol[self.MAIN_KEY]["num_undirected_edges"] // 2
56+
num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2
5757
assert (
5858
len(prop_list) == num_directed_edges
5959
), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "

chebai_graph/preprocessing/properties/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
ATOM_FG_EDGE = "atom_fg_lvl"
1111
FG_GRAPHNODE_EDGE = "fg_graphNode_lvl"
1212
EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_EDGE}
13+
NUM_EDGES = "num_undirected_edges"

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _read_data(self, smiles: str) -> GeomData | None:
181181

182182
# Empty features initialized; node and edge features can be added later
183183
x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0))
184-
edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"], 0))
184+
edge_attr = torch.zeros((augmented_molecule["edges"][NUM_EDGES], 0))
185185

186186
assert (
187187
edge_index.shape[0] == 2
@@ -297,7 +297,7 @@ def _augment_graph_structure(
297297
ATOM_FG_EDGE: atom_fg_edges,
298298
WITHIN_FG_EDGE: internal_fg_edges,
299299
FG_GRAPHNODE_EDGE: fg_to_graph_edges,
300-
"num_undirected_edges": self._num_of_edges * 2, # Undirected edges
300+
NUM_EDGES: self._num_of_edges * 2, # Undirected edges
301301
}
302302
return undirected_edge_index, node_info, edge_info
303303

@@ -449,12 +449,12 @@ def _construct_fg_level_structure(
449449
source_fg is not None and target_fg is not None
450450
), "Each bond should have a fg node on both end"
451451

452-
internal_edge_index[0].append(source_fg)
453-
internal_edge_index[1].append(target_fg)
454452
edge_str = f"{source_fg}_{target_fg}"
455453
if edge_str not in internal_fg_edges:
456454
# If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted.
457455
# Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG
456+
internal_edge_index[0].append(source_fg)
457+
internal_edge_index[1].append(target_fg)
458458
internal_fg_edges[edge_str] = {EDGE_LEVEL: WITHIN_FG_EDGE}
459459
self._num_of_edges += 1
460460

0 commit comments

Comments
 (0)