diff --git a/analysis/chemical_space_plots.py b/analysis/chemical_space_plots.py new file mode 100644 index 0000000..6fb85e0 --- /dev/null +++ b/analysis/chemical_space_plots.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python +""" +Chemical space visualization for a SINGLE generated set. + +Each run plots: + - Papyrus background (50k random subset) + - All inhibitors for a given target (AKT1 / CDK2) + - Generated molecules from ONE phase (I/II/III), + for that target, optionally filtered by encoder. + +Usage examples: + + # AKT1, Phase III, ProtT5 + python analysis/chemical_space_plots.py \ + --target AKT1 \ + --phase 3 \ + --encoder prot_t5 \ + --include-tsne + + # CDK2, Phase I, all encoders together + python analysis/chemical_space_plots.py \ + --target CDK2 \ + --phase I + +Output: + plots/chemical_space/umap_AKT1_phaseIII_prot_t5.png + plots/chemical_space/tsne_AKT1_phaseIII_prot_t5.png (if --include-tsne) +""" + +import os +import sys +import argparse +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +# nice paper-style +plt.style.use("seaborn-v0_8-whitegrid") +plt.rcParams.update( + { + "axes.titlesize": 14, + "axes.labelsize": 12, + "legend.fontsize": 10, + "xtick.labelsize": 10, + "ytick.labelsize": 10, + } +) + +# ------------------------------------------------------------------------- +# Import project utilities +# ------------------------------------------------------------------------- + +# add repo root to path (as you requested) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from prot2mol.utils_fps import generate_morgan_fingerprints_parallel # type: ignore + +try: + import umap # umap-learn +except ImportError as e: + raise ImportError("Please install 'umap-learn' for UMAP.") from e + +try: + from sklearn.manifold import TSNE +except ImportError as e: + raise ImportError("Please install 'scikit-learn' for t-SNE.") from e + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +# ------------------------------------------------------------------------- +# Helpers +# ------------------------------------------------------------------------- + +def _load_parquet(path: Path) -> pd.DataFrame: + if not path.exists(): + raise FileNotFoundError(f"Parquet file not found: {path}") + print(f"[INFO] Loading: {path}") + return pd.read_parquet(path) + + +def normalize_phase(phase_str: str) -> Tuple[str, str]: + """ + Normalize user phase arg to: + - numeric: "1" / "2" / "3" (for filenames) + - roman: "I" / "II" / "III" (for labels) + """ + s = phase_str.strip().upper() + if s in {"1", "I", "PHASE1", "PHASE_I"}: + return "1", "I" + if s in {"2", "II", "PHASE2", "PHASE_II"}: + return "2", "II" + if s in {"3", "III", "PHASE3", "PHASE_III"}: + return "3", "III" + raise ValueError(f"Unrecognized phase: {phase_str}") + + +def load_background_smiles( + processed_ref_dir: Path, n_background: int, random_state: int +) -> List[str]: + """Sample Papyrus background subset.""" + all_path = processed_ref_dir / "papyrus_all_processed.parquet" + df = _load_parquet(all_path) + + if "is_valid" in df.columns: + df = df[df["is_valid"]] + + smiles = df["smiles_canonical"].dropna().astype(str) + if len(smiles) <= n_background: + print( + f"[WARN] Requested {n_background} background; " + f"only {len(smiles)} available. Using all." + ) + subset = smiles + else: + subset = smiles.sample(n=n_background, random_state=random_state) + + print(f"[INFO] Background molecules: {len(subset)}") + return subset.tolist() + + +def load_inhibitor_smiles(processed_ref_dir: Path, target: str) -> List[str]: + """Load all inhibitors for a given target (AKT1 / CDK2).""" + target_lower = target.lower() + ref_path = processed_ref_dir / f"{target_lower}_ref_processed.parquet" + df = _load_parquet(ref_path) + + if "is_valid" in df.columns: + df = df[df["is_valid"]] + + smiles = df["smiles_canonical"].dropna().astype(str) + print(f"[INFO] {target} inhibitors: {len(smiles)} molecules") + return smiles.tolist() + + +def load_generated_smiles_single( + processed_gen_dir: Path, + target: str, + phase_num: str, + encoder: str = None, + max_generated: int = None, +) -> List[str]: + """ + Load generated molecules for ONE phase + ONE target. + Optionally filter by encoder; optionally subsample. + """ + target_lower = target.lower() + path = processed_gen_dir / f"phase{phase_num}_{target_lower}_processed.parquet" + df = _load_parquet(path) + + mask = pd.Series(True, index=df.index) + + # safety: filter again by target / encoder / validity + if "target_id" in df.columns: + mask &= df["target_id"].astype(str).str.upper() == target.upper() + if encoder is not None and "encoder" in df.columns: + mask &= df["encoder"].astype(str) == encoder + if "is_valid" in df.columns: + mask &= df["is_valid"] + + df = df[mask] + smiles = df["smiles_canonical"].dropna().astype(str) + + if max_generated is not None and len(smiles) > max_generated: + smiles = smiles.sample(n=max_generated, random_state=42) + + print( + f"[INFO] Generated set – target={target}, phase={phase_num}, " + f"encoder={encoder if encoder else 'ALL'}, " + f"mols={len(smiles)}" + ) + return smiles.tolist() + + +def build_fps_and_labels( + bg_smiles: List[str], + inh_smiles: List[str], + gen_smiles: List[str], +) -> Tuple[np.ndarray, np.ndarray, List[str]]: + """ + Build fingerprint matrix + labels for three datasets: + 0: Background + 1: Reference inhibitors + 2: Generated (single set) + """ + datasets = [ + ("Background", bg_smiles), + ("Reference inhibitors", inh_smiles), + ("Generated", gen_smiles), + ] + + smiles_all: List[str] = [] + labels_idx: List[int] = [] + + for idx, (_, smiles) in enumerate(datasets): + smiles_all.extend(smiles) + labels_idx.extend([idx] * len(smiles)) + + print(f"[INFO] Total molecules in embedding: {len(smiles_all)}") + + fps = generate_morgan_fingerprints_parallel( + smiles_all, radius=2, nBits=1024, n_jobs=None + ) + labels_idx_arr = np.array(labels_idx, dtype=np.int32) + label_names = [name for name, _ in datasets] + + return fps, labels_idx_arr, label_names + + +def compute_umap(fps: np.ndarray, random_state: int) -> np.ndarray: + """2D UMAP (Dice distance).""" + reducer = umap.UMAP( + n_neighbors=50, + min_dist=0.8, + metric="dice", + random_state=random_state, + ) + return reducer.fit_transform(fps) + + +def compute_tsne(fps: np.ndarray, random_state: int) -> np.ndarray: + """2D t-SNE (Jaccard distance over binary fingerprints).""" + tsne = TSNE( + n_components=2, + metric="jaccard", + perplexity=30, + learning_rate="auto", + init="pca", + random_state=random_state, + ) + return tsne.fit_transform(fps) + + +def plot_embedding( + coords: np.ndarray, + labels_idx: np.ndarray, + label_names: List[str], + target: str, + phase_roman: str, + encoder: str, + method: str, + output_dir: Path, +) -> None: + """Scatter plot for one embedding (UMAP or t-SNE).""" + encoder_label = encoder.upper() if encoder is not None else "ALL" + method_lower = method.lower() + + # colors similar to DrugGEN style + color_bg = "#B0B0B0" # light grey + color_ref = "#F6A800" # orange + color_gen = "#1f77b4" # blue + + fig, ax = plt.subplots(figsize=(8, 6)) + + for idx, name in enumerate(label_names): + mask = labels_idx == idx + if not np.any(mask): + continue + x = coords[mask, 0] + y = coords[mask, 1] + + if name == "Background": + ax.scatter( + x, + y, + s=3, + c=color_bg, + alpha=0.15, + linewidth=0, + label="Papyrus background (50k)", + ) + elif name == "Reference inhibitors": + ax.scatter( + x, + y, + s=12, + c=color_ref, + alpha=0.9, + linewidth=0.3, + edgecolors="black", + label=f"{target} inhibitors", + ) + else: # Generated + gen_label = f"Generated – Phase {phase_roman}" + if encoder is not None: + gen_label += f" ({encoder_label})" + ax.scatter( + x, + y, + s=10, + c=color_gen, + alpha=0.9, + linewidth=0, + label=gen_label, + ) + + if method_lower == "umap": + ax.set_xlabel("UMAP-1") + ax.set_ylabel("UMAP-2") + else: + ax.set_xlabel("t-SNE-1") + ax.set_ylabel("t-SNE-2") + + title = f"{method} projection – {target}, Phase {phase_roman}" + if encoder is not None: + title += f" ({encoder_label})" + ax.set_title(title) + + ax.legend( + frameon=True, + loc="upper left", + bbox_to_anchor=(1.02, 1.0), + borderaxespad=0.0, + ) + + fig.tight_layout() + + output_dir.mkdir(parents=True, exist_ok=True) + fname_parts = [ + method_lower, + target.upper(), + f"phase{phase_roman}", + ] + if encoder is not None: + fname_parts.append(encoder) + fname = "_".join(fname_parts) + ".png" + out_path = output_dir / fname + fig.savefig(out_path, dpi=300) + plt.close(fig) + + print(f"[INFO] Saved {method} plot to: {out_path}") + + +# ------------------------------------------------------------------------- +# CLI +# ------------------------------------------------------------------------- + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="UMAP / t-SNE chemical space plot for a SINGLE generated set." + ) + + default_proc_dir = REPO_ROOT / "data" / "processed" + default_out_dir = REPO_ROOT / "plots" / "chemical_space" + + parser.add_argument( + "--target", + type=str, + required=True, + help="Target to analyze, e.g. AKT1 or CDK2.", + ) + parser.add_argument( + "--phase", + type=str, + required=True, + help="Phase of generated set: 1 / 2 / 3 or I / II / III.", + ) + parser.add_argument( + "--encoder", + type=str, + default=None, + help="Optional encoder filter (e.g. 'prot_t5', 'esm2'). " + "If omitted, all encoders in that phase file are used.", + ) + parser.add_argument( + "--n-background", + type=int, + default=50000, + help="Number of Papyrus background molecules (default: 50000).", + ) + parser.add_argument( + "--max-generated", + type=int, + default=None, + help="Optional cap on number of generated molecules " + "from this phase (e.g. 5000). Default: use all.", + ) + parser.add_argument( + "--include-tsne", + action="store_true", + help="Also compute and save t-SNE plot.", + ) + parser.add_argument( + "--processed-dir", + type=str, + default=str(default_proc_dir), + help=f"Base processed data dir (default: {default_proc_dir}).", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(default_out_dir), + help=f"Directory to save plots (default: {default_out_dir}).", + ) + parser.add_argument( + "--random-state", + type=int, + default=42, + help="Random seed for sampling / embeddings (default: 42).", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + proc_dir = Path(args.processed_dir).resolve() + processed_ref_dir = proc_dir / "reference" + processed_gen_dir = proc_dir / "generated" + output_dir = Path(args.output_dir).resolve() + + target = args.target.upper() + phase_num, phase_roman = normalize_phase(args.phase) + encoder = args.encoder + + print("[INFO] Repository root :", REPO_ROOT) + print("[INFO] Processed dir :", proc_dir) + print("[INFO] Target :", target) + print("[INFO] Phase :", phase_roman) + print("[INFO] Encoder filter :", encoder if encoder else "ALL") + + # 1) Load datasets + bg_smiles = load_background_smiles( + processed_ref_dir, + n_background=args.n_background, + random_state=args.random_state, + ) + inh_smiles = load_inhibitor_smiles(processed_ref_dir, target) + gen_smiles = load_generated_smiles_single( + processed_gen_dir, + target=target, + phase_num=phase_num, + encoder=encoder, + max_generated=args.max_generated, + ) + + if len(gen_smiles) == 0: + print("[WARN] No generated molecules found for this configuration.") + return + + # 2) Fingerprints + labels + fps, labels_idx, label_names = build_fps_and_labels( + bg_smiles, inh_smiles, gen_smiles + ) + + # 3) UMAP + print("[INFO] Computing UMAP embedding...") + umap_coords = compute_umap(fps, random_state=args.random_state) + plot_embedding( + coords=umap_coords, + labels_idx=labels_idx, + label_names=label_names, + target=target, + phase_roman=phase_roman, + encoder=encoder, + method="UMAP", + output_dir=output_dir, + ) + + # 4) t-SNE (optional) + if args.include_tsne: + print("[INFO] Computing t-SNE embedding (may be slow)...") + tsne_coords = compute_tsne(fps, random_state=args.random_state) + plot_embedding( + coords=tsne_coords, + labels_idx=labels_idx, + label_names=label_names, + target=target, + phase_roman=phase_roman, + encoder=encoder, + method="tSNE", + output_dir=output_dir, + ) + + print("[INFO] Done.") + + +if __name__ == "__main__": + main() + diff --git a/analysis/phase1_plots.py b/analysis/phase1_plots.py new file mode 100644 index 0000000..0cd9533 --- /dev/null +++ b/analysis/phase1_plots.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python +""" +Phase 1 analysis: 1D distribution plots for generated vs. reference molecules. + +Goal: +- Check whether generated molecules physically resemble the known inhibitors + in the reference dataset (AKT1 or CDK2). +- Properties to analyze: + - Molecular Weight (MW) + - LogP (Lipophilicity) + - Hydrogen Bond Donors (HBD) + - Hydrogen Bond Acceptors (HBA) + - QED + - SA (Synthetic Accessibility) + +Visualization: +- Overlay distribution plots (histograms or KDE curves) for: + - Multiple generated sets (e.g., Phase I, II, III for a given target) + - Reference training dataset (e.g., akt1_ref_processed.parquet) + +Usage example: + + python analysis/phase1_plots.py \ + --generated data/processed/generated/phase1_akt1_processed.parquet \ + --generated data/processed/generated/phase2_akt1_processed.parquet \ + --generated data/processed/generated/phase3_akt1_processed.parquet \ + --reference data/processed/reference/akt1_ref_processed.parquet \ + --target AKT1 \ + --properties mw logp hbd hba qed sa \ + --plot-type kde +""" + +import sys +import argparse +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +# clean, paper-friendly style +plt.style.use("seaborn-v0_8-whitegrid") +plt.rcParams.update({ + "axes.titlesize": 14, + "axes.labelsize": 12, + "legend.fontsize": 10, + "xtick.labelsize": 10, + "ytick.labelsize": 10, +}) + +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(REPO_ROOT)) + +# Property configuration: column name in processed parquet and pretty label +PROPERTY_INFO: Dict[str, Dict[str, str]] = { + "mw": {"col": "mw", "label": "Molecular weight (Da)"}, + "logp": {"col": "logp", "label": "LogP"}, + "hbd": {"col": "hbd", "label": "H-bond donors"}, + "hba": {"col": "hba", "label": "H-bond acceptors"}, + "qed": {"col": "qed", "label": "QED score"}, + "sa": {"col": "sa", "label": "SA score"}, +} + +# ------------------------------------------------------------------------- +# Helper functions +# ------------------------------------------------------------------------- + +def infer_generated_label(df: pd.DataFrame, fallback: str) -> str: + """Label generated dataset as 'Phase X (TARGET)' if possible.""" + phase = None + target = None + + if "phase" in df.columns and df["phase"].notna().any(): + vals = df["phase"].dropna().unique() + if len(vals) == 1: + phase = str(vals[0]) + + if "target_id" in df.columns and df["target_id"].notna().any(): + vals = df["target_id"].dropna().unique() + if len(vals) == 1: + target = str(vals[0]) + + if phase is not None and target is not None: + return f"Phase {phase} ({target})" + if phase is not None: + return f"Phase {phase}" + return fallback + +def infer_reference_label(df_ref: pd.DataFrame, fallback: str = "Reference") -> str: + """Label reference dataset as 'Reference (TARGET)' if possible.""" + label = fallback + if "target_symbol" in df_ref.columns and df_ref["target_symbol"].notna().any(): + vals = df_ref["target_symbol"].dropna().unique() + if len(vals) == 1: + label = f"Reference ({vals[0]})" + return label + +def parse_properties(props: List[str]) -> List[str]: + props = [p.lower() for p in props] + if "all" in props: + return list(PROPERTY_INFO.keys()) + + unknown = [p for p in props if p not in PROPERTY_INFO] + if unknown: + raise ValueError( + f"Unknown properties requested: {unknown}. " + f"Valid options: {list(PROPERTY_INFO.keys())} or 'all'." + ) + return props + +def filter_df_for_target_and_encoder( + df: pd.DataFrame, + target: Optional[str], + encoder: Optional[str], + is_reference: bool = False, +) -> pd.DataFrame: + """ + Filter a dataframe by target and encoder if those filters are provided. + + - For generated sets: + * target name (AKT1/CDK2) -> column 'target_id' + * CHEMBL id (CHEMBL4282) -> column 'protein_chembl_id' + - For reference sets: + * target name -> 'target_symbol' + * CHEMBL id -> 'target_chembl_id' + """ + filtered = df.copy() + + if target is not None: + t_upper = target.upper() + is_chembl = t_upper.startswith("CHEMBL") + + col_to_use = None + if is_reference: + if is_chembl and "target_chembl_id" in filtered.columns: + col_to_use = "target_chembl_id" + elif "target_symbol" in filtered.columns: + col_to_use = "target_symbol" + elif "target_id" in filtered.columns: + col_to_use = "target_id" + else: + if is_chembl and "protein_chembl_id" in filtered.columns: + col_to_use = "protein_chembl_id" + elif "target_id" in filtered.columns: + col_to_use = "target_id" + elif "target_symbol" in filtered.columns: + col_to_use = "target_symbol" + + if col_to_use is not None: + filtered = filtered[ + filtered[col_to_use].astype(str).str.upper() == t_upper + ] + + if encoder is not None and not is_reference: + if "encoder" in filtered.columns: + filtered = filtered[filtered["encoder"] == encoder] + + return filtered + +def ensure_non_empty(df: pd.DataFrame, label: str, prop_col: str) -> np.ndarray: + """Return non-NaN values for the given column (as float array).""" + if prop_col not in df.columns: + print(f"[WARN] Column '{prop_col}' not found in dataset '{label}'. Skipping.") + return np.array([]) + + values = df[prop_col].to_numpy(dtype=float) + values = values[~np.isnan(values)] + if values.size == 0: + print(f"[WARN] No valid values for '{prop_col}' in dataset '{label}'. Skipping.") + return values + +# ------------------------------------------------------------------------- +# Plotting +# ------------------------------------------------------------------------- + +def _compute_x_limits(arrays: List[np.ndarray]) -> Optional[tuple]: + """ + Compute x-limits from all value arrays. + + We use a wide percentile range (0.5–99.5) and a bit more padding so that + the right tail of the KDE can smoothly decay towards zero instead of + being cut in the middle of the slope. + """ + non_empty = [a for a in arrays if a.size > 0] + if not non_empty: + return None + + concat = np.concatenate(non_empty) + q_low, q_high = np.percentile(concat, [0.5, 99.5]) + + if q_low == q_high: + # fallback if all values identical + return q_low - 1.0, q_high + 1.0 + + # slightly larger padding than before + pad = 0.10 * (q_high - q_low) + return q_low - pad, q_high + pad + +def plot_property_distribution( + prop_key: str, + reference: Dict, + generated_list: List[Dict], + plot_type: str, + output_dir: Path, + target: Optional[str] = None, + encoder: Optional[str] = None, + bins: int = 60, +) -> None: + """ + Plot overlayed 1D distributions (hist or kde) for a single property. + + KDE uses seaborn.kdeplot, with filled reference and line-only generated sets. + Legend labels include mean value like DrugGEN figures. + """ + info = PROPERTY_INFO[prop_key] + col = info["col"] + xlabel = info["label"] + + ref_vals = reference["values"] + ref_label = reference["label"] + + # Prepare all values for x-limits + all_vals = [ref_vals] + [g["values"] for g in generated_list] + x_limits = _compute_x_limits(all_vals) + + fig, ax = plt.subplots(figsize=(8, 6)) + + # ---- reference ---- + if ref_vals.size > 0: + mean_ref = float(ref_vals.mean()) + label_ref = f"{ref_label} ({mean_ref:.2f})" + + if plot_type == "hist": + ax.hist( + ref_vals, + bins=bins, + density=True, + histtype="stepfilled", + alpha=0.4, + label=label_ref, + ) + elif plot_type == "kde": + sns.kdeplot( + x=ref_vals, + ax=ax, + fill=True, + alpha=0.35, + linewidth=2, + label=label_ref, + bw_adjust=1.2, + gridsize=512, + cut=0, + ) + + else: + raise ValueError(f"Unsupported plot_type: {plot_type}") + + # ---- generated sets ---- + for gen in generated_list: + vals = gen["values"] + if vals.size == 0: + continue + mean_gen = float(vals.mean()) + label_gen = f"{gen['label']} ({mean_gen:.2f})" + + if plot_type == "hist": + ax.hist( + vals, + bins=bins, + density=True, + histtype="step", + linewidth=2, + alpha=0.9, + label=label_gen, + ) + elif plot_type == "kde": + sns.kdeplot( + x=vals, + ax=ax, + fill=True, + linewidth=2, + label=label_gen, + bw_adjust=1.2, + gridsize=512, + cut=0, + ) + + # ---- axis labels / title / legend ---- + ax.set_xlabel(xlabel) + ax.set_ylabel("Density") + + title_parts = [f"Distribution of {xlabel}"] + if target is not None: + title_parts.append(f"(Target: {target})") + if encoder is not None: + title_parts.append(f"[Encoder: {encoder}]") + ax.set_title(" ".join(title_parts)) + + if x_limits is not None: + ax.set_xlim(*x_limits) + + ax.legend(frameon=True) + + fig.tight_layout() + + # filename + target_part = f"_target-{target}" if target is not None else "" + encoder_part = f"_encoder-{encoder}" if encoder is not None else "" + fname = f"{prop_key}_{plot_type}{target_part}{encoder_part}.png" + out_path = output_dir / fname + + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=300) + plt.close(fig) + + print(f"[INFO] Saved plot: {out_path} (property={col}, type={plot_type})") + +# ------------------------------------------------------------------------- +# CLI +# ------------------------------------------------------------------------- + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Phase 1: plot 1D distributions (generated vs reference) " + "for MW, LogP, HBD/HBA, QED, SA." + ) + + default_output_dir = REPO_ROOT / "plots" / "phase1" + + parser.add_argument( + "--generated", + type=str, + nargs="+", + required=True, + help="Path(s) to processed generated parquet files " + "(e.g., phase1_akt1_processed.parquet phase2_akt1_processed.parquet ...).", + ) + parser.add_argument( + "--reference", + type=str, + required=True, + help="Path to processed reference parquet file " + "(e.g., akt1_ref_processed.parquet or cdk2_ref_processed.parquet).", + ) + parser.add_argument( + "--properties", + type=str, + nargs="+", + default=["mw", "logp", "hbd", "hba", "qed", "sa"], + help="Properties to plot. Options: mw logp hbd hba qed sa or 'all'. " + "Default: mw logp hbd hba qed sa", + ) + parser.add_argument( + "--plot-type", + type=str, + choices=["hist", "kde"], + default="kde", + help="Type of 1D distribution plot: 'hist' or 'kde' (default: kde).", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(default_output_dir), + help=f"Directory to save plots (default: {default_output_dir})", + ) + parser.add_argument( + "--target", + type=str, + default=None, + help="Optional target filter (e.g., AKT1, CDK2 or CHEMBL4282).", + ) + parser.add_argument( + "--encoder", + type=str, + default=None, + help="Optional encoder filter for generated datasets " + "(e.g., 'prot_t5' or 'esm2').", + ) + parser.add_argument( + "--bins", + type=int, + default=60, + help="Number of bins for histogram (default: 60).", + ) + + return parser.parse_args() + +# ------------------------------------------------------------------------- +# Main +# ------------------------------------------------------------------------- + +def main() -> None: + args = parse_args() + output_dir = Path(args.output_dir).resolve() + + # 1) reference + ref_path = Path(args.reference).resolve() + if not ref_path.exists(): + raise FileNotFoundError(f"Reference parquet file not found: {ref_path}") + + print(f"[INFO] Loading reference dataset: {ref_path}") + df_ref = pd.read_parquet(ref_path) + df_ref = filter_df_for_target_and_encoder(df_ref, args.target, encoder=None, is_reference=True) + ref_label = infer_reference_label(df_ref) + print(f"[INFO] Reference label: {ref_label}, rows after filtering: {len(df_ref)}") + + # 2) generated + generated_datasets: List[Dict] = [] + for gpath_str in args.generated: + gpath = Path(gpath_str).resolve() + if not gpath.exists(): + print(f"[WARN] Generated parquet not found, skipping: {gpath}") + continue + + print(f"[INFO] Loading generated dataset: {gpath}") + df_gen = pd.read_parquet(gpath) + df_gen = filter_df_for_target_and_encoder( + df_gen, + target=args.target, + encoder=args.encoder, + is_reference=False, + ) + + stem = gpath.stem + if stem.endswith("_processed"): + stem = stem[:-10] + gen_label = infer_generated_label(df_gen, fallback=stem) + + print(f"[INFO] Generated label: {gen_label}, rows after filtering: {len(df_gen)}") + + generated_datasets.append({"label": gen_label, "df": df_gen}) + + if len(generated_datasets) == 0: + print("[WARN] No generated datasets loaded; nothing to plot.") + return + + # 3) properties + props = parse_properties(args.properties) + print(f"[INFO] Properties to plot: {props}") + print(f"[INFO] Plot type: {args.plot_type}") + if args.target: + print(f"[INFO] Target filter: {args.target}") + if args.encoder: + print(f"[INFO] Encoder filter: {args.encoder}") + print(f"[INFO] Output directory: {output_dir}") + + # 4) plot each property + for prop_key in props: + col = PROPERTY_INFO[prop_key]["col"] + + ref_vals = ensure_non_empty(df_ref, ref_label, col) + ref_dict = {"label": ref_label, "values": ref_vals} + + gen_value_dicts: List[Dict] = [] + for item in generated_datasets: + label = item["label"] + df_gen = item["df"] + vals = ensure_non_empty(df_gen, label, col) + if vals.size > 0: + gen_value_dicts.append({"label": label, "values": vals}) + + if (ref_vals.size == 0) and (len(gen_value_dicts) == 0): + print(f"[WARN] No valid values for property '{col}' across all datasets; skipping plot.") + continue + + plot_property_distribution( + prop_key=prop_key, + reference=ref_dict, + generated_list=gen_value_dicts, + plot_type=args.plot_type, + output_dir=output_dir, + target=args.target, + encoder=args.encoder, + bins=args.bins, + ) + + print("[INFO] Phase 1 distribution plotting completed.") + +if __name__ == "__main__": + main() + diff --git a/analysis/prepare_datasets.py b/analysis/prepare_datasets.py new file mode 100644 index 0000000..b60247d --- /dev/null +++ b/analysis/prepare_datasets.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python +""" +Stage 1: Prepare processed datasets for downstream analysis. + +- Read raw generated molecule CSVs (Phase I/II/III, AKT1/CDK2). +- Read raw Papyrus filtered dataset. +- Decode SELFIES to SMILES where needed. +- Canonicalize SMILES and compute RDKit descriptors (MW, HBD, HBA, etc.). +- Normalize column names (sa, qed, logp, predicted_pchembl, etc.). +- Save processed datasets under data/processed/... + +Raw datasets are NEVER modified; only new processed copies are written. + +You can run this script from anywhere, e.g.: + + python analysis/prepare_datasets.py \ + --raw-dir /path/to/Prot2Mol/data/raw \ + --processed-dir /path/to/Prot2Mol/data/processed + +If you omit the flags, defaults are inferred relative to the repository root. +""" + +import os +import sys +import argparse +from pathlib import Path +from typing import Optional, List, Dict + +import numpy as np +import pandas as pd +import selfies as sf +from rdkit import Chem, RDLogger +from rdkit.Chem import Descriptors, Lipinski + +# ------------------------------------------------------------------------- +# Make sure we can import the prot2mol package regardless of CWD +# ------------------------------------------------------------------------- + +# Repo root = parent of this "analysis" directory +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(REPO_ROOT)) + +# Now we can import utils from prot2mol +from prot2mol.utils import ( # type: ignore + get_mol, + canonic_smiles, + sascorer_calculation, + qed_calculation, + logp_calculation, +) + +RDLogger.DisableLog("rdApp.*") + +# ------------------------------------------------------------------------- +# Static configuration (non-path) +# ------------------------------------------------------------------------- + +# Mapping from protein CHEMBL ID to a human-readable target symbol. +# !!! IMPORTANT: update these IDs to match your dataset exactly. !!! +TARGET_CHEMBL_TO_SYMBOL: Dict[str, str] = { + "CHEMBL4282": "AKT1", # example for AKT1 + "CHEMBL301": "CDK2", # example for CDK2 +} + +# Raw CSV reading configuration defaults +DEFAULT_GENERATED_CSV_SEP = "\t" # generation output often uses tab-separated format +DEFAULT_PAPYRUS_CSV_SEP = "," # Papyrus filtered CSV is usually comma-separated + +# ------------------------------------------------------------------------- +# Helper functions (path-independent) +# ------------------------------------------------------------------------- + +def safe_decode_selfies(selfies_str: Optional[str]) -> Optional[str]: + """Decode SELFIES string to SMILES. Return None if decoding fails.""" + if selfies_str is None or (isinstance(selfies_str, float) and np.isnan(selfies_str)): + return None + try: + return sf.decoder(selfies_str) + except Exception: + return None + +def add_common_rdkit_descriptors( + df: pd.DataFrame, + smiles_col: str = "smiles_raw", + recompute_sa_qed_logp: bool = False, + existing_sa_col: Optional[str] = None, + existing_qed_col: Optional[str] = None, + existing_logp_col: Optional[str] = None, +) -> pd.DataFrame: + """ + Add common RDKit-based descriptors to the DataFrame: + - smiles_canonical + - is_valid + - mw (molecular weight) + - hbd (H-bond donors) + - hba (H-bond acceptors) + - sa, qed, logp (either copied from existing columns or recomputed) + + Parameters + ---------- + df : pd.DataFrame + Input dataframe with at least a column smiles_col. + smiles_col : str + Column name containing SMILES strings to canonicalize and featurize. + recompute_sa_qed_logp : bool + If True, recompute SA/QED/logP using RDKit functions. + If False, try to use existing columns (existing_sa_col, existing_qed_col, existing_logp_col). + existing_sa_col, existing_qed_col, existing_logp_col : Optional[str] + Column names to copy SA/QED/LogP from if recompute_sa_qed_logp is False. + + Returns + ------- + pd.DataFrame + DataFrame with new descriptor columns added. + """ + # Canonical SMILES + df["smiles_canonical"] = df[smiles_col].apply( + lambda s: canonic_smiles(s) if isinstance(s, str) and len(s) > 0 else None + ) + + # Build RDKit Mol objects once to reuse + mols: List[Optional[Chem.Mol]] = [get_mol(smi) for smi in df["smiles_canonical"]] + df["is_valid"] = [mol is not None for mol in mols] + + # MW, HBD, HBA (NaN for invalid molecules) + df["mw"] = [ + float(Descriptors.MolWt(mol)) if mol is not None else np.nan + for mol in mols + ] + df["hbd"] = [ + float(Lipinski.NumHDonors(mol)) if mol is not None else np.nan + for mol in mols + ] + df["hba"] = [ + float(Lipinski.NumHAcceptors(mol)) if mol is not None else np.nan + for mol in mols + ] + + # SA, QED, LogP + if (not recompute_sa_qed_logp) and all( + col in df.columns + for col in [existing_sa_col, existing_qed_col, existing_logp_col] + ): + # Copy from existing columns + df["sa"] = df[existing_sa_col] # type: ignore[index] + df["qed"] = df[existing_qed_col] # type: ignore[index] + df["logp"] = df[existing_logp_col] # type: ignore[index] + else: + # Recompute from RDKit Mols + df["sa"] = sascorer_calculation(mols) + df["qed"] = qed_calculation(mols) + df["logp"] = logp_calculation(mols) + + return df + +# ------------------------------------------------------------------------- +# Generated dataset processing +# ------------------------------------------------------------------------- + +def process_single_generated_dataset( + raw_path: Path, + processed_path: Path, + phase_label: str, + target_label: str, + csv_sep: str = DEFAULT_GENERATED_CSV_SEP, +) -> None: + """ + Process a single raw generated dataset (Phase I/II/III for a given target) + into a processed parquet file with standardized columns and descriptors. + + Parameters + ---------- + raw_path : Path + Path to the raw generation CSV file. + processed_path : Path + Path where the processed parquet file will be saved. + phase_label : str + Phase label (e.g., "I", "II", "III"). + target_label : str + Human-readable target label (e.g., "AKT1", "CDK2"). + csv_sep : str + Column separator used in the raw CSV file (default: tab). + """ + if not raw_path.exists(): + raise FileNotFoundError(f"Raw generated file not found: {raw_path}") + + print(f"[INFO] Processing generated dataset: {raw_path.name} " + f"(phase={phase_label}, target={target_label})") + + df = pd.read_csv(raw_path, sep=csv_sep) + + # Basic metadata + df["source_type"] = "generated" + df["phase"] = phase_label + df["target_id"] = target_label + + # Protein CHEMBL ID and mapping to target symbol (if desired) + df["protein_chembl_id"] = df["Protein_ID"] + df["target_symbol"] = df["protein_chembl_id"].map( + TARGET_CHEMBL_TO_SYMBOL + ).fillna(df["protein_chembl_id"]) + + # Normalize/rename key columns (keep original names too if you like) + df["generated_selfies"] = df["Generated_SELFIES"] + df["smiles_raw"] = df["Generated_SMILES"] + df["encoder"] = df["Protein_Encoder"] + df["model_name"] = df["Model_Name"] + df["predicted_pchembl"] = df["Predicted_pChEMBL"] + + df["generation_temperature"] = df["Generation_Temperature"] + df["generation_top_p"] = df["Generation_Top_p"] + df["max_length"] = df["Max_Length"] + df["batch_size"] = df["Batch_Size"] + df["generation_timestamp"] = df["Generation_Timestamp"] + + df["test_similarity"] = df["Test_Similarity"] + df["train_similarity"] = df["Train_Similarity"] + + # Add RDKit descriptors; SA/QED/LogP are already present in columns + df = add_common_rdkit_descriptors( + df, + smiles_col="smiles_raw", + recompute_sa_qed_logp=False, + existing_sa_col="SA_Score", + existing_qed_col="QED_Score", + existing_logp_col="LogP_Score", + ) + + # Keep only the columns we care about (optional, but keeps files clean) + keep_cols = [ + # metadata + "source_type", + "phase", + "target_id", + "protein_chembl_id", + "target_symbol", + "encoder", + "model_name", + "generation_temperature", + "generation_top_p", + "max_length", + "batch_size", + "generation_timestamp", + # sequences + "generated_selfies", + "smiles_raw", + "smiles_canonical", + "is_valid", + # scores & similarities + "predicted_pchembl", + "sa", + "qed", + "logp", + "test_similarity", + "train_similarity", + # RDKit descriptors + "mw", + "hbd", + "hba", + ] + + keep_cols = [c for c in keep_cols if c in df.columns] + df = df[keep_cols] + + processed_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(processed_path, index=False) + print(f"[INFO] Saved processed generated dataset to: {processed_path}") + +def process_all_generated_datasets( + raw_generated_dir: Path, + proc_generated_dir: Path, + generated_sep: str = DEFAULT_GENERATED_CSV_SEP, +) -> None: + """ + Run processing for all generated datasets with a fixed naming scheme. + + This assumes the following potential raw filenames inside raw_generated_dir: + + phase1_akt1.csv + phase1_cdk2.csv + phase2_akt1.csv + phase2_cdk2.csv + phase3_akt1.csv + phase3_cdk2.csv + + For any of these that do NOT exist on disk, the script will just print a warning + and skip them instead of raising an error. + """ + config = [ + {"name": "phase1_akt1", "phase": "I", "target_label": "AKT1"}, + {"name": "phase1_cdk2", "phase": "I", "target_label": "CDK2"}, + {"name": "phase2_akt1", "phase": "II", "target_label": "AKT1"}, + {"name": "phase2_cdk2", "phase": "II", "target_label": "CDK2"}, + {"name": "phase3_akt1", "phase": "III", "target_label": "AKT1"}, + {"name": "phase3_cdk2", "phase": "III", "target_label": "CDK2"}, + ] + + any_processed = False + + for cfg in config: + raw_path = raw_generated_dir / f"{cfg['name']}.csv" + proc_path = proc_generated_dir / f"{cfg['name']}_processed.parquet" + + if not raw_path.exists(): + print(f"[WARN] Skipping {raw_path.name}: file not found in {raw_generated_dir}") + continue + + process_single_generated_dataset( + raw_path=raw_path, + processed_path=proc_path, + phase_label=cfg["phase"], + target_label=cfg["target_label"], + csv_sep=generated_sep, + ) + any_processed = True + + if not any_processed: + print("[WARN] No generated CSV files were found; nothing was processed.") + +# ------------------------------------------------------------------------- +# Papyrus / reference dataset processing +# ------------------------------------------------------------------------- + +def process_papyrus_reference( + raw_path: Path, + out_all_path: Path, + out_akt1_path: Path, + out_cdk2_path: Path, + papyrus_sep: str = DEFAULT_PAPYRUS_CSV_SEP, + drop_duplicate_smiles: bool = True, +) -> None: + """ + Process Papyrus filtered dataset into: + - papyrus_all_processed.parquet + - akt1_ref_processed.parquet + - cdk2_ref_processed.parquet + + This step: + - decodes SELFIES to SMILES, + - canonicalizes SMILES, + - computes RDKit descriptors (SA, QED, LogP, MW, HBD, HBA), + - optionally drops duplicate SMILES within each subset. + """ + if not raw_path.exists(): + raise FileNotFoundError(f"Raw Papyrus file not found: {raw_path}") + + print(f"[INFO] Processing Papyrus reference dataset: {raw_path.name}") + + df = pd.read_csv(raw_path, sep=papyrus_sep) + + # Basic metadata + df["source_type"] = "ref_papyrus_all" + df["target_chembl_id"] = df["Target_CHEMBL_ID"] + df["target_symbol"] = df["target_chembl_id"].map( + TARGET_CHEMBL_TO_SYMBOL + ).fillna(df["target_chembl_id"]) + + df["target_fasta"] = df["Target_FASTA"] + df["selfies"] = df["Compound_SELFIES"] + + # Decode SELFIES -> SMILES + df["smiles_raw"] = df["selfies"].apply(safe_decode_selfies) + + # Add RDKit descriptors; for Papyrus we recompute SA/QED/LogP + df = add_common_rdkit_descriptors( + df, + smiles_col="smiles_raw", + recompute_sa_qed_logp=True, + ) + + out_all_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_all_path, index=False) + print(f"[INFO] Saved papyrus_all_processed to: {out_all_path}") + + # Build AKT1 and CDK2 subsets + akt1_df = df[df["target_symbol"] == "AKT1"].copy() + cdk2_df = df[df["target_symbol"] == "CDK2"].copy() + + akt1_df["source_type"] = "ref_akt1" + cdk2_df["source_type"] = "ref_cdk2" + + if drop_duplicate_smiles: + if "smiles_canonical" in akt1_df.columns: + akt1_df = akt1_df.drop_duplicates(subset=["smiles_canonical"]) + if "smiles_canonical" in cdk2_df.columns: + cdk2_df = cdk2_df.drop_duplicates(subset=["smiles_canonical"]) + + akt1_df.to_parquet(out_akt1_path, index=False) + cdk2_df.to_parquet(out_cdk2_path, index=False) + print(f"[INFO] Saved akt1_ref_processed to: {out_akt1_path}") + print(f"[INFO] Saved cdk2_ref_processed to: {out_cdk2_path}") + +# ------------------------------------------------------------------------- +# CLI & main +# ------------------------------------------------------------------------- + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments for the dataset preparation script. + """ + parser = argparse.ArgumentParser( + description="Prepare processed datasets (generated + Papyrus reference) " + "for Prot2Mol analysis." + ) + + default_raw_dir = REPO_ROOT / "data" / "raw" + default_processed_dir = REPO_ROOT / "data" / "processed" + + parser.add_argument( + "--raw-dir", + type=str, + default=str(default_raw_dir), + help=f"Base directory for raw data (default: {default_raw_dir})", + ) + parser.add_argument( + "--processed-dir", + type=str, + default=str(default_processed_dir), + help=f"Base directory for processed data (default: {default_processed_dir})", + ) + parser.add_argument( + "--generated-sep", + type=str, + default=DEFAULT_GENERATED_CSV_SEP, + help=f"CSV separator for generated datasets (default: '{DEFAULT_GENERATED_CSV_SEP}')", + ) + parser.add_argument( + "--papyrus-sep", + type=str, + default=DEFAULT_PAPYRUS_CSV_SEP, + help=f"CSV separator for Papyrus dataset (default: '{DEFAULT_PAPYRUS_CSV_SEP}')", + ) + parser.add_argument( + "--papyrus-file", + type=str, + default=None, + help="Path to Papyrus filtered CSV. " + "If not provided, defaults to /reference/papyrus_filtered.csv", + ) + + return parser.parse_args() + +def main() -> None: + """Entry point for the Stage 1 dataset preparation pipeline.""" + args = parse_args() + + raw_dir = Path(args.raw_dir).resolve() + proc_dir = Path(args.processed_dir).resolve() + + raw_generated_dir = raw_dir / "generated" + raw_reference_dir = raw_dir / "reference" + + proc_generated_dir = proc_dir / "generated" + proc_reference_dir = proc_dir / "reference" + + proc_generated_dir.mkdir(parents=True, exist_ok=True) + proc_reference_dir.mkdir(parents=True, exist_ok=True) + + # Papyrus file path (either user-specified or default) + papyrus_path = ( + Path(args.papyrus_file).resolve() + if args.papyrus_file is not None + else (raw_reference_dir / "papyrus_filtered.csv") + ) + + papyrus_all_proc_path = proc_reference_dir / "papyrus_all_processed.parquet" + akt1_ref_proc_path = proc_reference_dir / "akt1_ref_processed.parquet" + cdk2_ref_proc_path = proc_reference_dir / "cdk2_ref_processed.parquet" + + print("[INFO] Repository root :", REPO_ROOT) + print("[INFO] Raw data dir :", raw_dir) + print("[INFO] Processed data dir:", proc_dir) + print("[INFO] Raw generated dir :", raw_generated_dir) + print("[INFO] Raw reference dir :", raw_reference_dir) + print("[INFO] Papyrus raw file :", papyrus_path) + + print("[INFO] Starting Stage 1 dataset preparation...") + + # 1) Process generated datasets (Phase I/II/III, AKT1/CDK2) + process_all_generated_datasets( + raw_generated_dir=raw_generated_dir, + proc_generated_dir=proc_generated_dir, + generated_sep=args.generated_sep, + ) + + # 2) Process Papyrus reference dataset and its AKT1/CDK2 subsets + process_papyrus_reference( + raw_path=papyrus_path, + out_all_path=papyrus_all_proc_path, + out_akt1_path=akt1_ref_proc_path, + out_cdk2_path=cdk2_ref_proc_path, + papyrus_sep=args.papyrus_sep, + ) + + print("[INFO] Stage 1 dataset preparation completed.") + +if __name__ == "__main__": + main()