diff --git a/.gitignore b/.gitignore index c909c30..770f8fe 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ **/*.ckpt logs/* +results/* checkpoints/* data/*/*/*.pt data/*/*/*.pkl diff --git a/configs/msgym.yaml b/configs/msgym.yaml index 1237280..473a627 100644 --- a/configs/msgym.yaml +++ b/configs/msgym.yaml @@ -20,8 +20,8 @@ enable_progress_bar: True # Data dataset: msgym batch_size: 256 -num_workers: 47 -shuffle: True +num_workers: 16 +shuffle: False extra_nodes: True swap: False diff --git a/evaluation_generation.py b/evaluation_generation.py index 87f8e16..d6bfd11 100644 --- a/evaluation_generation.py +++ b/evaluation_generation.py @@ -79,22 +79,12 @@ def calculate_mces(mces, pairs): mces_thld = 100 mces_cache = {} myopic_mces = MyopicMCES( - threshold=20, - solver='HiGHS', - solver_options={ - 'msg': 0, - 'log_to_console': False, - 'output_flag': False, - 'time_limit': 10, # Optional: add timeout - 'log_file': os.devnull, # Redirect logs to nowhere - 'highs_debug_level': 0, - 'highs_verbosity': 'off' - } + threshold=20 ) for k in ks: result_metric = {"accuracy": 0, "similarity": 0, "MCES": 0} count = 0 - sub_dfs = split_dataframe(df1, chunk_size=50) + sub_dfs = split_dataframe(df1, chunk_size=100) for df in tqdm(sub_dfs): smile = list(df["true"])[0] pred_smiles = sorted(list(df["pred"]), key=lambda x: list(df["pred"]).count(x), reverse=True) @@ -124,19 +114,19 @@ def calculate_mces(mces, pairs): # if Chem.MolToSmiles(mol) != Chem.MolToSmiles(GetScaffoldForMol(Chem.MolFromSmiles(scaf_smi))): # print('scaffold match', smile) result_metric["accuracy"] += int(in_top_k) - # dists = [] + dists = [] # pairs = [(smile, pred) for pred, pred_mol in zip(pred_smiles, pred_mols) if pred_mol is not None] # results = calculate_mces(myopic_mces, pairs) # dists = [results.get((smile, pred), mces_thld) for pred in pred_smiles] - # for pred, pred_mol in zip(pred_smiles, pred_mols): - # if pred_mol is None: - # dists.append(mces_thld) - # else: - # if (smile, pred) not in mces_cache: - # mce_val = myopic_mces(smile, pred) - # mces_cache[(smile, pred)] = mce_val - # dists.append(mces_cache[(smile, pred)]) + for pred, pred_mol in zip(pred_smiles, pred_mols): + if pred_mol is None: + dists.append(mces_thld) + else: + if (smile, pred) not in mces_cache: + mce_val = myopic_mces(smile, pred) + mces_cache[(smile, pred)] = mce_val + dists.append(mces_cache[(smile, pred)]) mol_fp = GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) pred_fps = [ GetMorganFingerprintAsBitVect(pred, radius=2, nBits=2048) if pred is not None else None for pred in pred_mols @@ -145,7 +135,7 @@ def calculate_mces(mces, pairs): TanimotoSimilarity(mol_fp, pred) if pred is not None else 0 for pred in pred_fps ] result_metric["similarity"] += max(sims) - # result_metric["MCES"] += min(min(dists), mces_thld) + result_metric["MCES"] += min(min(dists), mces_thld) for key in result_metric: result_metric[key] = result_metric[key] / len(sub_dfs) print(dataset, k, result_metric) diff --git a/metrics_calculation.py b/metrics_calculation.py new file mode 100644 index 0000000..0caaae6 --- /dev/null +++ b/metrics_calculation.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python3 +import os +import argparse +from collections import defaultdict + +import pandas as pd +import numpy as np +from tqdm import tqdm + +from rdkit import Chem, RDLogger, DataStructs +from rdkit.Chem import rdMolDescriptors as rdmd +from rdkit.Chem.rdmolops import RemoveHs +from rdkit.Chem.MolStandardize import rdMolStandardize + +RDLogger.DisableLog("rdApp.*") + +# ------------------------- MCES wrapper (lazy import) ------------------------- +class MyopicMCES: + def __init__( + self, + ind=0, + solver=None, # let myopic_mces pick default if None + threshold=20, + always_stronger_bound=True, + solver_options=None, + ): + # import here so script runs even if MCES is disabled / not installed + from myopic_mces.myopic_mces import MCES as _MCES + import pulp + + self._MCES = _MCES + self.ind = ind + # choose an available solver if none given + self.solver = solver if solver is not None else pulp.listSolvers(onlyAvailable=True)[0] + self.threshold = threshold + self.always_stronger_bound = always_stronger_bound + if solver_options is None: + solver_options = dict(msg=0) + self.solver_options = solver_options + + def __call__(self, s1, s2): + try: + _, dist = self._MCES( + s1=s1, + s2=s2, + ind=self.ind, + threshold=self.threshold, + always_stronger_bound=self.always_stronger_bound, + solver=self.solver, + solver_options=self.solver_options, + ) + return float(dist) + except Exception: + return float(self.threshold) + + +# ------------------------- chemistry helpers --------------------------------- +_lf = rdMolStandardize.LargestFragmentChooser() + +def mol_from_smiles(smi: str): + if not isinstance(smi, str): + return None + try: + m = Chem.MolFromSmiles(smi) + except Exception: + return None + if m is None: + return None + try: + m = _lf.choose(m) # keep largest fragment + except Exception: + pass + try: + m = RemoveHs(m) + except Exception: + pass + return m + +def inchi_block(mol): + try: + return Chem.inchi.MolToInchiKey(mol).split("-")[0] + except Exception: + return None + +def has_scaffold(pred_mol, scaffold_mol, use_chirality=False): + if pred_mol is None or scaffold_mol is None: + return False + try: + return pred_mol.HasSubstructMatch(scaffold_mol, useChirality=use_chirality) + except Exception: + return False + + +# ------------------------- metrics ------------------------------------------- +def tanimoto_max(mol, pred_mols): + if mol is None or not pred_mols: + return 0.0 + fp = rdmd.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) + sims = [] + for pm in pred_mols: + if pm is None: + sims.append(0.0) + else: + pf = rdmd.GetMorganFingerprintAsBitVect(pm, radius=2, nBits=2048) + sims.append(DataStructs.TanimotoSimilarity(fp, pf)) + return max(sims) if sims else 0.0 + +def accuracy_at_k(true_mol, pred_mols, k): + if true_mol is None or not pred_mols or k <= 0: + return 0 + true_block = inchi_block(true_mol) + if true_block is None: + return 0 + topk = pred_mols[:k] + return int(any((pm is not None) and (inchi_block(pm) == true_block) for pm in topk)) + +def mces_min_at_k(mces_fn, true_smiles, pred_smiles, k, cap): + vals = [] + for smi in pred_smiles[:k]: + if smi is None: + vals.append(cap) + else: + vals.append(mces_fn(true_smiles, smi)) + return min(vals) if vals else cap + + +# ------------------------- plotting ------------------------------------------ +def _ensure_matplotlib_headless(): + # ensure plotting works on servers without a display + import matplotlib + try: + matplotlib.get_backend() + except Exception: + pass + matplotlib.use("Agg") + import matplotlib.pyplot as plt + return matplotlib, plt + +def make_histograms(per_df, ks, out_dir, bins=40, dpi=200): + _, plt = _ensure_matplotlib_headless() + plot_dir = os.path.join(out_dir, "plots") + os.makedirs(plot_dir, exist_ok=True) + + def _plot_one_prefix(prefix, xlabel, fname): + # per-k histograms + gray "sum of distributions" overlay + series_by_k = [] + for k in ks: + col = f"{prefix}@{k}" + if col in per_df.columns: + series_by_k.append((k, per_df[col].dropna().values)) + if not series_by_k: + return + + # common bins over pooled data + pooled = np.concatenate([v for _, v in series_by_k]) if series_by_k else np.array([]) + if pooled.size == 0: + return + counts_by_k = [] + hist_bins = bins + # compute bin edges once + counts0, edges = np.histogram(pooled, bins=hist_bins) + bin_centers = 0.5 * (edges[:-1] + edges[1:]) + # recompute each k with same edges + for k, arr in series_by_k: + c, _ = np.histogram(arr, bins=edges) + counts_by_k.append((k, c)) + # gray sum-of-distributions + summed = np.sum([c for _, c in counts_by_k], axis=0) + + plt.figure(figsize=(7, 5)) + # background gray bars (sum) + plt.bar(bin_centers, summed, width=(edges[1]-edges[0]), color="lightgray", edgecolor=None, alpha=0.6, label="sum over k") + # per-k step lines on top + for k, c in counts_by_k: + plt.plot(bin_centers, c, lw=2, label=f"k={k}") + plt.xlabel(xlabel) + plt.ylabel("count (samples)") + plt.title(f"{prefix}: distribution across samples") + plt.grid(alpha=0.3, linestyle="--") + plt.legend() + plt.tight_layout() + plt.savefig(os.path.join(plot_dir, fname), dpi=dpi, bbox_inches="tight") + plt.close() + + # also save individual per-k histograms + for k, arr in series_by_k: + plt.figure(figsize=(7, 5)) + plt.hist(arr, bins=edges, color=None, edgecolor="black") + plt.xlabel(xlabel) + plt.ylabel("count (samples)") + plt.title(f"{prefix} @ k={k}") + plt.grid(alpha=0.3, linestyle="--") + plt.tight_layout() + plt.savefig(os.path.join(plot_dir, f"{prefix}_k{k}.png"), dpi=dpi, bbox_inches="tight") + plt.close() + + # what we can plot + _plot_one_prefix("tanimoto", "Tanimoto (best in top-k)", "tanimoto_multi_k.png") + _plot_one_prefix("frac_scaffold", "fraction with scaffold (top-k)", "scaffold_frac_multi_k.png") + if any(f"acc@{k}" in per_df.columns for k in ks): + _plot_one_prefix("acc", "accuracy (0/1) in top-k", "accuracy_multi_k.png") + if any(f"mces_min@{k}" in per_df.columns for k in ks): + _plot_one_prefix("mces_min", "MCES min (lower is better)", "mces_min_multi_k.png") + + +# ------------------------- main ---------------------------------------------- +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--file_path", required=True, help="CSV with columns: scaffold,pred,true,score,...") + ap.add_argument("--out_dir", required=True, help="Where to write outputs (CSVs, plots)") + ap.add_argument("--ks", default="1,10,100", help="Comma-separated list of k values (e.g. '1,10,100,1000')") + # ↓↓↓ group by TRUE only + ap.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + ap.add_argument("--group_cols", default="true", + help="Columns to group predictions by (default: true)") + ap.add_argument("--enable_mces", action="store_true", help="Compute MCES (off by default)") + ap.add_argument("--mces_cap", type=float, default=100.0, help="Cap (fallback) for failed MCES") + ap.add_argument("--mces_threshold", type=float, default=20.0, help="MCES solver threshold") + ap.add_argument("--make_hists", action="store_true", help="Save histograms to out_dir/plots") + ap.add_argument("--hist_bins", type=int, default=40, help="Number of bins for histograms") + ap.add_argument("--hist_dpi", type=int, default=300, help="DPI for histogram PNGs") + args = ap.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + ks = [int(x) for x in args.ks.split(",") if x.strip()] + df = pd.read_csv(args.file_path) + + required_cols = {"true", "pred", "scaffold"} + missing = required_cols - set(df.columns) + if missing: + raise SystemExit(f"Missing columns in input CSV: {missing}") + + # ----- Keep ORIGINAL SAMPLE ORDER (by first appearance) ----- + group_cols = [c.strip() for c in args.group_cols.split(",") if c.strip()] # -> ["true"] + df["_group_id"] = df.groupby(group_cols, sort=False).ngroup() + # ------------------------------------------------------------ + + use_mces = bool(args.enable_mces) + + # Prepare MCES if requested + my_mces = None + if use_mces: + try: + my_mces = MyopicMCES(threshold=args.mces_threshold) + except Exception as e: + raise SystemExit(f"Could not initialize MCES: {e}") + + per_rows = [] + + # group by sample (true only) + for keys, g in tqdm(df.groupby(group_cols, sort=False), desc="Samples"): + keys = keys if isinstance(keys, tuple) else (keys,) + sample = {col: val for col, val in zip(group_cols, keys)} + + true_smi = g["true"].iloc[0] + scaf_smi = g["scaffold"].iloc[0] + preds_smi = list(g["pred"].astype(str).values) + + # build mols and filter invalid predictions; compute k_eff + true_mol = mol_from_smiles(true_smi) + scaf_mol = mol_from_smiles(scaf_smi) + + pred_mols_all = [mol_from_smiles(s) for s in preds_smi] + valid_pairs = [(s, m) for s, m in zip(preds_smi, pred_mols_all) if m is not None] + n_valid = len(valid_pairs) + + if n_valid: + valid_smiles = [s for s, _ in valid_pairs] + freq = pd.Series(valid_smiles).value_counts() + + first_idx = {} + for i, s in enumerate(preds_smi): # original order tie-break + if s not in first_idx: + first_idx[s] = i + valid_smiles_sorted = sorted(valid_smiles, key=lambda x: (-freq[x], first_idx[x])) + pred_mols_sorted = [mol_from_smiles(s) for s in valid_smiles_sorted] + else: + valid_smiles_sorted = [] + pred_mols_sorted = [] + + frac_scaf_all = 0.0 + if n_valid: + scaf_flags_all = [has_scaffold(m, scaf_mol) for _, m in valid_pairs] + frac_scaf_all = float(sum(scaf_flags_all)) / n_valid + + row = { + **sample, + "n_preds": len(preds_smi), + "n_valid": n_valid, + "frac_scaffold_all": frac_scaf_all, + } + + # per-k metrics using k_eff = min(k, n_valid) + for k in ks: + k_eff = min(k, n_valid) + if k_eff > 0: + top_mols = pred_mols_sorted[:k_eff] + top_smiles = valid_smiles_sorted[:k_eff] + + flags_k = [has_scaffold(m, scaf_mol) for m in top_mols] + row[f"frac_scaffold@{k}"] = float(sum(flags_k)) / k_eff + + row[f"acc@{k}"] = accuracy_at_k(true_mol, top_mols, k_eff) + row[f"tanimoto@{k}"] = tanimoto_max(true_mol, top_mols) if true_mol else 0.0 + + if use_mces: + row[f"mces_min@{k}"] = mces_min_at_k(my_mces, true_smi, top_smiles, k_eff, cap=float(args.mces_cap)) + else: + row[f"frac_scaffold@{k}"] = 0.0 + row[f"acc@{k}"] = 0 + row[f"tanimoto@{k}"] = 0.0 + if use_mces: + row[f"mces_min@{k}"] = float(args.mces_cap) + + per_rows.append(row) + + # ---------------- save per-sample ---------------- + per_df = pd.DataFrame(per_rows) + + # restore original sample order using _group_id + order = df[group_cols + ["_group_id"]].drop_duplicates() + per_df = (per_df + .merge(order, on=group_cols, how="left") + .sort_values("_group_id") + .drop(columns="_group_id")) + + per_csv = os.path.join(args.out_dir, "per_sample_metrics.csv") + #per_df.to_csv(per_csv, index=False) + + # ---------------- save summaries ---------------- + num_cols = [c for c in per_df.columns if c not in group_cols] + summary_df = pd.DataFrame({ + "mean_over_samples": per_df[num_cols].mean(numeric_only=True), + "median_over_samples": per_df[num_cols].median(numeric_only=True) + }).T + summary_csv = os.path.join(args.out_dir, "summary_metrics.csv") + #summary_df.to_csv(summary_csv) + + # compact by-k summary (means only) + rows = [] + for k in ks: + row = {"k": k} + for prefix in ["acc", "tanimoto", "frac_scaffold"]: + col = f"{prefix}@{k}" + if col in per_df.columns: + row[prefix] = float(per_df[col].mean()) + if "mces_min@{}".format(k) in per_df.columns: + row["mces_min"] = float(per_df[f"mces_min@{k}"].mean()) + rows.append(row) + byk_df = pd.DataFrame(rows) + byk_csv = os.path.join(args.out_dir, f"summary_by_k_{args.seed}.csv") + byk_df.to_csv(byk_csv, index=False) + + # ---------------- plots (optional) ------------- + if args.make_hists: + make_histograms(per_df, ks, args.out_dir, bins=args.hist_bins, dpi=args.hist_dpi) + + print(f"Wrote per-sample metrics -> {per_csv}") + print(f"Wrote summary metrics -> {summary_csv}") + print(f"Wrote by-k summary -> {byk_csv}") + if args.make_hists: + print(f"Saved histograms in -> {os.path.join(args.out_dir, 'plots')}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/plot_mcs.py b/plot_mcs.py new file mode 100644 index 0000000..c0b2fa5 --- /dev/null +++ b/plot_mcs.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import os +import re +from typing import Tuple, List, Optional + +import numpy as np +import pandas as pd +from rdkit import Chem, RDLogger +from rdkit.Chem import Draw, rdFMCS +from rdkit.Chem.Draw import rdMolDraw2D + +RDLogger.DisableLog("rdApp.*") + +# ------------------------ helpers: parsing/sanitizing ------------------------- +_WS_RE = re.compile(r"[\u00A0\u2007\u202F]") # non-breaking spaces + +def clean_smiles(raw: str) -> Optional[str]: + if raw is None: + return None + s = str(raw).strip().strip('"').strip("'") + s = _WS_RE.sub(" ", s).replace(" ", "") + return s or None + +def mol_from_smiles(smi: str) -> Optional[Chem.Mol]: + s = clean_smiles(smi) + if not s: + return None + m = Chem.MolFromSmiles(s, sanitize=False) + if m is None: + return None + try: + Chem.SanitizeMol(m) + except Exception: + return None + Chem.rdDepictor.Compute2DCoords(m) + return m + +# ------------------------ MCS + indexing utilities --------------------------- +def find_mcs(m1: Chem.Mol, m2: Chem.Mol, + ring_matches_ring_only=True, + complete_rings_only=False, + timeout_s: int = 10) -> Tuple[Optional[Chem.Mol], Tuple[int, ...], Tuple[int, ...]]: + res = rdFMCS.FindMCS( + [m1, m2], + ringMatchesRingOnly=ring_matches_ring_only, + completeRingsOnly=complete_rings_only, + timeout=timeout_s + ) + if res.canceled or not res.smartsString: + return None, tuple(), tuple() + mcs_mol = Chem.MolFromSmarts(res.smartsString) + if mcs_mol is None: + return None, tuple(), tuple() + match1 = m1.GetSubstructMatch(mcs_mol) + match2 = m2.GetSubstructMatch(mcs_mol) + if not match1 or not match2: + return None, tuple(), tuple() + return mcs_mol, tuple(int(i) for i in match1), tuple(int(i) for i in match2) + +def bonds_from_atom_set(mol: Chem.Mol, atom_idx_iter) -> List[int]: + idx_set = set(int(a) for a in atom_idx_iter) + keep = [] + for b in mol.GetBonds(): + a1, a2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx() + if a1 in idx_set and a2 in idx_set: + keep.append(int(b.GetIdx())) + return keep + +def _as_int_list(x): + if x is None: + return [] + out = [] + for v in x: + if isinstance(v, (list, tuple, set, np.ndarray)): + out.extend(int(i) for i in v) + else: + out.append(int(v)) + return out + +def _ensure_list_of_lists(n_mols, items): + """None -> None; [] -> [[],...]; or broadcast single list to per-mol lists""" + if items is None: + return None + if len(items) == 0: + return [[] for _ in range(n_mols)] + if isinstance(items[0], (int, np.integer)): + # single flat list -> broadcast + lst = _as_int_list(items) + return [lst for _ in range(n_mols)] + # already list-of-lists + return [ _as_int_list(it) for it in items ] + +# ------------------------ drawing primitives --------------------------------- +def draw_grid_png(mols, legends=None, hatoms=None, hbonds=None, + cols=2, cell_w=500, cell_h=400, out_png="out.png"): + if legends is None: + legends = [""] * len(mols) + hatoms = _ensure_list_of_lists(len(mols), hatoms) + hbonds = _ensure_list_of_lists(len(mols), hbonds) + + img = Draw.MolsToGridImage( + mols, + molsPerRow=cols, + subImgSize=(cell_w, cell_h), + legends=legends, + highlightAtomLists=hatoms, + highlightBondLists=hbonds, + useSVG=False, # returns a PIL Image + ) + img.save(out_png) + +def draw_overlay_png(m1, m2, hatoms1=None, hbonds1=None, hatoms2=None, hbonds2=None, + w=700, h=500, legend=None, out_png="out.png"): + d2d = rdMolDraw2D.MolDraw2DCairo(w, h) # modern signature: (width, height) + d2d.drawOptions().addStereoAnnotation = False + + d2d.DrawMolecule( + m1, + highlightAtoms=_as_int_list(hatoms1 or []), + highlightBonds=_as_int_list(hbonds1 or []), + legend=legend or "" + ) + d2d.DrawMolecule( + m2, + highlightAtoms=_as_int_list(hatoms2 or []), + highlightBonds=_as_int_list(hbonds2 or []), + ) + d2d.FinishDrawing() + with open(out_png, "wb") as f: + f.write(d2d.GetDrawingText()) + +def draw_single_png(m, hatoms=None, hbonds=None, w=600, h=450, legend=None, out_png="out.png"): + d2d = rdMolDraw2D.MolDraw2DCairo(w, h) + d2d.DrawMolecule( + m, + highlightAtoms=_as_int_list(hatoms or []), + highlightBonds=_as_int_list(hbonds or []), + legend=legend or "", + ) + d2d.FinishDrawing() + with open(out_png, "wb") as f: + f.write(d2d.GetDrawingText()) + +# ------------------------ main workflow -------------------------------------- +def main(): + ap = argparse.ArgumentParser(description="Plot two molecules and highlight their MCS.") + ap.add_argument("--data", required=True, help="CSV with columns: identifier, smiles") + ap.add_argument("--id1", required=True, help="First MSG identifier") + ap.add_argument("--id2", required=True, help="Second MSG identifier") + ap.add_argument("--out_prefix", required=True, help="Output path prefix (e.g., ./plots/m1_m2)") + ap.add_argument("--layout", choices=["side-by-side", "overlay"], default="side-by-side", + help="How to plot the two molecules") + ap.add_argument("--complete_rings_only", action="store_true", help="MCS must contain only complete rings") + ap.add_argument("--rings_match_rings_only", action="store_true", + help="Only match rings to rings in MCS") + ap.add_argument("--mcs_timeout", type=int, default=10, help="MCS timeout (seconds)") + args = ap.parse_args() + + df = pd.read_csv(args.data) + if not {"identifier", "smiles"}.issubset(df.columns): + raise SystemExit("Input CSV must have columns: identifier, smiles") + + id2smi = dict(zip(df["identifier"].astype(str), df["smiles"].astype(str))) + for need in (args.id1, args.id2): + if need not in id2smi: + raise SystemExit(f"Identifier '{need}' not found in data.") + + m1 = mol_from_smiles(id2smi[args.id1]) + m2 = mol_from_smiles(id2smi[args.id2]) + if m1 is None or m2 is None: + raise SystemExit("Failed to parse one of the molecules from SMILES.") + + mcs_mol, match1, match2 = find_mcs( + m1, m2, + ring_matches_ring_only=args.rings_match_rings_only, + complete_rings_only=args.complete_rings_only, + timeout_s=args.mcs_timeout + ) + if mcs_mol is None: + print("[warn] MCS not found (timeout or empty); continuing with unhighlighted drawings.") + match1, match2 = tuple(), tuple() + + mcs_bonds1 = bonds_from_atom_set(m1, match1) + mcs_bonds2 = bonds_from_atom_set(m2, match2) + + n_mcs_atoms = mcs_mol.GetNumAtoms() if mcs_mol is not None else 0 + n1 = m1.GetNumAtoms() + n2 = m2.GetNumAtoms() + denom = max(n1, n2) if max(n1, n2) > 0 else 1 + ratio = n_mcs_atoms / denom + print(f"MCS atoms: {n_mcs_atoms} | mol1 atoms: {n1} | mol2 atoms: {n2} | ratio: {ratio:.4f}") + + out_dir = os.path.dirname(args.out_prefix) + if out_dir and not os.path.isdir(out_dir): + os.makedirs(out_dir, exist_ok=True) + + if args.layout == "side-by-side": + out_img = f"{args.out_prefix}_side_by_side.png" + draw_grid_png( + [m1, m2], + legends=[args.id1, args.id2], + hatoms=[list(match1), list(match2)], + hbonds=[mcs_bonds1, mcs_bonds2], + cols=2, cell_w=500, cell_h=400, + out_png=out_img + ) + else: + out_img = f"{args.out_prefix}_overlay.png" + draw_overlay_png( + m1, m2, + hatoms1=list(match1), hbonds1=mcs_bonds1, + hatoms2=list(match2), hbonds2=mcs_bonds2, + w=700, h=500, legend=f"{args.id1} vs {args.id2}", + out_png=out_img + ) + print(f"[ok] Saved pair image -> {out_img}") + + if mcs_mol is not None and mcs_mol.GetNumAtoms() > 0: + Chem.rdDepictor.Compute2DCoords(mcs_mol) + out_mcs = f"{args.out_prefix}_mcs.png" + draw_single_png(mcs_mol, legend=f"MCS (atoms={n_mcs_atoms})", out_png=out_mcs) + print(f"[ok] Saved MCS image -> {out_mcs}") + + out_txt = f"{args.out_prefix}_mcs_ratio.txt" + with open(out_txt, "w") as f: + f.write(f"id1={args.id1}\n") + f.write(f"id2={args.id2}\n") + f.write(f"mcs_atoms={n_mcs_atoms}\n") + f.write(f"mol1_atoms={n1}\n") + f.write(f"mol2_atoms={n2}\n") + f.write(f"ratio={ratio:.6f}\n") + print(f"[ok] Wrote MCS stats -> {out_txt}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/results_sorting.py b/results_sorting.py new file mode 100644 index 0000000..2ce70a2 --- /dev/null +++ b/results_sorting.py @@ -0,0 +1,26 @@ +import pandas as pd +from rdkit import Chem +import argparse + +def num_atoms(smi): + m = Chem.MolFromSmiles(smi) + return m.GetNumAtoms() if m is not None else None + +ap = argparse.ArgumentParser() +ap.add_argument("--file_path", required=True, help="CSV with columns: scaffold,pred,true,score,...") +args = ap.parse_args() + +df = pd.read_csv(args.file_path) + +df["total_ll"] = df["nll"] + df["ell"] + +# preserve the original dataset order by assigning group IDs +df["group_id"] = (df["true"] + "|" + df["scaffold"]).factorize()[0] + +# sort inside each group but keep global group order +df = df.sort_values(["group_id", "total_ll", "score"], ascending=[True, False, False]) + +# drop the helper column before saving +df = df.drop(columns="group_id") + +df.to_csv(args.file_path.replace(".csv", "_my_ranked.csv"), index=False) \ No newline at end of file diff --git a/sample.py b/sample.py index 523dffe..26ae412 100644 --- a/sample.py +++ b/sample.py @@ -1,6 +1,11 @@ import argparse import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +import torch import pandas as pd +from pathlib import Path +import numpy as np from src.utils import disable_rdkit_logging, parse_yaml_config, set_deterministic from src.analysis.rdkit_functions import build_molecule @@ -8,23 +13,40 @@ from src.data.msgym_dataset import MsGymDataModule, MsGyminfos from src.data.canopus_dataset import CanopusDataModule, Canopusinfos from src.analysis.visualization import MolecularVisualization - +from src.frameworks import diffusion_utils from rdkit import Chem +from rdkit.Chem.MolStandardize import rdMolStandardize from tqdm import tqdm from pdb import set_trace +def _to_float(x): + try: + if hasattr(x, "item"): # torch/np scalar + return float(x.item()) + if isinstance(x, (list, tuple)) and len(x) == 1: + return float(x[0]) + return float(x) + except Exception: + return float("nan") + +_LF = rdMolStandardize.LargestFragmentChooser() + def main(args): + set_deterministic(args.sampling_seed) torch_device = "cuda:0" if args.device == "gpu" else "cpu" + gen = torch.Generator(device=torch_device) + gen.manual_seed(args.sampling_seed if args.sampling_seed is not None else 0) + diffusion_utils.set_sampling_generator(gen) data_root = os.path.join(args.data, args.dataset) checkpoint_name = args.checkpoint.split("/")[-1].replace(".ckpt", "") - output_dir = os.path.join(args.samples, f"{args.dataset}_{args.mode}") + output_dir = os.path.join(args.samples, f"{Path(args.msgym_pkl).stem}") if args.table_name != '': table_name = f"{args.table_name}.csv" else: - table_name = f"{checkpoint_name}_T={args.n_steps}_n={args.n_samples}_seed={args.sampling_seed}.csv" + table_name = f"n={args.n_samples}_seed={args.sampling_seed}.csv" table_path = os.path.join(output_dir, table_name) skip_first_n = 0 @@ -56,7 +78,10 @@ def main(args): shuffle=args.shuffle, extra_nodes=args.extra_nodes, swap=args.swap, - evaluation=False, + # added + msgym_pkl=args.msgym_pkl, + ranks_pkl=args.ranks_pkl, + seed=args.sampling_seed ) dataset_infos = MsGyminfos(datamodule) else: @@ -66,12 +91,10 @@ def main(args): num_workers=args.num_workers, shuffle=args.shuffle, extra_nodes=args.extra_nodes, - swap=args.swap, - evaluation=False, + swap=args.swap ) dataset_infos = Canopusinfos(datamodule) - set_deterministic(args.sampling_seed) model.eval().to(torch_device) visualization_tools = MolecularVisualization(dataset_infos) @@ -93,6 +116,8 @@ def main(args): if args.mode == "test" else datamodule.val_dataloader() ) + print("num test batches:", len(dataloader)) + print("dataset length :", len(dataloader.dataset)) for i, data in enumerate(tqdm(dataloader)): if i * args.batch_size < skip_first_n: print(i , skip_first_n) @@ -118,11 +143,11 @@ def main(args): data=data, batch_id=ident, batch_size=bs, - save_final=32, - keep_chain=32, + save_final=min(len(dataloader.dataset), args.batch_size), + keep_chain=min(len(dataloader.dataset), args.batch_size), number_chain_steps_to_save=40, sample_idx=sample_idx, - save_true_reactants=True, + #save_true_reactants=True, use_one_hot=args.use_one_hot, ) @@ -170,17 +195,27 @@ def main(args): grouped_ells ): - true_n_dummy_atoms = 0 + # --- TRUE --- true_mol = build_molecule( true_mol[0], true_mol[1], dataset_infos.atom_decoder ) + try: + true_mol = _LF.choose(true_mol) + except Exception: + pass true_smi = Chem.MolToSmiles(true_mol, canonical=True) + # --- SCAFFOLD / PRODUCT --- product_mol = build_molecule( product_mol[0], product_mol[1], dataset_infos.atom_decoder ) - product_smi = Chem.MolToSmiles(product_mol) + try: + product_mol = _LF.choose(product_mol) + except Exception: + pass + product_smi = Chem.MolToSmiles(product_mol, canonical=True) + # --- PREDICTIONS --- for pred_mol, pred_score, nll, ell in zip( pred_mols, pred_scores, nlls, ells ): @@ -190,13 +225,18 @@ def main(args): dataset_infos.atom_decoder, return_n_dummy_atoms=True, ) - pred_smi = Chem.MolToSmiles(pred_mol) + try: + pred_mol = _LF.choose(pred_mol) + except Exception: + pass + pred_smi = Chem.MolToSmiles(pred_mol, canonical=True) + true_molecules_smiles.append(true_smi) product_molecules_smiles.append(product_smi) pred_molecules_smiles.append(pred_smi) - computed_scores.append(pred_score) - computed_nlls.append(nll) - computed_ells.append(ell) + computed_scores.append(_to_float(pred_score)) + computed_nlls.append(_to_float(nll)) + computed_ells.append(_to_float(ell)) table = pd.DataFrame( { @@ -210,6 +250,16 @@ def main(args): ) full_table = pd.concat([prev_table, table]) full_table.to_csv(table_path, index=False) + # Optional: delete cached test file after run + if args.delete_test_cache: + root = os.path.join(args.data, args.dataset) + processed = os.path.join(root, "processed", "msgym_test_final.pt") + try: + if os.path.exists(processed): + os.remove(processed) + print(f"Deleted {processed}") + except Exception as e: + print(f"Could not delete {processed}: {e}") if __name__ == "__main__": @@ -233,7 +283,15 @@ def main(args): parser.add_argument( "--table_name", action="store", type=str, required=False, default='' ) - main(args=parse_yaml_config(parser.parse_args())) + # added + parser.add_argument("--msgym_pkl", type=str, default=None, + help="Path to msgym.pkl to use instead of the default.") + parser.add_argument("--ranks_pkl", type=str, default=None, + help="Path to ranks_msgym_pred.pkl (only used for test mode).") + parser.add_argument("--delete_test_cache", action="store_true", + help="Delete processed/msgym_test_final.pt after run.") + parsed_args, _ = parser.parse_known_args() + main(args=parse_yaml_config(parsed_args)) diff --git a/sanitize_results.py b/sanitize_results.py new file mode 100644 index 0000000..57b414f --- /dev/null +++ b/sanitize_results.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# sanitize_results.py +# Usage: +# python sanitize_results.py --file_path table.csv + +import argparse +import re +import pandas as pd +from rdkit import Chem, RDLogger +from rdkit.Chem.MolStandardize import rdMolStandardize +from rdkit.Chem.rdmolops import RemoveHs + +RDLogger.DisableLog("rdApp.*") +_lf = rdMolStandardize.LargestFragmentChooser() +_WS_RE = re.compile(r"[\u00A0\u2007\u202F]") # common non-breaking spaces + +def clean_smiles(raw: str, fix_nitro: bool = True): + if raw is None: + return None + s = str(raw).strip().strip('"').strip("'") + # normalize spaces + s = _WS_RE.sub(" ", s).replace(" ", "") + if not s: + return None + if fix_nitro: + # replace bare N(=O)O with [N+](=O)[O-] when not already bracketed + s = re.sub(r'(? {out_path}") + +if __name__ == "__main__": + main() diff --git a/scaffold_coverage.py b/scaffold_coverage.py new file mode 100644 index 0000000..e5f2595 --- /dev/null +++ b/scaffold_coverage.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +import argparse +import os +from collections import defaultdict + +import pandas as pd +from tqdm import tqdm + +from rdkit import Chem +from rdkit import RDLogger +from rdkit.Chem import Draw + +RDLogger.DisableLog("rdApp.*") + + +def mol_from_smiles(smi: str): + if not isinstance(smi, str) or smi.strip() == "": + return None + try: + return Chem.MolFromSmiles(smi) + except Exception: + return None + + +def has_scaffold(pred_mol, scaffold_mol) -> bool: + if pred_mol is None or scaffold_mol is None: + return False + try: + return pred_mol.HasSubstructMatch(scaffold_mol) + except Exception: + return False + + +def main(): + ap = argparse.ArgumentParser(description="Compute scaffold coverage on generated samples and save per-molecule plots.") + ap.add_argument("--input_csv", required=True, help="CSV from sample.py (columns: scaffold,pred,true,score,nll,ell)") + ap.add_argument("--out_csv", default="scaffold_coverage_summary.csv", help="Where to write per-molecule summary CSV") + ap.add_argument("--plots_dir", default="scaffold_plots", help="Directory to write true_vs_scaffold images") + ap.add_argument("--max_per_mol", type=int, default=None, help="Optionally cap number of preds per molecule (e.g., 100)") + args = ap.parse_args() + + os.makedirs(args.plots_dir, exist_ok=True) + + df = pd.read_csv(args.input_csv) + + # Basic sanity columns + needed = {"scaffold", "pred", "true"} + missing = needed - set(df.columns) + if missing: + raise ValueError(f"Input CSV is missing required columns: {missing}") + + # Group by (true, scaffold) — in msgym runs, each true has a single scaffold + groups = df.groupby(["true", "scaffold"], sort=False) + + rows = [] + overall_n_molecules = 0 + overall_valid_preds = 0 + overall_with_scaffold = 0 + + print(f"Found {len(groups)} molecules in the table") + for (true_smi, scaffold_smi), g in tqdm(groups, desc="Evaluating", unit="mol"): + print(1) + # Optionally cap number of preds + preds = g["pred"].tolist() + if args.max_per_mol is not None: + preds = preds[: args.max_per_mol] + + true_mol = mol_from_smiles(true_smi) + scaf_mol = mol_from_smiles(scaffold_smi) + + # Some scaffolds in data can be multi-fragment ("A.B"). You can decide to skip or try the largest fragment. + # Here: if multi-fragment, we try the **largest** fragment. + if scaf_mol is None and isinstance(scaffold_smi, str) and "." in scaffold_smi: + try: + parts = scaffold_smi.split(".") + parts_mols = [mol_from_smiles(p) for p in parts] + parts_mols = [m for m in parts_mols if m is not None] + if parts_mols: + scaf_mol = max(parts_mols, key=lambda m: m.GetNumAtoms()) + except Exception: + pass + + # Count valid preds + how many contain scaffold + valid_pred_mols = [] + n_with_scaf = 0 + for ps in preds: + pm = mol_from_smiles(ps) + if pm is None: + continue + valid_pred_mols.append(pm) + if has_scaffold(pm, scaf_mol): + n_with_scaf += 1 + + n_preds = len(preds) + n_valid = len(valid_pred_mols) + frac_scaf = (n_with_scaf / n_valid) if n_valid > 0 else 0.0 + + rows.append( + { + "true": true_smi, + "scaffold": scaffold_smi, + "n_rows_in_csv": len(g), # raw rows for this molecule in input + "n_preds_considered": n_preds, # after optional cap + "n_valid_preds": n_valid, # RDKit-parsable preds + "n_with_scaffold": n_with_scaf, # substructure match + "frac_with_scaffold": frac_scaf, + "true_num_atoms": (true_mol.GetNumAtoms() if true_mol else None), + "scaffold_num_atoms": (scaf_mol.GetNumAtoms() if scaf_mol else None), + } + ) + + # Save side-by-side plot of true vs scaffold + # If either is None, we still try to draw what we can. + try: + mols = [] + legends = [] + if true_mol is not None: + mols.append(true_mol) + legends.append("True") + if scaf_mol is not None: + mols.append(scaf_mol) + legends.append("Scaffold") + if mols: + img = Draw.MolsToGridImage( + mols, + molsPerRow=len(mols), + subImgSize=(350, 350), + legends=legends, + ) + # filename from first 32 chars of SMILES to keep paths manageable + base = f"{hash(true_smi) & 0xffffffff:x}" + img_path = os.path.join(args.plots_dir, f"true_vs_scaffold_{base}.png") + img.save(img_path) + except Exception: + pass + + # aggregate + overall_n_molecules += 1 + overall_valid_preds += n_valid + overall_with_scaffold += n_with_scaf + + summary = pd.DataFrame(rows) + summary.to_csv(args.out_csv, index=False) + + print("\n==== Summary ====") + print(f"Molecules evaluated: {overall_n_molecules}") + print(f"Valid predictions total: {overall_valid_preds}") + print(f"Predictions containing scaffold: {overall_with_scaffold}") + overall_frac = (overall_with_scaffold / overall_valid_preds) if overall_valid_preds > 0 else 0.0 + print(f"Overall fraction (contains scaffold | valid preds): {overall_frac:.4f}") + print(f"Per-molecule summary written to: {args.out_csv}") + print(f"Per-molecule plots saved to: {args.plots_dir}") + + # === Histogram === + import matplotlib.pyplot as plt + plt.figure(figsize=(7,5)) + plt.hist(summary["frac_with_scaffold"], bins=20, color="skyblue", edgecolor="black") + plt.xlabel("Fraction of predictions containing scaffold") + plt.ylabel("Number of molecules") + plt.title("Distribution of scaffold coverage per molecule") + hist_path = os.path.join(args.plots_dir, "scaffold_fraction_histogram.png") + plt.tight_layout() + plt.savefig(hist_path) + print(f"Histogram saved to: {hist_path}") + + +if __name__ == "__main__": + main() diff --git a/src/data/msgym_dataset.py b/src/data/msgym_dataset.py index 15cec33..c020aee 100644 --- a/src/data/msgym_dataset.py +++ b/src/data/msgym_dataset.py @@ -1,5 +1,7 @@ import pickle import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import pathlib import re @@ -16,6 +18,7 @@ from rdkit import Chem, RDLogger from rdkit.Chem.rdchem import BondType as BT from rdkit.Chem import rdFMCS +from rdkit.Chem.MolStandardize import rdMolStandardize import torch import torch.nn.functional as F @@ -28,6 +31,8 @@ from src.analysis.rdkit_functions import mol2smiles, build_molecule_with_partial_charges from src.data.abstract_dataset import MolecularDataModule, AbstractDatasetInfos +from src.utils import make_worker_init_fn, make_torch_generator + def to_list(value: Any) -> Sequence: if isinstance(value, Sequence) and not isinstance(value, str): return value @@ -35,6 +40,7 @@ def to_list(value: Any) -> Sequence: return [value] atom_decoder = ['N', 'P', 'B', 'I', 'As', 'Se', 'Cl', 'C', 'F', 'S', 'Br', 'O', 'Si'] +_LF = rdMolStandardize.LargestFragmentChooser() class MsGymDataset(InMemoryDataset): types = {'N': 0, 'P': 1, 'B': 2, 'I': 3, 'As': 4, 'Se': 5, 'Cl': 6, 'C': 7, 'F': 8, 'S': 9, 'Br': 10, 'O': 11, 'Si': 12} @@ -46,10 +52,12 @@ class MsGymDataset(InMemoryDataset): Chem.BondType.AROMATIC: 3 } - def __init__(self, stage, root, transform=None, pre_transform=None, pre_filter=None, preprocess=False): + def __init__(self, stage, root, transform=None, pre_transform=None, pre_filter=None, preprocess=False, msgym_pkl_path=None, ranks_pkl_path=None): self.stage = stage self.atom_decoder = atom_decoder self.remove_h = True + self.msgym_pkl_path = msgym_pkl_path + self.ranks_pkl_path = ranks_pkl_path if self.stage == 'train': self.file_idx = 0 elif self.stage == 'val': @@ -104,9 +112,16 @@ def __getitem__(self, idx): def process(self): preprocess = self.preprocess RDLogger.DisableLog('rdApp.*') - ms_dict = pickle.load(open('./data/msgym/raw/msgym.pkl', 'rb')) + + default_msgym = './data/msgym/raw/msgym.pkl' + msgym_path = self.msgym_pkl_path or default_msgym + ms_dict = pickle.load(open(msgym_path, 'rb')) + if self.stage =='test': - sca_dict = pickle.load(open('./data/msgym/raw/ranks_msgym_pred.pkl', 'rb')) + if self.ranks_pkl_path is None: + default_ranks = './data/msgym/raw/ranks_msgym_pred.pkl' + self.ranks_pkl_path = default_ranks + sca_dict = pickle.load(open(self.ranks_pkl_path, 'rb')) # sca_dict = {ele[1]: Chem.MolToSmiles(Chem.MolFromSmarts(ele[4][0])) for ele in sca_list} else: sca_dict = None @@ -165,13 +180,17 @@ def process(self): class MsGymDataModule(MolecularDataModule): DATASET_CLASS = MsGymDataset - def __init__(self, data_root, batch_size, num_workers, shuffle, extra_nodes=False, evaluation=False, swap=False): + def __init__(self, data_root, batch_size, num_workers, shuffle, extra_nodes=False, evaluation=True, swap=False, msgym_pkl=None, ranks_pkl=None, seed=42): super().__init__(batch_size, num_workers, shuffle) self.extra_nodes = extra_nodes self.evaluation = evaluation self.swap = swap self.data_root = data_root - self.train_smiles = [] + if not self.evaluation: + self.train_smiles = [] + self.msgym_pkl = msgym_pkl + self.ranks_pkl = ranks_pkl + self.seed = seed self.prepare_data() self.preprocess = False @@ -183,11 +202,22 @@ def prepare_data(self): base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] root_path = os.path.join(base_path, self.data_root) - datasets = { - 'train': MsGymDataset(stage='train', root=root_path, preprocess=False), - 'val': MsGymDataset(stage='val', root=root_path, preprocess=False), - 'test': MsGymDataset(stage='test', root=root_path, preprocess=False) - } + if self.evaluation: + datasets = { + 'test': MsGymDataset(stage='test', root=root_path, preprocess=False, + msgym_pkl_path=self.msgym_pkl, ranks_pkl_path=self.ranks_pkl) + } + else: + datasets = { + 'train': MsGymDataset(stage='train', root=root_path, preprocess=False, + msgym_pkl_path=self.msgym_pkl, ranks_pkl_path=None), + 'val': MsGymDataset(stage='val', root=root_path, preprocess=False, + msgym_pkl_path=self.msgym_pkl, ranks_pkl_path=None), + 'test': MsGymDataset(stage='test', root=root_path, preprocess=False, + msgym_pkl_path=self.msgym_pkl, ranks_pkl_path=self.ranks_pkl), + } + wif = make_worker_init_fn(self.seed) + dlgen = make_torch_generator(self.seed) self.dataloaders = {} for split, dataset in datasets.items(): self.dataloaders[split] = DataLoader( @@ -195,9 +225,14 @@ def prepare_data(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=(self.shuffle and split == 'train'), + worker_init_fn=wif, # <<< add + generator=dlgen, + persistent_workers=(self.num_workers > 0), + pin_memory=True ) print(len(datasets['test'])) - self.train_smiles = datasets['train'].r_smiles + if not self.evaluation: + self.train_smiles = datasets['train'].r_smiles class MsGyminfos(AbstractDatasetInfos): max_n_dummy_nodes = 10 @@ -312,6 +347,10 @@ def align_molecular_node_features(pyg1, pyg2, atom_decoder): def create_scaffold_graph(smiles, atom_decoder, i, ms, sca_dict=None, key_list=[], use_scaffold=True, source='train'): mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + # leave only main fragment + mol = _LF.choose(mol) mol = Chem.RemoveAllHs(mol) atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] pyg_graph = molecule_to_pyg_graph(mol, atom_decoder, smiles, ms) @@ -319,12 +358,19 @@ def create_scaffold_graph(smiles, atom_decoder, i, ms, sca_dict=None, key_list=[ if sca_dict != None: try: p_mol = Chem.MolFromSmiles(sca_dict[key_list[i]]) - except: + except Exception: + return None + if p_mol is None: return None + # largest fragment + p_mol = _LF.choose(p_mol) else: p_mol = mol # pyg_graph = molecule_to_pyg_graph(p_mol, atom_decoder, smiles, ms, key_list[i]) scaffold = GetScaffoldForMol(p_mol) + if scaffold is None: + return None + scaffold = _LF.choose(scaffold) # largest fragment too scaffold = add_missing_atoms(scaffold, atoms) scaffold_g = molecule_to_pyg_graph(scaffold, atom_decoder, smiles, ms) if scaffold_g == None: @@ -337,8 +383,10 @@ def create_scaffold_graph(smiles, atom_decoder, i, ms, sca_dict=None, key_list=[ scaffold_x = scaffold_g.x else: scaffold = GetScaffoldForMol(mol) - if scaffold == None: + if scaffold is None: return None + # largest fragment + scaffold = _LF.choose(scaffold) scaffold_nodes = align_scaffold_to_molecule(mol, scaffold) # Create a mask for edges that only connect nodes within the scaffold edge_start, edge_end = pyg_graph.edge_index diff --git a/src/data/utils.py b/src/data/utils.py index f31737d..24ca0ce 100644 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -1,3 +1,6 @@ +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch import torch_geometric diff --git a/src/features/extra_features.py b/src/features/extra_features.py index 8760865..86d37dc 100644 --- a/src/features/extra_features.py +++ b/src/features/extra_features.py @@ -1,6 +1,10 @@ +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch from src.data import utils +from src.frameworks import diffusion_utils class DummyExtraFeatures: @@ -168,7 +172,9 @@ def get_eigenvectors_features(vectors, node_mask, n_connected, k=2): # Create an indicator for the nodes outside the largest connected components first_ev = torch.round(vectors[:, :, 0], decimals=3) * node_mask # bs, n # Add random value to the mask to prevent 0 from becoming the mode - random = torch.randn(bs, n, device=node_mask.device) * (~node_mask) # bs, n + eps = (torch.arange(n, device=node_mask.device, dtype=first_ev.dtype) + 1) / (n + 1) + eps = eps.unsqueeze(0).expand(bs, -1) + random = eps * (~node_mask) # bs, n first_ev = first_ev + random most_common = torch.mode(first_ev, dim=1).values # values: bs -- indices: bs mask = ~ (first_ev == most_common.unsqueeze(1)) diff --git a/src/features/extra_features_molecular.py b/src/features/extra_features_molecular.py index 273e6e4..a428b24 100644 --- a/src/features/extra_features_molecular.py +++ b/src/features/extra_features_molecular.py @@ -1,3 +1,6 @@ +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch from src.data import utils diff --git a/src/frameworks/diffusion_utils.py b/src/frameworks/diffusion_utils.py index 7ee5fd3..c7f29b7 100644 --- a/src/frameworks/diffusion_utils.py +++ b/src/frameworks/diffusion_utils.py @@ -1,3 +1,7 @@ +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +from typing import Optional, Tuple import torch from torch.nn import functional as F import numpy as np @@ -5,48 +9,54 @@ from src.data.utils import PlaceHolder +# --------------------------- +# Determinism plumbing +# --------------------------- +# You set this once from your main script: +# from src.frameworks import diffusion_utils +# gen = torch.Generator(device=torch_device); gen.manual_seed(seed) +# diffusion_utils.set_sampling_generator(gen) +# All multinomial/randn calls below will use it. +_GEN: Optional[torch.Generator] = None + +def set_sampling_generator(gen: torch.Generator): + global _GEN + _GEN = gen + +def _require_gen(device): + if _GEN is None: + # Fall back to default (non-deterministic). Prefer to set via set_sampling_generator(). + return torch.default_generator + # Torch requires the generator device to match the target device for CUDA ops. + # If you created it on the correct device in main (recommended), this is fine. + return _GEN + +# --------------------------- +# Small helpers +# --------------------------- def sum_except_batch(x): return x.reshape(x.size(0), -1).sum(dim=-1) - def assert_correctly_masked(variable, node_mask): assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \ 'Variables not masked properly.' +def sample_gaussian(size, *, device, generator): + return torch.randn(size, device=device, generator=generator) -def sample_gaussian(size): - x = torch.randn(size) - return x - - -def sample_gaussian_with_mask(size, node_mask): - x = torch.randn(size) - x = x.type_as(node_mask.float()) - x_masked = x * node_mask - return x_masked - +def sample_gaussian_with_mask(size, node_mask, *, device, generator): + x = torch.randn(size, device=device, generator=generator) + return (x * node_mask.to(x.dtype)) def clip_noise_schedule(alphas2, clip_value=0.001): - """ - For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during - sampling. - """ alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) - alphas_step = (alphas2[1:] / alphas2[:-1]) - alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) alphas2 = np.cumprod(alphas_step, axis=0) - return alphas2 - def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ steps = timesteps + 2 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 @@ -55,25 +65,19 @@ def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): betas = np.clip(betas, a_min=0, a_max=0.999) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) - if raise_to_power != 1: alphas_cumprod = np.power(alphas_cumprod, raise_to_power) - return alphas_cumprod - def cosine_beta_schedule_discrete(timesteps, s=0.008): - """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ steps = timesteps + 2 x = np.linspace(0, steps, steps) - alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - alphas return betas.squeeze() - def polynomial_beta_schedule_discrete(timesteps, s=0.008): steps = timesteps + 2 x = np.linspace(0, steps, steps) @@ -81,7 +85,6 @@ def polynomial_beta_schedule_discrete(timesteps, s=0.008): betas = 1 - alphas return betas.squeeze() - def linear_beta_schedule_discrete(timesteps, s=0.008): steps = timesteps + 2 x = np.linspace(0, steps, steps) @@ -89,358 +92,237 @@ def linear_beta_schedule_discrete(timesteps, s=0.008): betas = 1 - alphas return betas.squeeze() - def custom_beta_schedule_discrete(timesteps, average_num_nodes=50, s=0.008): - """ Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ. """ steps = timesteps + 2 x = np.linspace(0, steps, steps) - alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - alphas - assert timesteps >= 100 - - p = 4 / 5 # 1 - 1 / num_edge_classes + p = 4 / 5 num_edges = average_num_nodes * (average_num_nodes - 1) / 2 - - # First 100 steps: only a few updates per graph updates_per_graph = 1.2 beta_first = updates_per_graph / (p * num_edges) - betas[betas < beta_first] = beta_first return np.array(betas) - - def gaussian_KL(q_mu, q_sigma): - """Computes the KL distance between a normal distribution and the standard normal. - Args: - q_mu: Mean of distribution q. - q_sigma: Standard deviation of distribution q. - p_mu: Mean of distribution p. - p_sigma: Standard deviation of distribution p. - Returns: - The KL distance, summed over all dimensions except the batch dim. - """ return sum_except_batch((torch.log(1 / q_sigma) + 0.5 * (q_sigma ** 2 + q_mu ** 2) - 0.5)) - def cdf_std_gaussian(x): return 0.5 * (1. + torch.erf(x / math.sqrt(2))) - def SNR(gamma): - """Computes signal to noise ratio (alpha^2/sigma^2) given gamma.""" return torch.exp(-gamma) - def inflate_batch_array(array, target_shape): - """ - Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty - axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. - """ target_shape = (array.size(0),) + (1,) * (len(target_shape) - 1) return array.view(target_shape) - def sigma(gamma, target_shape): - """Computes sigma given gamma.""" return inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_shape) - def alpha(gamma, target_shape): - """Computes alpha given gamma.""" return inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), target_shape) - def check_mask_correct(variables, node_mask): for i, variable in enumerate(variables): if len(variable) > 0: assert_correctly_masked(variable, node_mask) - def check_tensor_same_size(*args): for i, arg in enumerate(args): if i == 0: continue assert args[0].size() == arg.size() - def sigma_and_alpha_t_given_s(gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_size: torch.Size): - """ - Computes sigma t given s, using gamma_t and gamma_s. Used during sampling. - - These are defined as: - alpha t given s = alpha t / alpha s, - sigma t given s = sqrt(1 - (alpha t given s) ^2 ). - """ - sigma2_t_given_s = inflate_batch_array( - -torch.expm1(F.softplus(gamma_s) - F.softplus(gamma_t)), target_size - ) - - # alpha_t_given_s = alpha_t / alpha_s + sigma2_t_given_s = inflate_batch_array(-torch.expm1(F.softplus(gamma_s) - F.softplus(gamma_t)), target_size) log_alpha2_t = F.logsigmoid(-gamma_t) log_alpha2_s = F.logsigmoid(-gamma_s) log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s - alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s) alpha_t_given_s = inflate_batch_array(alpha_t_given_s, target_size) - sigma_t_given_s = torch.sqrt(sigma2_t_given_s) - return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s - def reverse_tensor(x): return x[torch.arange(x.size(0) - 1, -1, -1)] +# --------------------------- +# RNG-using functions (deterministic) +# --------------------------- -def sample_feature_noise(X_size, E_size, y_size, node_mask): - """Standard normal noise for all features. - Output size: X.size(), E.size(), y.size() """ - # TODO: How to change this for the multi-gpu case? - epsX = sample_gaussian(X_size) - epsE = sample_gaussian(E_size) - epsy = sample_gaussian(y_size) +def sample_feature_noise(X_size, E_size, y_size, node_mask, *, device, generator=None): + """Standard normal noise for all features, symmetric for edges.""" + gen = generator or _require_gen(device) + epsX = sample_gaussian(X_size, device=device, generator=gen) + epsE = sample_gaussian(E_size, device=device, generator=gen) + epsy = sample_gaussian(y_size, device=device, generator=gen) - float_mask = node_mask.float() - epsX = epsX.type_as(float_mask) - epsE = epsE.type_as(float_mask) - epsy = epsy.type_as(float_mask) + float_mask = node_mask.to(epsX.dtype) + epsX = epsX * float_mask + epsy = epsy * float_mask - # Get upper triangular part of edge noise, without main diagonal - upper_triangular_mask = torch.zeros_like(epsE) - indices = torch.triu_indices(row=epsE.size(1), col=epsE.size(2), offset=1) - upper_triangular_mask[:, indices[0], indices[1], :] = 1 + # Mask edges by node presence + epsE = epsE * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) + # Upper triangular only (then mirror) + upper_triangular_mask = torch.zeros_like(epsE, device=device) + i, j = torch.triu_indices(row=epsE.size(1), col=epsE.size(2), offset=1, device=device) + upper_triangular_mask[:, i, j, :] = 1 epsE = epsE * upper_triangular_mask - epsE = (epsE + torch.transpose(epsE, 1, 2)) - - assert (epsE == torch.transpose(epsE, 1, 2)).all() + epsE = epsE + epsE.transpose(1, 2) + assert (epsE == epsE.transpose(1, 2)).all() return PlaceHolder(X=epsX, E=epsE, y=epsy).mask(node_mask) - -def sample_normal(mu_X, mu_E, mu_y, sigma, node_mask): - """Samples from a Normal distribution.""" - # TODO: change for multi-gpu case - eps = sample_feature_noise(mu_X.size(), mu_E.size(), mu_y.size(), node_mask).type_as(mu_X) +def sample_normal(mu_X, mu_E, mu_y, sigma, node_mask, *, device, generator=None): + gen = generator or _require_gen(device) + eps = sample_feature_noise(mu_X.size(), mu_E.size(), mu_y.size(), node_mask, + device=device, generator=gen).type_as(mu_X) X = mu_X + sigma * eps.X E = mu_E + sigma.unsqueeze(1) * eps.E y = mu_y + sigma.squeeze(1) * eps.y return PlaceHolder(X=X, E=E, y=y) +def sample_discrete_features(probX, probE, node_mask, *, generator=None): + """ + Sample nodes/edges with Multinomial using a fixed generator for determinism. + probX: (bs, n, dx_out) + probE: (bs, n, n, de_out) + """ + device = probX.device + gen = generator or _require_gen(device) -def check_issues_norm_values(gamma, norm_val1, norm_val2, num_stdevs=8): - """ Check if 1 / norm_value is still larger than 10 * standard deviation. """ - zeros = torch.zeros((1, 1)) - gamma_0 = gamma(zeros) - sigma_0 = sigma(gamma_0, target_shape=zeros.size()).item() - max_norm_value = max(norm_val1, norm_val2) - if sigma_0 * num_stdevs > 1. / max_norm_value: - raise ValueError( - f'Value for normalization value {max_norm_value} probably too ' - f'large with sigma_0 {sigma_0:.5f} and ' - f'1 / norm_value = {1. / max_norm_value}') - - -def sample_discrete_features(probX, probE, node_mask): - ''' Sample features from multinomial distribution with given probabilities (probX, probE, proby) - :param probX: bs, n, dx_out node features - :param probE: bs, n, n, de_out edge features - :param proby: bs, dy_out global features. - ''' bs, n, _ = probX.shape - # Noise X - # The masked rows should define probability distributions as well - probX[~node_mask] = 1 / probX.shape[-1] - - # Flatten the probability tensor to sample with multinomial - probX = probX.reshape(bs * n, -1) # (bs * n, dx_out) + probX = probX.clone() - # Sample X - X_t = probX.multinomial(1) # (bs * n, 1) - X_t = X_t.reshape(bs, n) # (bs, n) - X_t_probs = torch.gather(probX, 1, X_t.reshape(-1, 1)).reshape(bs, n) + # masked nodes -> uniform + probX[~node_mask] = 1.0 / probX.shape[-1] + probX_flat = probX.reshape(bs * n, -1) - # Noise E - # The masked rows should define probability distributions as well - inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2)) - diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1) + X_t = probX_flat.multinomial(1, generator=gen).reshape(bs, n) + X_t_probs = torch.gather(probX_flat, 1, X_t.reshape(-1, 1)).reshape(bs, n) - probE[inverse_edge_mask] = 1 / probE.shape[-1] - probE[diag_mask.bool()] = 1 / probE.shape[-1] + inverse_edge_mask = ~(node_mask.unsqueeze(1) & node_mask.unsqueeze(2)) + diag_mask = torch.eye(n, device=device, dtype=torch.bool).unsqueeze(0).expand(bs, -1, -1) - probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out) + probE = probE.clone() + probE[inverse_edge_mask] = 1.0 / probE.shape[-1] + probE[diag_mask] = 1.0 / probE.shape[-1] - # Sample E - E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n) + probE_flat = probE.reshape(bs * n * n, -1) + E_t = probE_flat.multinomial(1, generator=gen).reshape(bs, n, n) E_t = torch.triu(E_t, diagonal=1) - E_t = (E_t + torch.transpose(E_t, 1, 2)) - E_t_probs = torch.gather(probE, 1, E_t.reshape(-1, 1)).reshape(bs, n, n) - - total_probs_X = X_t_probs.mean(dim=1) # Sum over nodes - total_probs_E = E_t_probs.mean(dim=[1, 2]) # Sum over edges - - total_probs = total_probs_X + total_probs_E # Shape [bs] - - return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)), total_probs + E_t = E_t + E_t.transpose(1, 2) + E_t_probs = torch.gather(probE_flat, 1, E_t.reshape(-1, 1)).reshape(bs, n, n) -def compute_posterior_distribution(M, M_t, Qt_M, Qsb_M, Qtb_M): - ''' M: X or E - Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T - ''' - # Flatten feature tensors - M = M.flatten(start_dim=1, end_dim=-2).to(torch.float32) # (bs, N, d) with N = n or n * n - M_t = M_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) # same - - Qt_M_T = torch.transpose(Qt_M, -2, -1) # (bs, d, d) - - left_term = M_t @ Qt_M_T # (bs, N, d) - right_term = M @ Qsb_M # (bs, N, d) - product = left_term * right_term # (bs, N, d) + total_probs = X_t_probs.mean(dim=1) + E_t_probs.mean(dim=[1, 2]) - denom = M @ Qtb_M # (bs, N, d) @ (bs, d, d) = (bs, N, d) - denom = (denom * M_t).sum(dim=-1) # (bs, N, d) * (bs, N, d) + sum = (bs, N) - # denom = product.sum(dim=-1) - # denom[denom == 0.] = 1 + return PlaceHolder( + X=X_t, + E=E_t, + y=torch.zeros(bs, 0, device=device, dtype=X_t.dtype) + ), total_probs - prob = product / denom.unsqueeze(-1) # (bs, N, d) +# --------------------------- +# (the rest is unchanged / deterministic math) +# --------------------------- +def compute_posterior_distribution(M, M_t, Qt_M, Qsb_M, Qtb_M): + M = M.flatten(start_dim=1, end_dim=-2).to(torch.float32) + M_t = M_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) + Qt_M_T = torch.transpose(Qt_M, -2, -1) + left_term = M_t @ Qt_M_T + right_term = M @ Qsb_M + product = left_term * right_term + denom = M @ Qtb_M + denom = (denom * M_t).sum(dim=-1) + prob = product / denom.unsqueeze(-1) return prob - def compute_batched_over0_posterior_distribution(X_t, Qt, Qsb, Qtb): - """ M: X or E - Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 - X_t: bs, n, dt or bs, n, n, dt - Qt: bs, d_t-1, dt - Qsb: bs, d0, d_t-1 - Qtb: bs, d0, dt. - """ - # Flatten feature tensors - # Careful with this line. It does nothing if X is a node feature. If X is an edge features it maps to - # bs x (n ** 2) x d - X_t = X_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) # bs x N x dt - - Qt_T = Qt.transpose(-1, -2) # bs, dt, d_t-1 - left_term = X_t @ Qt_T # bs, N, d_t-1 - left_term = left_term.unsqueeze(dim=2) # bs, N, 1, d_t-1 - - right_term = Qsb.unsqueeze(1) # bs, 1, d0, d_t-1 - numerator = left_term * right_term # bs, N, d0, d_t-1 - - X_t_transposed = X_t.transpose(-1, -2) # bs, dt, N - - prod = Qtb @ X_t_transposed # bs, d0, N - prod = prod.transpose(-1, -2) # bs, N, d0 - denominator = prod.unsqueeze(-1) # bs, N, d0, 1 + X_t = X_t.flatten(start_dim=1, end_dim=-2).to(torch.float32) + Qt_T = Qt.transpose(-1, -2) + left_term = X_t @ Qt_T + left_term = left_term.unsqueeze(dim=2) + right_term = Qsb.unsqueeze(1) + numerator = left_term * right_term + X_t_transposed = X_t.transpose(-1, -2) + prod = Qtb @ X_t_transposed + prod = prod.transpose(-1, -2) + denominator = prod.unsqueeze(-1) denominator[denominator == 0] = 1e-6 - out = numerator / denominator return out - def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask): - # Add a small value everywhere to avoid nans pred_X = pred_X + 1e-7 pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True) - pred_E = pred_E + 1e-7 pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True) - - # Set masked rows to arbitrary distributions, so it doesn't contribute to loss - row_X = torch.zeros(true_X.size(-1), dtype=torch.float, device=true_X.device) - row_X[0] = 1. - row_E = torch.zeros(true_E.size(-1), dtype=torch.float, device=true_E.device) - row_E[0] = 1. + row_X = torch.zeros(true_X.size(-1), dtype=torch.float, device=true_X.device); row_X[0] = 1. + row_E = torch.zeros(true_E.size(-1), dtype=torch.float, device=true_E.device); row_E[0] = 1. diag_mask = ~torch.eye(node_mask.size(1), device=node_mask.device, dtype=torch.bool).unsqueeze(0) true_X[~node_mask] = row_X true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E pred_X[~node_mask] = row_X pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E - return true_X, true_E, pred_X, pred_E - def posterior_distributions(X, E, y, X_t, E_t, y_t, Qt, Qsb, Qtb): - prob_X = compute_posterior_distribution(M=X, M_t=X_t, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X) # (bs, n, dx) - prob_E = compute_posterior_distribution(M=E, M_t=E_t, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E) # (bs, n * n, de) - + prob_X = compute_posterior_distribution(M=X, M_t=X_t, Qt_M=Qt.X, Qsb_M=Qsb.X, Qtb_M=Qtb.X) + prob_E = compute_posterior_distribution(M=E, M_t=E_t, Qt_M=Qt.E, Qsb_M=Qsb.E, Qtb_M=Qtb.E) return PlaceHolder(X=prob_X, E=prob_E, y=y_t) +def sample_discrete_feature_noise(limit_dist, node_mask, *, generator=None): + device = node_mask.device + gen = generator or _require_gen(device) -def sample_discrete_feature_noise(limit_dist, node_mask): - """ Sample from the limit distribution of the diffusion process""" bs, n_max = node_mask.shape - x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1) - e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1) - y_limit = limit_dist.y[None, :].expand(bs, -1) - U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max) - U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max) - U_y = torch.empty((bs, 0)) - - long_mask = node_mask.long() - U_X = U_X.type_as(long_mask) - U_E = U_E.type_as(long_mask) - U_y = U_y.type_as(long_mask) - - U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).float() - U_E = F.one_hot(U_E, num_classes=e_limit.shape[-1]).float() - - # Get upper triangular part of edge noise, without main diagonal - upper_triangular_mask = torch.zeros_like(U_E) - indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1) - upper_triangular_mask[:, indices[0], indices[1], :] = 1 + x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1).to(device) + e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1).to(device) + y_limit = limit_dist.y[None, :].expand(bs, -1).to(device) - U_E = U_E * upper_triangular_mask - U_E = (U_E + torch.transpose(U_E, 1, 2)) + U_X = x_limit.flatten(end_dim=-2).multinomial(1, generator=gen).reshape(bs, n_max) + U_E = e_limit.flatten(end_dim=-2).multinomial(1, generator=gen).reshape(bs, n_max, n_max) + U_y = torch.empty((bs, 0), device=device, dtype=U_X.dtype) - assert (U_E == torch.transpose(U_E, 1, 2)).all() + U_X = F.one_hot(U_X, num_classes=x_limit.shape[-1]).to(x_limit.dtype) + U_E = F.one_hot(U_E, num_classes=e_limit.shape[-1]).to(e_limit.dtype) - return PlaceHolder(X=U_X, E=U_E, y=U_y).mask(node_mask) + upper_triangular_mask = torch.zeros_like(U_E, device=device) + i, j = torch.triu_indices(U_E.size(1), U_E.size(2), offset=1, device=device) + upper_triangular_mask[:, i, j, :] = 1 + U_E = U_E * upper_triangular_mask + U_E = U_E + U_E.transpose(1, 2) + assert (U_E == U_E.transpose(1, 2)).all() + return PlaceHolder(X=U_X, E=U_E, y=U_y).mask(node_mask) def cbo0pdi_X(X_t, Qt, Qsb, Qtb): - """ Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 - X_t: (bs, n, dt) - Qt: (bs, n, d_t-1, dt) - Qsb: (bs, n, d0, d_t-1) - Qtb: (bs, n, d_0, dt) - """ - - Qt_T = Qt.transpose(-1, -2) # bs, n, dt, d_t-1 - left_term = X_t.unsqueeze(-2) @ Qt_T # bs, n, 1, d_t-1 - numerator = left_term * Qsb # bs, n, d0, d_t-1 - denominator = Qtb @ X_t.unsqueeze(-1) # bs, n, d_0, 1 + Qt_T = Qt.transpose(-1, -2) + left_term = X_t.unsqueeze(-2) @ Qt_T + numerator = left_term * Qsb + denominator = Qtb @ X_t.unsqueeze(-1) denominator[denominator == 0] = 1e-6 - out = numerator / denominator # bs, n, d0, d_t-1 - + out = numerator / denominator return out - def cbo0pdi_E(E_t, Qt, Qsb, Qtb): - """ Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 - E_t: (bs, n, n, dt) - Qt: (bs, n, n, d_t-1, dt) - Qsb: (bs, n, n, d0, d_t-1) - Qtb: (bs, n, n, d_0, dt) - """ - E_t = E_t.flatten(start_dim=1, end_dim=2).to(torch.float32) # bs, N, dt - Qt = Qt.flatten(start_dim=1, end_dim=2).to(torch.float32) # bs, N, d_t-1, dt - Qsb = Qsb.flatten(start_dim=1, end_dim=2).to(torch.float32) # bs, N, d0, d_t-1 - Qtb = Qtb.flatten(start_dim=1, end_dim=2).to(torch.float32) # bs, N, d_0, dt - - Qt_T = Qt.transpose(-1, -2) # bs, N, dt, d_t-1 - left_term = E_t.unsqueeze(-2) @ Qt_T # bs, N, 1, d_t-1 - numerator = left_term * Qsb # bs, N, d0, d_t-1 - denominator = Qtb @ E_t.unsqueeze(-1) # bs, N, d_0, 1 + E_t = E_t.flatten(start_dim=1, end_dim=2).to(torch.float32) + Qt = Qt.flatten(start_dim=1, end_dim=2).to(torch.float32) + Qsb = Qsb.flatten(start_dim=1, end_dim=2).to(torch.float32) + Qtb = Qtb.flatten(start_dim=1, end_dim=2).to(torch.float32) + Qt_T = Qt.transpose(-1, -2) + left_term = E_t.unsqueeze(-2) @ Qt_T + numerator = left_term * Qsb + denominator = Qtb @ E_t.unsqueeze(-1) denominator[denominator == 0] = 1e-6 - out = numerator / denominator # bs, N, d0, d_t-1 - - return out + out = numerator / denominator + return out \ No newline at end of file diff --git a/src/frameworks/markov_bridge.py b/src/frameworks/markov_bridge.py index 610a26c..b766afd 100644 --- a/src/frameworks/markov_bridge.py +++ b/src/frameworks/markov_bridge.py @@ -1,8 +1,10 @@ +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl -import os from src.data import utils from src.frameworks.noise_schedule import InterpolationTransition, PredefinedNoiseScheduleDiscrete @@ -16,7 +18,6 @@ import numpy as np from rdkit import Chem from pdb import set_trace -torch.set_float32_matmul_precision('medium') class MarkovBridge(pl.LightningModule): @@ -92,6 +93,10 @@ def __init__( self.extra_features = extra_features self.domain_features = domain_features self.use_context = use_context + + self._gen = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") + # set its seed from your main seed after constructing the model + self._gen.manual_seed(42) self.model = GraphTransformer( n_layers=n_layers, @@ -392,7 +397,7 @@ def apply_noise(self, X, E, y, X_T, E_T, y_T, node_mask, s, t_int=None): # When evaluating, the loss for t=0 is computed separately lowest_t = 0 if self.training else 1 if t_int == None: - t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1) + t_int = torch.randint(lowest_t, self.T + 1, (X.size(0), 1), device=X.device, generator=self._gen).float() s_int = t_int - 1 t_float = t_int / self.T @@ -419,7 +424,10 @@ def apply_noise(self, X, E, y, X_T, E_T, y_T, node_mask, s, t_int=None): probX = (X.unsqueeze(-2) @ Qtb.X).squeeze(-2) # (bs, n, dx_out) probE = (E.unsqueeze(-2) @ Qtb.E).squeeze(-2) # (bs, n, n, de_out) - sampled_t,score = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask) + sampled_t, score = diffusion_utils.sample_discrete_features( + probX=probX, probE=probE, node_mask=node_mask, + generator=None # it will use the global one we set + ) X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) @@ -665,12 +673,17 @@ def sample_chain( molecule_list = utils.create_pred_target_molecules(X, E, data.batch, batch_size) # torch.save(molecule_list, f'sampled_s.pt') # exit() + # Replace NaN/inf with large finite sentinels so CSV never writes blanks + nll = torch.nan_to_num(nll, nan=0.0, posinf=1e9, neginf=-1e9) + ell = torch.nan_to_num(ell, nan=0.0, posinf=1e9, neginf=-1e9) + return ( chain_X, chain_E, r_chain_X, r_chain_E, true_molecule_list, scaffolds_list, molecule_list, pred, score, nll.detach().cpu().numpy().tolist(), ell.detach().cpu().numpy().tolist(), ) - + + def visualize( self, chain_X, @@ -763,8 +776,11 @@ def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, X_T, E_T, y_T, node_mask, co assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() - sampled_s, score = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask) - + sampled_s, score = diffusion_utils.sample_discrete_features( + prob_X, prob_E, node_mask=node_mask, + generator=None # use global gen + ) + X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() @@ -774,12 +790,19 @@ def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, X_T, E_T, y_T, node_mask, co out_one_hot = utils.PlaceHolder(X=X_t, E=E_s, y=torch.zeros(y_t.shape[0], 0)) out_discrete = utils.PlaceHolder(X=X_t, E=E_s, y=torch.zeros(y_t.shape[0], 0)) - # Likelihood - node_log_likelihood = torch.log(prob_X) + torch.log(pred_X) + # ---- Likelihood (safe) + eps = 1e-12 + # ensure strictly positive probs before logs + prob_X_safe = prob_X.clamp_min(eps) + prob_E_safe = prob_E.clamp_min(eps) + pred_X_safe = pred_X.clamp_min(eps) + pred_E_safe = pred_E.clamp_min(eps) + + node_log_likelihood = (torch.log(prob_X_safe) + torch.log(pred_X_safe)) node_log_likelihood = (node_log_likelihood * X_s).sum(-1).sum(-1) - edge_log_likelihood = torch.log(prob_E) + torch.log(pred_E) - edge_log_likelihood = (edge_log_likelihood * E_s).sum(-1).sum(-1).sum(-1) # bxnxnxK + edge_log_likelihood = (torch.log(prob_E_safe) + torch.log(pred_E_safe)) + edge_log_likelihood = (edge_log_likelihood * E_s).sum(-1).sum(-1).sum(-1) return ( out_one_hot.mask(node_mask).type_as(y_t), diff --git a/src/frameworks/noise_schedule.py b/src/frameworks/noise_schedule.py index 3bf23ab..bacd7f5 100644 --- a/src/frameworks/noise_schedule.py +++ b/src/frameworks/noise_schedule.py @@ -1,4 +1,7 @@ import numpy as np +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import torch from src.data import utils from src.frameworks import diffusion_utils @@ -104,24 +107,27 @@ def __init__(self, x_classes: int, e_classes: int, y_classes: int): if self.y_classes > 0: self.u_y = self.u_y / self.y_classes - def get_Qt(self, beta_t, device): - """ Returns one-step transition matrices for X and E, from step t - 1 to step t. - Qt = (1 - beta_t) * I + beta_t / K + def get_Qt(self, *, beta_t, X_T, E_T, y_T, node_mask, device): + # ... build q_x (bs, n, dx_in, dx_out) and q_e (bs, n, n, de_in, de_out) ... - beta_t: (bs) noise level between 0 and 1 - returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). - """ - beta_t = beta_t.unsqueeze(1) - beta_t = beta_t.to(device) - self.u_x = self.u_x.to(device) - self.u_e = self.u_e.to(device) - self.u_y = self.u_y.to(device) + # 1) Node transitions: set absent-node rows to identity + mask_x = node_mask.unsqueeze(-1).unsqueeze(-1) # (bs, n, 1, 1), bool + eye_x = torch.eye(q_x.shape[-1], device=device) # (dx_out, dx_out) + eye_x = eye_x.expand(q_x.shape[0], q_x.shape[1], -1, -1) # (bs, n, dx_out, dx_out) - q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye(self.X_classes, device=device).unsqueeze(0) - q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye(self.E_classes, device=device).unsqueeze(0) - q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye(self.y_classes, device=device).unsqueeze(0) + # If dx_in == dx_out this works directly; if your q_x is (dx_in, dx_out) with dx_in==dx_out (=13), ok. + # If your implementation stores (dx_out, dx_out), adjust accordingly. + q_x = torch.where(mask_x, q_x, eye_x) - return utils.PlaceHolder(X=q_x, E=q_e, y=q_y) + # 2) Edge transitions: set edges touching absent nodes to identity + edge_mask = (node_mask.unsqueeze(2) & node_mask.unsqueeze(1)) # (bs, n, n) + mask_e = edge_mask.unsqueeze(-1).unsqueeze(-1) # (bs, n, n, 1, 1) + eye_e = torch.eye(q_e.shape[-1], device=device) # (de_out, de_out) + eye_e = eye_e.expand(q_e.shape[0], q_e.shape[1], q_e.shape[2], -1, -1) # (bs, n, n, de_out, de_out) + + q_e = torch.where(mask_e, q_e, eye_e) + + return PlaceHolder(X=q_x, E=q_e, y=None) def get_Qt_bar(self, alpha_bar_t, device): """ Returns t-step transition matrices for X and E, from step 0 to step t. @@ -237,115 +243,58 @@ def __init__(self, x_classes: int, e_classes: int, y_classes: int): self.E_classes = e_classes self.y_classes = y_classes - def get_Qt(self, beta_t, X_T, E_T, y_T, node_mask, device): - """X_T (bs, n, dx), E_T (bs, n, n, de)""" - """ Returns two transition matrix for X and E""" - - beta_t = beta_t.unsqueeze(1) # (bs, 1, 1) - beta_t = beta_t.to(device) - - q_x_1 = (1 - beta_t) * torch.eye(self.X_classes, device=device) # (bs, dx, dx) - q_x_2 = beta_t.unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs, n, dx, dx) - q_x = q_x_1.unsqueeze(1) + q_x_2 - q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device) - - q_e_1 = (1 - beta_t) * torch.eye(self.E_classes, device=device) # (bs, de, de) - q_e_2 = beta_t.unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs, n, n, de, de) - q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2 - - diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1) - q_e[diag] = torch.eye(q_e.shape[-1], device=device) - - edge_mask = node_mask[:, None, :] & node_mask[:, :, None] - q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device) - - return utils.PlaceHolder(X=q_x, E=q_e, y=y_T) - - def get_Qt_bar(self, alpha_bar_t, X_T, E_T, y_T, node_mask, device): - """ - alpha_bar_t: (bs, 1) - X_T: (bs, n, dx) - E_T: (bs, n, n, de) - y_T: (bs, dy) - - Returns transition matrices for X, E, and y - """ - alpha_bar_t = alpha_bar_t.unsqueeze(1) # (bs, 1, 1) - alpha_bar_t = alpha_bar_t.to(device) - - q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device) # (bs, dx, dx) - q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs, n, dx, dx) - q_x = q_x_1.unsqueeze(1) + q_x_2 - q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device) + def _apply_node_mask_identity(self, q_x, node_mask, device): + # q_x: (bs, n, dx, dx); node_mask: (bs, n) True=present + bs, n, dx, _ = q_x.shape + mask4 = node_mask.unsqueeze(-1).unsqueeze(-1) # (bs,n,1,1) + eye_x = torch.eye(dx, device=device).expand(bs, n, dx, dx) # (bs,n,dx,dx) + return torch.where(mask4, q_x, eye_x) - q_e_1 = alpha_bar_t * torch.eye(self.E_classes, device=device) # (bs, de, de) - q_e_2 = (1 - alpha_bar_t).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs, n, n, de, de) - q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2 + def _apply_edge_mask_identity(self, q_e, node_mask, device): + # q_e: (bs, n, n, de, de) + bs, n, de, _ = q_e.shape[0], q_e.shape[1], q_e.shape[-1], q_e.shape[-1] + eye_e = torch.eye(de, device=device).view(1,1,1,de,de).expand(bs, n, n, de, de) - diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1) - q_e[diag] = torch.eye(q_e.shape[-1], device=device) + # self-edges -> identity + diag = torch.eye(n, dtype=torch.bool, device=device).view(1, n, n, 1, 1).expand(bs, n, n, 1, 1) + q_e = torch.where(diag, eye_e, q_e) - edge_mask = node_mask[:, None, :] & node_mask[:, :, None] - q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device) - - return utils.PlaceHolder(X=q_x, E=q_e, y=y_T) - - -class InterpolationTransition: - def __init__(self, x_classes: int, e_classes: int, y_classes: int): - self.X_classes = x_classes - self.E_classes = e_classes - self.y_classes = y_classes + # edges touching absent node -> identity + edge_present = (node_mask.unsqueeze(2) & node_mask.unsqueeze(1)) # (bs,n,n) + mask_e = edge_present.unsqueeze(-1).unsqueeze(-1) # (bs,n,n,1,1) + q_e = torch.where(mask_e, q_e, eye_e) + return q_e def get_Qt(self, beta_t, X_T, E_T, y_T, node_mask, device): - """X_T (bs, n, dx), E_T (bs, n, n, de)""" - """ Returns two transition matrix for X and E""" - - beta_t = beta_t.unsqueeze(1) # (bs, 1, 1) - beta_t = beta_t.to(device) - - q_x_1 = (1 - beta_t) * torch.eye(self.X_classes, device=device) # (bs, dx, dx) - q_x_2 = beta_t.unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs, n, dx, dx) - q_x = q_x_1.unsqueeze(1) + q_x_2 - q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device) + beta_t = beta_t.unsqueeze(1).to(device) # (bs,1) - q_e_1 = (1 - beta_t) * torch.eye(self.E_classes, device=device) # (bs, de, de) - q_e_2 = beta_t.unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs, n, n, de, de) - q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2 + # nodes + q_x_1 = (1 - beta_t) * torch.eye(self.X_classes, device=device) # (bs,dx,dx) + q_x_2 = beta_t.unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs,n,dx,dx) + q_x = q_x_1.unsqueeze(1) + q_x_2 # (bs,n,dx,dx) + q_x = self._apply_node_mask_identity(q_x, node_mask, device) - diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1) - q_e[diag] = torch.eye(q_e.shape[-1], device=device) - - edge_mask = node_mask[:, None, :] & node_mask[:, :, None] - q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device) + # edges + q_e_1 = (1 - beta_t) * torch.eye(self.E_classes, device=device) # (bs,de,de) + q_e_2 = beta_t.unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs,n,n,de,de) + q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2 # (bs,n,n,de,de) + q_e = self._apply_edge_mask_identity(q_e, node_mask.to(torch.bool), device) return utils.PlaceHolder(X=q_x, E=q_e, y=y_T) def get_Qt_bar(self, alpha_bar_t, X_T, E_T, y_T, node_mask, device): - """ - alpha_bar_t: (bs, 1) - X_T: (bs, n, dx) - E_T: (bs, n, n, de) - y_T: (bs, dy) + alpha_bar_t = alpha_bar_t.unsqueeze(1).to(device) # (bs,1) - Returns transition matrices for X, E, and y - """ - alpha_bar_t = alpha_bar_t.unsqueeze(1) # (bs, 1, 1) - alpha_bar_t = alpha_bar_t.to(device) - - q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device) # (bs, dx, dx) - q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs, n, dx, dx) + # nodes + q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device) + q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) q_x = q_x_1.unsqueeze(1) + q_x_2 - q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device) + q_x = self._apply_node_mask_identity(q_x, node_mask, device) - q_e_1 = alpha_bar_t * torch.eye(self.E_classes, device=device) # (bs, de, de) - q_e_2 = (1 - alpha_bar_t).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs, n, n, de, de) + # edges + q_e_1 = alpha_bar_t * torch.eye(self.E_classes, device=device) + q_e_2 = (1 - alpha_bar_t).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2 - - diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1) - q_e[diag] = torch.eye(q_e.shape[-1], device=device) - - edge_mask = node_mask[:, None, :] & node_mask[:, :, None] - q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device) + q_e = self._apply_edge_mask_identity(q_e, node_mask.to(torch.bool), device) return utils.PlaceHolder(X=q_x, E=q_e, y=y_T) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 4b98b65..b162961 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,7 +2,9 @@ import yaml import numpy as np import random - +import os +os.environ["PYTHONHASHSEED"] = "42" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" import rdkit.rdBase as rkrb import rdkit.RDLogger as rkl import torch # Should be imported after RDKit for some reason @@ -23,6 +25,10 @@ def set_deterministic(seed): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + # added + torch.use_deterministic_algorithms(True) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False def parse_yaml_config(args): @@ -33,3 +39,17 @@ def parse_yaml_config(args): arg_dict[key] = value args.config = args.config.name return args + +# added + +def make_worker_init_fn(base_seed: int): + def _fn(worker_id: int): + s = base_seed + worker_id + import random, numpy as np, torch + random.seed(s); np.random.seed(s); torch.manual_seed(s) + return _fn + +def make_torch_generator(seed: int): + g = torch.Generator() + g.manual_seed(seed) + return g \ No newline at end of file