From 6efa304e9afe082b2f0164d8a98a01c6bdada833 Mon Sep 17 00:00:00 2001 From: Brandon Walker Date: Fri, 13 Jun 2025 12:39:14 -0400 Subject: [PATCH] rebase --- README.md | 15 + docker/Dockerfile.hcase | 50 +- docker/Singularity.def | 61 ++ environment.yml | 21 +- environment_cpu.yml | 19 +- examples/EmbedUSPTO/embed_uspto.py | 420 +++++++++-- hcase/fingerprints.py | 33 +- hcase/hcase.py | 357 +++++---- hcase/plot.py | 1088 ++++++++++++++++++++++++++++ hcase/substructures.py | 2 +- 10 files changed, 1849 insertions(+), 217 deletions(-) create mode 100644 docker/Singularity.def create mode 100644 hcase/plot.py diff --git a/README.md b/README.md index 9845783..d2dd4dd 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,21 @@ To test the HCASE method, you can run the `embed_uspto.py` script located in the This script will execute the embedding process and provide the results. +## Interactive Heatmap Plotting and Features + +The HCASE repository includes an interactive heatmap visualization tool for exploring the distribution of embedded chemical structures. This tool is implemented in `hcase/plot.py` and provides the following features: + +- **Automatic Deployment**: The Dash app for the heatmap visualization is automatically launched when running the USPTO embedding workflow (e.g., `examples/EmbedUSPTO/embed_uspto.py`). No separate launch step is required. +- **Frequency Heatmap**: Visualizes the frequency of structures in each region of the embedded space, using log-scaled bins for better dynamic range. +- **Color Scheme**: The heatmap uses a discrete, inverted color scale—light colors represent low frequency counts, and dark colors represent high frequency counts. +- **Plot Area Border**: The plot area (data region) is outlined with a clear border, making the data region visually distinct. Axis ticks and labels are positioned just outside this border. +- **Interactive Search**: Users can search for specific reference or target structures by SMILES or InChIKey, and matching points are highlighted on the heatmap. +- **Structure Display**: Clicking on a heatmap cell displays the reference structure and a grid of corresponding target structures, with options to highlight substructure overlaps. +- **Frequency Range Slider**: A slider allows users to filter the heatmap by frequency range, using log-scale bins. +- **Legend**: A color legend explains the mapping between frequency bins and colors. + +To use the heatmap tool, simply run the USPTO embedding workflow as described above. The Dash app will be deployed automatically, enabling detailed exploration of the chemical embedding space and helping users identify regions of interest, structure clusters, and outliers. + ## For Contributors We welcome contributions to the HCASE method! Here's how you can contribute: diff --git a/docker/Dockerfile.hcase b/docker/Dockerfile.hcase index 49e8d85..30b9bde 100644 --- a/docker/Dockerfile.hcase +++ b/docker/Dockerfile.hcase @@ -1,9 +1,47 @@ -FROM python:3.12-slim-bookworm +# Use a base image with CUDA and Python (adjust CUDA version as needed) +FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 -RUN apt-get update && apt-get install -y gcc +# Set environment variables to prevent some interactive prompts +ENV DEBIAN_FRONTEND=noninteractive -COPY ./requirements.txt requirements.txt -RUN pip install -r requirements.txt +# Install dependencies +RUN apt-get update && apt-get install -y \ + wget \ + git \ + curl \ + ca-certificates \ + build-essential \ + python3 \ + python3-pip \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* -RUN mkdir /app -WORKDIR /app +# Symlink python3 and pip3 +RUN ln -s /usr/bin/python3 /usr/bin/python && ln -s /usr/bin/pip3 /usr/bin/pip + +# Upgrade pip and install conda (Miniconda) +ENV CONDA_DIR=/opt/conda +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p $CONDA_DIR && \ + rm miniconda.sh +ENV PATH=$CONDA_DIR/bin:$PATH + +# IMPORTANT: Build context must be the repo root, not the docker/ folder! +# Example: docker build -f docker/Dockerfile.hcase . +# This will copy the entire repo into /workspace in the container. +COPY . /workspace + +# Create and activate environment +RUN conda env create -f /workspace/environment.yml && conda clean -a + +# Activate the environment by default +SHELL ["conda", "run", "-n", "base", "/bin/bash", "-c"] + +# Install cupy using pip or conda as needed (adjust CUDA version) +RUN pip install cupy-cuda12x + +# Set working directory inside container +WORKDIR /workspace + +# Default command +CMD ["python"] diff --git a/docker/Singularity.def b/docker/Singularity.def new file mode 100644 index 0000000..5a70d06 --- /dev/null +++ b/docker/Singularity.def @@ -0,0 +1,61 @@ +Bootstrap: docker +From: nvidia/cuda:12.2.0-devel-ubuntu22.04 + +%environment + export PATH=/opt/conda/bin:$PATH + export DEBIAN_FRONTEND=noninteractive + export WORKDIR=/workspace + +%post + apt-get update && apt-get install -y \ + wget \ + git \ + curl \ + ca-certificates \ + build-essential \ + python3 \ + python3-pip \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* + + # Symlink python3 and pip3 + ln -s /usr/bin/python3 /usr/bin/python || true + ln -s /usr/bin/pip3 /usr/bin/pip || true + + # Install Miniconda + wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh + bash miniconda.sh -b -p /opt/conda + rm miniconda.sh + + # Create workspace directory + mkdir -p /workspace + + # Create conda environment + /opt/conda/bin/conda env create -f /workspace/environment.yml + /opt/conda/bin/conda clean -a + + # Install cupy (adjust CUDA version if needed) + /opt/conda/bin/pip install cupy-cuda12x + +%files + aux_code /workspace/aux_code + data /workspace/data + examples /workspace/examples + hcase /workspace/hcase + log /workspace/log + plots /workspace/plots + tests /workspace/tests + workflows /workspace/workflows + environment.yml /workspace/environment.yml + requirements.txt /workspace/requirements.txt + README.md /workspace/README.md + LICENSE /workspace/LICENSE + Makefile /workspace/Makefile + pyproject.toml /workspace/pyproject.toml + setup.py /workspace/setup.py + DISCLAIMER /workspace/DISCLAIMER + NOTES /workspace/NOTES + +%runscript + cd /workspace + exec /opt/conda/bin/python "$@" diff --git a/environment.yml b/environment.yml index b62c6d5..5c4c833 100644 --- a/environment.yml +++ b/environment.yml @@ -3,14 +3,17 @@ channels: - conda-forge - anaconda dependencies: - - python==3.13.2 - - cupy==13.4.0 - - rdkit==2024.09.6 + - python==3.11 + - rdkit==2025.03.3 + - cupy==13.4.1 - pandarallel==1.6.5 - - ipywidgets==8.1.5 - - pip==25.0.1 - - numpy==2.2.3 - - tqdm==4.65.0 - - pytest + - ipywidgets==8.1.7 + - pip==25.1.1 + - numpy==2.3.0 + - tqdm==4.67.1 + - pytest==8.4.0 - pip: - - -e . + - -e . + - molplotly==1.1.7 + - dash==2.11.1 + - jupyter-dash==0.4.2 diff --git a/environment_cpu.yml b/environment_cpu.yml index 6816aeb..077f524 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -3,13 +3,16 @@ channels: - conda-forge - anaconda dependencies: - - python==3.13.2 - - rdkit==2024.09.6 + - python==3.11 + - rdkit==2025.03.3 - pandarallel==1.6.5 - - ipywidgets==8.1.5 - - pip==25.0.1 - - numpy==2.2.3 - - tqdm==4.65.0 - - pytest + - ipywidgets==8.1.7 + - pip==25.1.1 + - numpy==2.3.0 + - tqdm==4.67.1 + - pytest==8.4.0 - pip: - - -e . + - -e . + - molplotly==1.1.7 + - dash==2.11.1 + - jupyter-dash==0.4.2 diff --git a/examples/EmbedUSPTO/embed_uspto.py b/examples/EmbedUSPTO/embed_uspto.py index 7fa364d..790087b 100644 --- a/examples/EmbedUSPTO/embed_uspto.py +++ b/examples/EmbedUSPTO/embed_uspto.py @@ -2,78 +2,392 @@ import re from tqdm.auto import tqdm tqdm.pandas() -from hcase.fingerprints import compute_substructure_batched, compute_fingerprints_batched, undo_stringify_fingerprint +try: + import cupy as cp +except: + pass +from rdkit import Chem +from rdkit.Chem import inchi +import os +from pandarallel import pandarallel +pandarallel.initialize(progress_bar=True) +import numpy as np +import time +import argparse + +from hcase.fingerprints import ( + compute_substructure_batched, + compute_fingerprints_batched, + undo_stringify_fingerprint, +) from hcase import embed +from hcase.plot import plot_interactive_hs -def extract_reaction_smiles_from_csv(file_path): +def get_unique_fingerprint_indices_per_file(all_fps_dfs): """ - Extracts all ReactionSmiles from a CSV file. - - Args: - file_path (str): The path to the CSV file. - - Returns: - list: A list of extracted ReactionSmiles strings. + For each file, return the indices of fingerprints unique to that file + relative to all previous files (first file: all unique, next: unique vs previous, etc). + Returns a list of lists of indices (one list per file). """ - df = pd.read_csv(file_path, delimiter='\t') - return df['ReactionSmiles'].tolist() - + seen = set() + unique_indices_per_file = [] + for df in all_fps_dfs: + fps = df['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split()))) + indices = [] + for idx, fp in enumerate(fps): + if fp not in seen: + indices.append(idx) + seen.add(fp) + unique_indices_per_file.append(indices) + return unique_indices_per_file -def extract_product_smiles(reaction_smiles): +def report_fingerprint_intersections(all_fps_dfs, filenames): """ - Extracts the product SMILES from a reaction SMILES string by returning - everything after the last '>' symbol, and removes any unwanted annotations - like '|f:0.1,2.3,4.5,8.9|'. - - Args: - reaction_smiles (str): The reaction SMILES string. - - Returns: - str: The cleaned product SMILES string. + For each file, compute unique fingerprints and intersection across all files. + Report statistics and associated substructures. """ + # Convert fingerprints to hashable type (tuple of ints) + fps_sets = [] + fps_to_substructs = [] + for df in all_fps_dfs: + # Convert stringified fingerprint to tuple + fps = df['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split()))) + fps_sets.append(set(fps)) + # Map fingerprint tuple to substructure(s) + fp2sub = {} + for fp, sub in zip(fps, df['substructure']): + fp2sub.setdefault(fp, set()).add(sub) + fps_to_substructs.append(fp2sub) + + # Intersection across all files + intersection = set.intersection(*fps_sets) + print(f"\n=== Fingerprint Intersection Statistics ===") + print(f"Number of files: {len(filenames)}") + print(f"Fingerprints common to all files (intersection): {len(intersection)}") + if intersection: + # Collect all substructures associated with intersection fingerprints + substructs = set() + for fp in intersection: + for fp2sub in fps_to_substructs: + substructs.update(fp2sub.get(fp, [])) + substructs_list = list(substructs) + print("Associated substructures (intersection):") + for sub in substructs_list[:20]: + print(f" - {sub}") + if len(substructs_list) > 20: + print(f" ... ({len(substructs_list)-20} more not shown)") + + # Unique, intersection, and "shared with some but not all" fingerprints per file + unique_indices_per_file = get_unique_fingerprint_indices_per_file(all_fps_dfs) + for i, (df, indices, fname) in enumerate(zip(all_fps_dfs, unique_indices_per_file, filenames)): + # Print unique IDs, structures, substructures, and unique fingerprints + unique_ids = set(df['ID']) + unique_structures = set(df['structure']) + unique_substructures = set(df['substructure']) + unique_fingerprints = set(df['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split())))) + print(f"\nFile: {fname}") + print(f" Number of unique IDs: {len(unique_ids)}") + print(f" Number of unique structures: {len(unique_structures)}") + print(f" Number of unique substructures: {len(unique_substructures)}") + print(f" Number of unique fingerprints: {len(unique_fingerprints)}") + + fps = df['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split()))) + fps_set = set(fps) + # Unique: only in this file + others = set.union(*(set(df2['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split())))) for j, df2 in enumerate(all_fps_dfs) if j != i)) + unique_fps = fps_set - others + # Intersection: in all files + intersection_fps = set.intersection(*[set(df2['fingerprint'].apply(lambda x: tuple(map(int, x.strip('[]').replace('\n', '').split())))) for df2 in all_fps_dfs]) + # Shared with some but not all: in this file, not unique, not in intersection + shared_some_fps = fps_set - unique_fps - intersection_fps + + print(f"\nFile: {fname}") + print(f" Unique fingerprints: {len(unique_fps)}") + print(f" In intersection (common to all): {len(fps_set & intersection_fps)}") + print(f" Shared with some but not all: {len(shared_some_fps)}") + + if unique_fps: + substructs = set(df.loc[fps.isin(unique_fps), 'substructure']) + substructs_list = list(substructs) + print(" Associated substructures (unique):") + for sub in substructs_list[:20]: + print(f" - {sub}") + if len(substructs_list) > 20: + print(f" ... ({len(substructs_list)-20} more not shown)") + if shared_some_fps: + substructs = set(df.loc[fps.isin(shared_some_fps), 'substructure']) + substructs_list = list(substructs) + print(" Associated substructures (shared with some but not all):") + for sub in substructs_list[:20]: + print(f" - {sub}") + if len(substructs_list) > 20: + print(f" ... ({len(substructs_list)-20} more not shown)") + + +def canonicalize_rxn_smiles(rxn_smiles: str) -> str: + def canon_and_sort(smiles_part): + mols = [Chem.MolFromSmiles(smi) for smi in smiles_part.split('.') if smi] + canons = [Chem.MolToSmiles(mol, canonical=True) for mol in mols if mol is not None] + return '.'.join(sorted(canons)) + + parts = rxn_smiles.split('>') + if len(parts) == 3: + reactants, agents, products = parts + elif len(parts) == 2: + reactants, products = parts + agents = '' + else: + return None + + reactants = canon_and_sort(reactants) + agents = canon_and_sort(agents) + products = canon_and_sort(products) + + return f"{reactants}>{agents}>{products}" + + +def extract_reaction_smiles_from_csv(file_path, id_prefix="USPTO"): + # Auto-detect delimiter + import csv + with open(file_path, 'r', encoding='utf-8') as f: + sample = f.read(2048) + sniffer = csv.Sniffer() + delimiter = sniffer.sniff(sample).delimiter + df = pd.read_csv(file_path, delimiter=delimiter, usecols=['ReactionSmiles'], dtype={'ReactionSmiles': str}) + df = df.dropna().reset_index(drop=True) + df['ID'] = [f"{id_prefix}{i+1:07d}" for i in range(len(df))] + return df[['ID', 'ReactionSmiles']] + + +def extract_product_smiles(reaction_smiles): if '>' in reaction_smiles: - # Extract everything after the last '>' product_smiles = reaction_smiles.split('>')[-1] else: - # If no '>' is found, return the original string (could be a single molecule) product_smiles = reaction_smiles - - # Remove any unwanted parts such as '|f:0.1,2.3,4.5,8.9|' product_smiles = re.sub(r'\|f:[\d.,]+\|', '', product_smiles) - return product_smiles.strip() -file_name = '2001_Sep2016_USPTOapplications_smiles.rsmi' -substructure_file = 'substructures.csv' -fingerprint_file = 'fingerprints.csv' -skip_fingerprint_generation = False +def smiles_to_inchikey(smiles): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + try: + return inchi.MolToInchiKey(mol) + except: + return None -if not skip_fingerprint_generation: +def process(file_path, substructure_file='substructures.csv', + fingerprint_file='fingerprints.csv', + preprocessed_file='preprocessed_rxns.csv', + fps_inchikey_file='df_fps_with_inchikey.csv', id_prefix='USPTO'): - ref_smiles = extract_reaction_smiles_from_csv(file_name) - product_smiles = [extract_product_smiles(reaction_smiles) for reaction_smiles in ref_smiles] + # Get absolute directory path of the current file + base_dir = os.path.dirname(os.path.abspath(__file__)) - # at some point computing BMS was getting "Killed" signal so turned off parrallelization - compute_substructure_batched(product_smiles, csv_file=substructure_file, parallel=False) - compute_fingerprints_batched(substructure_file=substructure_file, csv_file=fingerprint_file, parallel=False) + # Update file paths to absolute paths relative to this script + file_path = os.path.join(base_dir, file_path) + substructure_file = os.path.join(base_dir, substructure_file) + fingerprint_file = os.path.join(base_dir, fingerprint_file) + preprocessed_file = os.path.join(base_dir, preprocessed_file) + fps_inchikey_file = os.path.join(base_dir, fps_inchikey_file) + # --- Preprocessing --- + if os.path.exists(preprocessed_file): + df_rxns = pd.read_csv(preprocessed_file) + original_count = None # Can't know original if only loading preprocessed + else: + df_rxns_raw = extract_reaction_smiles_from_csv(file_path, id_prefix) + original_count = len(df_rxns_raw) + print(f" [Preprocessing] Rows after loading CSV: {len(df_rxns_raw)}") + df_rxns = df_rxns_raw.copy() + df_rxns['canonical_rxn'] = df_rxns['ReactionSmiles'].parallel_apply(canonicalize_rxn_smiles) + print(f" [Preprocessing] Rows after canonicalization (dropna): {df_rxns['canonical_rxn'].notna().sum()}") + df_rxns = df_rxns.dropna(subset=['canonical_rxn']) + df_rxns['product_smiles'] = df_rxns['canonical_rxn'].apply(extract_product_smiles) + print(f" [Preprocessing] Rows after extracting product_smiles: {len(df_rxns)}") + df_rxns = df_rxns[df_rxns['product_smiles'].str.contains(r'^[^\.]+$')] # single-product only + print(f" [Preprocessing] Rows after filtering for single-product: {len(df_rxns)}") + df_rxns = df_rxns.drop_duplicates(subset=['ID']).reset_index(drop=True) + print(f" [Preprocessing] Rows after drop_duplicates on ID: {len(df_rxns)}") -df_fingerprints = pd.read_csv(fingerprint_file, usecols=['fingerprint']) -ref_fingerprints = df_fingerprints['fingerprint'].values -ref_fingerprints = undo_stringify_fingerprint(ref_fingerprints) -try: - import cupy as cp - use_cupy = cp.cuda.is_available() -except ImportError: - use_cupy = False -embedded_data = embed( - ref_fingerprints, - n_dim=2, - use_cupy=use_cupy, - target_batch_size=1000, - ref_batch_size=50000, - max_hc_order_only=True -) + print(f"Saving preprocessed data to {preprocessed_file}") + df_rxns.to_csv(preprocessed_file, index=False) + + + # --- Fingerprints --- + if os.path.exists(fps_inchikey_file): + print(f"Loading cached fingerprint+InChIKey data from {fps_inchikey_file}") + df_fps = pd.read_csv(fps_inchikey_file) + else: + if not os.path.exists(substructure_file) or not os.path.exists(fingerprint_file): + compute_substructure_batched(df_rxns['product_smiles'].tolist(), csv_file=substructure_file, parallel=False) + compute_fingerprints_batched(substructure_file=substructure_file, csv_file=fingerprint_file, parallel=False) + + df_fps = pd.read_csv(fingerprint_file, usecols=['structure', 'substructure', 'fingerprint']) + + structure_to_id = dict(zip(df_rxns['product_smiles'], df_rxns['ID'])) + df_fps['ID'] = df_fps['structure'].map(structure_to_id) + + print(f" [Fingerprinting] Rows in df_fps: {len(df_fps)}") + print(f" [Fingerprinting] Unique IDs in df_fps: {df_fps['ID'].nunique()}") + print(f" [Fingerprinting] Unique structures in df_fps: {df_fps['structure'].nunique()}") + + print("Computing InChIKeys...") + df_fps['InChIKey'] = df_fps['substructure'].parallel_apply(smiles_to_inchikey) + + print(f"Saving fingerprint+InChIKey data to {fps_inchikey_file}") + df_fps.to_csv(fps_inchikey_file, index=False) + + # --- Embedding --- + # (Embedding and sorting are now performed only after concatenation of all files in the main block.) + postprocessed_count = len(df_rxns) if 'df_rxns' in locals() else None + return original_count, postprocessed_count, df_fps + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='Process and embed reaction SMILES.') + parser.add_argument('--file_paths', type=str, nargs='+', required=True, help='List of input SMILES CSV files') + parser.add_argument('--special_file', type=str, default=None, help='File path of the special file whose InChIKeys will be highlighted') + args = parser.parse_args() + + all_fps_dfs = [] + substructure_sets = [] + substructure_filenames = [] + special_ref_inchikeys = None + timings = {} + start = time.time() + start_total = start + timings['Preprocessing'] = time.time() - start + for file_path in args.file_paths: + prefix = os.path.splitext(os.path.basename(file_path))[0] + substructure_file = f"{prefix}_substructures.csv" + fingerprint_file = f"{prefix}_fingerprints.csv" + preprocessed_file = f"{prefix}_preprocessed_rxns.csv" + fps_inchikey_file = f"{prefix}_df_fps_with_inchikey.csv" + + # Run up to fingerprint+InChIKey generation + original_count, postprocessed_count, df_fps = process( + file_path=file_path, + substructure_file=substructure_file, + fingerprint_file=fingerprint_file, + preprocessed_file=preprocessed_file, + fps_inchikey_file=fps_inchikey_file, + id_prefix=prefix + ) + + print(f"\n=== File: {file_path} ===") + if original_count is not None: + print(f" Original input count: {original_count}") + else: + print(f" Original input count: (unknown, loaded from preprocessed)") + print(f" Post-processed count: {postprocessed_count}") + + all_fps_dfs.append(df_fps) + substructure_sets.append(set(df_fps['substructure'])) + substructure_filenames.append(file_path) + # If this is the special file, extract its InChIKeys + if args.special_file is not None and os.path.abspath(file_path) == os.path.abspath(args.special_file): + special_ref_inchikeys = df_fps['InChIKey'].tolist() + + timings['Fingerprint + InChIKey'] = time.time() - start + + # Report fingerprint intersection and uniqueness statistics + report_fingerprint_intersections(all_fps_dfs, substructure_filenames) + + # Concatenate only unique fingerprints (relative to previous files) + unique_indices_per_file = get_unique_fingerprint_indices_per_file(all_fps_dfs) + filtered_dfs = [df.iloc[indices] for df, indices in zip(all_fps_dfs, unique_indices_per_file)] + combined_fps = pd.concat(filtered_dfs, ignore_index=True) + + # Compute special_unique_ref_inchikeys if special file is provided + special_unique_ref_inchikeys = None + if args.special_file is not None: + # Find index of special file in substructure_filenames + special_idx = None + for i, fname in enumerate(substructure_filenames): + if os.path.abspath(fname) == os.path.abspath(args.special_file): + special_idx = i + break + if special_idx is not None: + df_special = all_fps_dfs[special_idx] + unique_idxs = unique_indices_per_file[special_idx] + special_unique_ref_inchikeys = df_special.iloc[unique_idxs]['InChIKey'].tolist() + + # Embedding and plotting on the combined data + target_fingerprints = undo_stringify_fingerprint(combined_fps['fingerprint'].values) + target_ids = np.array(combined_fps['substructure'].values) + parent_target_ids = np.array(combined_fps['structure'].values) + parent_reaction_ids = np.array(combined_fps['ID'].values) + target_inchikeys = np.array(combined_fps['InChIKey'].values) + + sort_idx = np.lexsort(target_fingerprints.T[::-1]) + target_fingerprints = target_fingerprints[sort_idx] + target_ids = target_ids[sort_idx] + parent_target_ids = parent_target_ids[sort_idx] + parent_reaction_ids = parent_reaction_ids[sort_idx] + target_inchikeys = target_inchikeys[sort_idx] + + ref_fingerprints, unique_indices = np.unique(target_fingerprints, axis=0, return_index=True) + ref_ids = target_ids[unique_indices] + parent_ref_ids = parent_target_ids[unique_indices] + ref_inchikeys = target_inchikeys[unique_indices] + try: + use_cupy = cp.cuda.is_available() + except: + use_cupy = False + + + + + embedded, ref_only_embedded, hc_orders, ref_ids, closest_indices_all = embed( + target_fingerprints, + ref_fingerprints=ref_fingerprints, + ref_ids=ref_ids, + target_ids=target_ids, + n_dim=2, + use_cupy=use_cupy, + target_batch_size=1000, + ref_batch_size=40000, + max_hc_order_only=True + ) + + timings['Embedding'] = time.time() - start + + # --- Total --- + timings['Total'] = time.time() - start_total + + print("\n=== Timing Summary ===") + for step, secs in timings.items(): + print(f"{step:<25}: {secs:.2f} seconds") + + # Print all unique and common special InChIKeys before plotting + if special_ref_inchikeys is not None: + unique_set = set(special_unique_ref_inchikeys) if special_unique_ref_inchikeys is not None else set() + all_set = set(special_ref_inchikeys) + common_set = all_set - unique_set + print("\nSpecial Unique InChIKeys:") + for ik in sorted(unique_set): + print(f" {ik}") + print("\nSpecial Common InChIKeys:") + for ik in sorted(common_set): + print(f" {ik}") + + plot_interactive_hs( + embedded, + target_ids, + hc_orders, + ref_ids, + parent_target_ids, + parent_ref_ids, + closest_indices_all, + parent_reaction_ids, + target_inchikeys, + ref_inchikeys, + special_ref_inchikeys=special_ref_inchikeys, + special_unique_ref_inchikeys=special_unique_ref_inchikeys + ) diff --git a/hcase/fingerprints.py b/hcase/fingerprints.py index 62d2cc2..0a0223c 100644 --- a/hcase/fingerprints.py +++ b/hcase/fingerprints.py @@ -5,7 +5,6 @@ from pandarallel import pandarallel from tqdm import tqdm import os -import gc from hcase.substructures import smiles2bmscaffold @@ -41,7 +40,14 @@ def compute_substructure(structures: np.ndarray, method=smiles2bmscaffold): df_valid = df[df['substructure'] != 'NA'].copy() return df_valid -def compute_substructure_batched(structures, method=smiles2bmscaffold, batch_size=100000, csv_file='substructures.csv', parallel=True): +def compute_substructure_batched( + structures, + method=smiles2bmscaffold, + batch_size=100000, + csv_file='substructures.csv', + parallel=True, + deduplicate=True +): """Computes substructures in batches, appending to a CSV file and resuming from previous progress.""" # Check existing results @@ -49,27 +55,17 @@ def compute_substructure_batched(structures, method=smiles2bmscaffold, batch_siz if os.path.exists(csv_file): df_existing = pd.read_csv(csv_file, usecols=['structure']) processed_structures = set(df_existing['structure']) - # Free memory - del df_existing - gc.collect() - # Filter only new structures structures = [s for s in structures if s not in processed_structures] - # Free memory - del processed_structures - gc.collect() # If everything is processed, exit early if not structures: return # Process in batches - first_batch = not os.path.exists(csv_file) # If file exists, append instead + first_batch = not os.path.exists(csv_file) batches = [structures[i:i + batch_size] for i in range(0, len(structures), batch_size)] - # Free memory - del structures - gc.collect() for batch in tqdm(batches, desc="Processing substructure batches", unit="batch"): df = pd.DataFrame({'structure': batch}) @@ -83,6 +79,12 @@ def compute_substructure_batched(structures, method=smiles2bmscaffold, batch_siz df.to_csv(csv_file, mode='w' if first_batch else 'a', header=first_batch, index=False) first_batch = False + # Deduplicate structure→substructure, keeping only the first instance + if deduplicate: + df_all = pd.read_csv(csv_file) + df_dedup = df_all.drop_duplicates(subset='substructure', keep='first') + df_dedup.to_csv(csv_file, index=False) + def compute_fingerprints(df: pd.DataFrame, method=lambda x: smiles2scaffoldkey(x, return_as_string=True)): """Computes fingerprints from substructures and filters invalid ones.""" @@ -102,7 +104,6 @@ def compute_fingerprints_batched(substructure_file='substructures.csv', if os.path.exists(csv_file): for chunk in pd.read_csv(csv_file, usecols=['substructure'], dtype=str, chunksize=10000): processed_substructures.update(chunk['substructure'].values) - gc.collect() # Get total line count (excluding header) with open(substructure_file) as f: @@ -136,10 +137,6 @@ def compute_fingerprints_batched(substructure_file='substructures.csv', # Append results new_substructures.to_csv(csv_file, mode='w' if first_batch else 'a', header=first_batch, index=False) first_batch = False - - # Free memory - del new_substructures, df_substructures - gc.collect() # Update progress bar pbar.update(1) diff --git a/hcase/hcase.py b/hcase/hcase.py index 8a9951b..bf55617 100644 --- a/hcase/hcase.py +++ b/hcase/hcase.py @@ -13,9 +13,64 @@ import numpy as np from tqdm import tqdm import os -import pandas as pd +from typing import List, Dict, Tuple from hilbertcurve.hilbertcurve import HilbertCurve + +def _load_index_npz(cache_dir: str, index_filename: str = "index.npz"): + index_path = os.path.join(cache_dir, index_filename) + if os.path.exists(index_path): + data = np.load(index_path, allow_pickle=True) + identifiers = data["identifiers"] + batch_files = data["batch_files"] + row_in_batch = data["row_in_batch"] + return identifiers, batch_files, row_in_batch + else: + return np.array([], dtype=str), np.array([], dtype=str), np.array([], dtype=int) + +def _save_index_npz(identifiers, batch_files, row_in_batch, cache_dir: str, index_filename: str = "index.npz"): + index_path = os.path.join(cache_dir, index_filename) + np.savez(index_path, identifiers=identifiers, batch_files=batch_files, row_in_batch=row_in_batch) + +def _get_next_batch_number_npz(cache_dir: str) -> int: + existing = [f for f in os.listdir(cache_dir) if f.startswith("batch_") and f.endswith(".npz")] + if not existing: + return 1 + nums = [int(f.split("_")[1].split(".")[0]) for f in existing] + return max(nums) + 1 + +def _add_closest_indices_batch_npz(identifiers: np.ndarray, closest_indices: np.ndarray, cache_dir: str): + assert len(identifiers) == closest_indices.shape[0] + batch_num = _get_next_batch_number_npz(cache_dir) + batch_file = f"batch_{batch_num:05d}.npz" + batch_path = os.path.join(cache_dir, batch_file) + np.savez(batch_path, identifiers=identifiers, closest_indices=closest_indices) + return batch_file + +def _get_closest_indices_from_cache_npz(identifiers: np.ndarray, cache_dir: str, idx_identifiers: np.ndarray, idx_batch_files: np.ndarray, idx_row_in_batch: np.ndarray) -> Dict[str, np.ndarray]: + # Map: batch_file -> list of (identifier, row_in_batch) + result = {} + if len(idx_identifiers) == 0: + return result + # Find indices in index for requested identifiers + id_to_idx = {id_: i for i, id_ in enumerate(idx_identifiers)} + found = [id_ for id_ in identifiers if id_ in id_to_idx] + if not found: + return result + # Group by batch_file + batch_map = {} + for id_ in found: + idx = id_to_idx[id_] + batch_file = idx_batch_files[idx] + row = idx_row_in_batch[idx] + batch_map.setdefault(batch_file, []).append((id_, row)) + for batch_file, id_row_list in batch_map.items(): + batch_path = os.path.join(cache_dir, batch_file) + with np.load(batch_path, allow_pickle=True) as data: + batch_indices = data["closest_indices"] + for identifier, row_idx in id_row_list: + result[identifier] = batch_indices[row_idx] + return result from hcase.distance_metrics import weighted_minkowski_1_5_distance, euclidean_distance, manhattan_distance, cosine_distance @@ -28,6 +83,23 @@ def compute_max_phc_order(ref_scaffold_smiles: np.ndarray) -> int: + """ + Compute the maximum PHC order for a given set of reference scaffolds. + + The maximum PHC order is the highest power of 4 that is less than or equal to the number of reference scaffolds. + + The equation is Ceil(log(M, 4)), where M is the number of reference scaffolds. + + Parameters + ---------- + ref_scaffold_smiles : ndarray + A 1D numpy array of SMILES strings, where each element is a reference scaffold. + + Returns + ------- + int + The maximum PHC order. + """ log_base = 4 M = ref_scaffold_smiles.shape[0] max_z = math.ceil(math.log(M, log_base)) @@ -35,6 +107,29 @@ def compute_max_phc_order(ref_scaffold_smiles: np.ndarray) -> int: def compute_hilbert_embeddings(bucket_ids: np.ndarray, hc_order: int, n_dim: int): + """ + Compute the Hilbert space embeddings for given bucket IDs. + + This function calculates the coordinates of points in a multidimensional + Hilbert space using the specified Hilbert Curve order and dimensionality. + + Parameters + ---------- + bucket_ids : np.ndarray + A 1D numpy array of integers representing the bucket IDs. Each bucket ID + corresponds to a point on the Hilbert curve. + hc_order : int + The order of the Hilbert Curve, which determines its resolution. + n_dim : int + The number of dimensions of the Hilbert space. + + Returns + ------- + np.ndarray + A 2D numpy array where each row contains the coordinates of a point in + the Hilbert space corresponding to a bucket ID. + """ + hilbert_curve = HilbertCurve(hc_order, n_dim) embedded_hs_coordinates = np.array([np.array(hilbert_curve.point_from_distance(b - 1)) for b in bucket_ids]) return embedded_hs_coordinates @@ -42,143 +137,161 @@ def compute_hilbert_embeddings(bucket_ids: np.ndarray, hc_order: int, n_dim: int def process_buckets(bucket_size: int, closest_order: np.ndarray, hc_order: int, n_dim: int): + """ + Processes bucket information and computes Hilbert space embeddings. + + This function assigns bucket IDs to the given closest order values based + on the specified bucket size. It then computes the corresponding Hilbert + space coordinates for these bucket IDs using the provided Hilbert Curve + order and dimensionality. + + Parameters + ---------- + bucket_size : int + The size of each bucket used for assigning bucket IDs. + closest_order : np.ndarray + A 1D numpy array of order values for which bucket IDs need to be + assigned. + hc_order : int + The order of the Hilbert Curve, which determines its resolution. + n_dim : int + The number of dimensions of the Hilbert space. + + Returns + ------- + tuple + A tuple containing: + - bucket_ids: np.ndarray + A 1D numpy array of integers representing the assigned bucket IDs. + - embedded_hs_coordinates: np.ndarray + A 2D numpy array where each row contains the coordinates of a point + in the Hilbert space corresponding to a bucket ID. + """ + bucket_ids = (np.round(closest_order / bucket_size) + 1).astype(int) # Assign bucket ID embedded_hs_coordinates = compute_hilbert_embeddings(bucket_ids, hc_order, n_dim) - + return bucket_ids, embedded_hs_coordinates +def embed(target_fingerprints: np.ndarray, ref_fingerprints: np.ndarray = None, target_ids: List[str] = None, + ref_ids: List[str] = None, + n_dim: int = 2, + use_cupy=False, target_batch_size=1000, ref_batch_size=50000, hc_order=None, + distance_fn='minkowski_1_5', max_hc_order_only=True, cache_dir: str = "embedding_cache"): -def embed(ref_fingerprints: np.ndarray, n_dim: int = 2, - use_cupy=False, target_batch_size=1500, ref_batch_size=50000, hc_order=None, - distance_fn='minkowski_1_5', target_fingerprints: np.ndarray = None, - max_hc_order_only=True, output_dir="embedding_results"): - """ - Embeds new structures based on reference fingerprints. - """ - if distance_fn not in distance_fn_mapping: raise ValueError(f"Unknown distance function: {distance_fn}. Choose from {list(distance_fn_mapping.keys())}") distance_fn = distance_fn_mapping[distance_fn] + # 1. Sort target fingerprints & target_ids + sort_idx_target = np.lexsort(target_fingerprints.T[::-1]) + target_fingerprints = target_fingerprints[sort_idx_target] + if target_ids is not None: + target_ids = np.array(target_ids)[sort_idx_target] + else: + target_ids = np.array([str(i) for i in range(target_fingerprints.shape[0])]) + + # 2. Create ref_fingerprints and unique indices from sorted target_fingerprints + if ref_fingerprints is None: + ref_fingerprints, unique_indices = np.unique(target_fingerprints, axis=0, return_index=True) + if ref_ids is None and target_ids is not None: + ref_ids = target_ids[unique_indices] + if use_cupy: import cupy as cp - ref_fingerprints = cp.asarray(ref_fingerprints, dtype=cp.float32) - - if target_fingerprints is None: - target_fingerprints = ref_fingerprints + target_fingerprints = cp.asarray(target_fingerprints, dtype=cp.float32) num_targets = target_fingerprints.shape[0] num_refs = ref_fingerprints.shape[0] + if num_targets < target_batch_size: + target_batch_size = num_targets + if num_refs < ref_batch_size: + ref_batch_size = num_refs max_z = compute_max_phc_order(ref_fingerprints) hc_order_vals = [max_z] if max_hc_order_only else (range(2, max_z + 1) if hc_order is None else [hc_order]) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - for i in tqdm(range(0, num_targets, target_batch_size), desc="Processing target batches", unit="batch"): - all_exist = True - for hc_order_val in hc_order_vals: - if not os.path.exists(f"{output_dir}/embedding_batch_{i}_{hc_order_val}.csv"): - all_exist = False - if all_exist: - continue - - target_fingerprint_batch = target_fingerprints[i:i + target_batch_size] - - # Initialize storage for the closest reference index and min distances - closest_indices = np.full(target_fingerprint_batch.shape[0], -1, dtype=int) - min_distances = np.full(target_fingerprint_batch.shape[0], np.inf, dtype=np.float32) - - for j in tqdm(range(0, num_refs, ref_batch_size), desc="Processing reference batches"): - ref_fingerprint_batch = ref_fingerprints[j:j + ref_batch_size] - - if use_cupy: - target_fingerprint_batch_cp = cp.asarray(target_fingerprint_batch, dtype=cp.float32) - ref_fingerprint_batch_cp = cp.asarray(ref_fingerprint_batch, dtype=cp.float32) - distances_batch = distance_fn(target_fingerprint_batch_cp, ref_fingerprint_batch_cp, use_cupy=True).get() - else: - distances_batch = distance_fn(target_fingerprint_batch, ref_fingerprint_batch) - - # Find the closest reference for each target in this batch - batch_closest_indices = np.argmin(distances_batch, axis=0) - batch_min_distances = np.min(distances_batch, axis=0) - - batch_closest_indices += j # Adjust indices to global reference index space - - # Update only where we found a closer reference - update_mask = batch_min_distances < min_distances - min_distances[update_mask] = batch_min_distances[update_mask] - closest_indices[update_mask] = batch_closest_indices[update_mask] - - # Convert closest indices to hierarchical order - order = np.arange(1, num_refs + 1) - closest_order = order[closest_indices] - - - # Process hierarchical ordering - for hc_order_val in hc_order_vals: - bucket_nr = math.pow(math.pow(2, hc_order_val), n_dim) - bucket_size = float(num_refs / (bucket_nr - 1)) - bucket_ids, embedded_hs_coordinates = process_buckets(bucket_size, closest_order, hc_order_val, n_dim) - - # Save batch results to disk - df = pd.DataFrame({ - 'closest_order': closest_order, - 'bucket_ids': bucket_ids, - 'hc_order_val': [hc_order_val] * len(closest_order) - }) - - for dim in range(embedded_hs_coordinates.shape[1]): - df[f'coord_{dim}'] = embedded_hs_coordinates[:, dim] - - df.to_csv(f"{output_dir}/embedding_batch_{i}_{hc_order_val}.csv", index=False) - - # Combine final results by hc_order - combined_results_by_hc_order = {} - - for filename in os.listdir(output_dir): - if filename.endswith(".csv"): - df = pd.read_csv(os.path.join(output_dir, filename)) - - # Iterate over each unique hc_order_val in the batch - for hc_order_val in np.unique(df['hc_order_val'].to_numpy()): - if hc_order_val not in combined_results_by_hc_order: - # Initialize a new entry for this hc_order_val - combined_results_by_hc_order[hc_order_val] = { - 'closest_order': [], - 'bucket_ids': [], - 'embedded_hs_coordinates': [], - 'hc_order_val': [] - } - - # Append results for this hc_order_val - mask = df['hc_order_val'] == hc_order_val - combined_results_by_hc_order[hc_order_val]['closest_order'].append(df[mask]['closest_order'].to_numpy()) - combined_results_by_hc_order[hc_order_val]['bucket_ids'].append(df[mask]['bucket_ids'].to_numpy()) - - # hc_order_val is the same across the batch, so we can just repeat it - hc_order_val_array = np.full(df[mask]['hc_order_val'].shape, hc_order_val) - combined_results_by_hc_order[hc_order_val]['hc_order_val'].append(hc_order_val_array) - - # Concatenate coordinates (all columns that start with 'coord_') - embedded_hs_coordinates = df[mask][[col for col in df.columns if col.startswith('coord_')]].to_numpy() - combined_results_by_hc_order[hc_order_val]['embedded_hs_coordinates'].append(embedded_hs_coordinates) - - # Now, concatenate arrays within each hc_order - final_combined_results = { - 'closest_order': [], - 'bucket_ids': [], - 'embedded_hs_coordinates': [], - 'hc_order_val': [] - } - - for hc_order_val, result in combined_results_by_hc_order.items(): - final_combined_results['closest_order'].append(np.concatenate(result['closest_order'])) - final_combined_results['bucket_ids'].append(np.concatenate(result['bucket_ids'])) - final_combined_results['embedded_hs_coordinates'].append(np.concatenate(result['embedded_hs_coordinates'])) - final_combined_results['hc_order_val'].append(np.concatenate(result['hc_order_val'])) - - return (final_combined_results['closest_order'], final_combined_results['bucket_ids'], - final_combined_results['embedded_hs_coordinates'], final_combined_results['hc_order_val']) \ No newline at end of file + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + idx_identifiers, idx_batch_files, idx_row_in_batch = _load_index_npz(cache_dir) + else: + idx_identifiers = np.array([], dtype=str) + idx_batch_files = np.array([], dtype=str) + idx_row_in_batch = np.array([], dtype=int) + + # Determine which target_ids are already cached + id_to_idx = {id_: i for i, id_ in enumerate(idx_identifiers)} + cached_mask = np.array([id_ in id_to_idx for id_ in target_ids]) + uncached_mask = ~cached_mask + + num_cached = np.sum(cached_mask) + num_uncached = np.sum(uncached_mask) + print(f"[HCASE] {num_cached} targets found in cache, {num_uncached} to compute.") + + # Prepare arrays for all results + closest_indices_all = np.full(num_targets, -1, dtype=int) + + # Fill in cached results + if np.any(cached_mask): + cached_ids = target_ids[cached_mask] + print(f"[HCASE] Loading cached closest_indices for {len(cached_ids)} targets.") + cached_closest_indices = _get_closest_indices_from_cache_npz(cached_ids, cache_dir, idx_identifiers, idx_batch_files, idx_row_in_batch) + for i, id_ in enumerate(target_ids): + if cached_mask[i] and id_ in cached_closest_indices: + closest_indices_all[i] = cached_closest_indices[id_] + + # Compute for uncached targets + uncached_indices = np.where(uncached_mask)[0] + if len(uncached_indices) > 0: + print(f"[HCASE] Computing closest_indices for {len(uncached_indices)} uncached targets...") + for batch_start in tqdm(range(0, len(uncached_indices), target_batch_size), desc="Processing uncached batches", unit="batch"): + batch_idx = uncached_indices[batch_start:batch_start + target_batch_size] + target_fingerprint_batch = target_fingerprints[batch_idx] + batch_size = target_fingerprint_batch.shape[0] + min_distances = np.full(batch_size, np.inf, dtype=np.float32) + closest_indices = np.full(batch_size, -1, dtype=int) + for j in tqdm(range(0, num_refs, ref_batch_size), desc="Processing reference batches"): + ref_fingerprint_batch = ref_fingerprints[j:j + ref_batch_size] + if use_cupy: + target_fingerprint_batch_cp = cp.asarray(target_fingerprint_batch, dtype=cp.float32) + ref_fingerprint_batch_cp = cp.asarray(ref_fingerprint_batch, dtype=cp.float32) + distances_batch = distance_fn(target_fingerprint_batch_cp, ref_fingerprint_batch_cp, use_cupy=True).get() + else: + distances_batch = distance_fn(target_fingerprint_batch, ref_fingerprint_batch) + batch_closest_indices = np.argmin(distances_batch, axis=0) + j + batch_min_distances = np.min(distances_batch, axis=0) + update_mask = batch_min_distances < min_distances + min_distances[update_mask] = batch_min_distances[update_mask] + closest_indices[update_mask] = batch_closest_indices[update_mask] + # Save this batch to cache + batch_target_ids = target_ids[batch_idx] + batch_file = _add_closest_indices_batch_npz(batch_target_ids, closest_indices, cache_dir) + # Update index arrays + new_identifiers = np.concatenate([idx_identifiers, batch_target_ids]) + new_batch_files = np.concatenate([idx_batch_files, np.array([batch_file] * len(batch_target_ids), dtype=str)]) + new_row_in_batch = np.concatenate([idx_row_in_batch, np.arange(len(batch_target_ids), dtype=int)]) + _save_index_npz(new_identifiers, new_batch_files, new_row_in_batch, cache_dir) + # Update in-memory index for this run + idx_identifiers = new_identifiers + idx_batch_files = new_batch_files + idx_row_in_batch = new_row_in_batch + # Fill in results + closest_indices_all[batch_idx] = closest_indices + + order = np.arange(1, num_refs + 1) + closest_order = order[closest_indices_all] + all_embeddings = [] + ref_only_embeddings = [] + all_hc_orders = [] + + for hc_order_val in hc_order_vals: + bucket_nr = math.pow(2, hc_order_val) ** n_dim + bucket_size = num_refs / (bucket_nr - 1) + bucket_ids, embedded_hs_coordinates = process_buckets(bucket_size, closest_order, hc_order_val, n_dim) + bucket_ids, ref_embedded_hs_coordinates = process_buckets(bucket_size, order, hc_order_val, n_dim) + all_embeddings.append(embedded_hs_coordinates) + ref_only_embeddings.append(ref_embedded_hs_coordinates) + all_hc_orders.append(hc_order_val) + + return all_embeddings, ref_only_embeddings, all_hc_orders, ref_ids, closest_indices_all diff --git a/hcase/plot.py b/hcase/plot.py new file mode 100644 index 0000000..af19bb4 --- /dev/null +++ b/hcase/plot.py @@ -0,0 +1,1088 @@ +import base64 +from io import BytesIO +from PIL import Image, ImageDraw, ImageFont +import pandas as pd +import numpy as np +import plotly.graph_objects as go +from dash import Dash, dcc, html +from dash.dependencies import Input, Output +from rdkit import Chem +from rdkit.Chem import rdDepictor +from rdkit.Chem.Draw import rdMolDraw2D +from rdkit.Chem import rdFMCS +import seaborn as sns + +TRANSPARENT_PNG_BASE64 = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8Xw8AAr8B9FzEBq8AAAAASUVORK5CYII=" + +def create_custom_linear_bins(freq_matrix, n_bins): + """ + Create custom bins for frequency: + - Bin 0: [0, 0] + - Bin 1: (0, 1] + - Bin 2: (1, 2) + - Bin 3: [2, 10) + - Bin 4: [10, 100) + - Bin 5: [100, 1000) + - Bin 6: [1000, 10000) + - Bin 7: [10000, inf) + """ + return np.array([0, 1, 2, 10, 100, 1000, 10000, np.inf]) + +def generate_heatmap_figure( + agg_combined, plot_type, freq_range, x_centers, y_centers, discrete_colorscale, n_bins, + special_ref_inchikeys=None, special_unique_ref_inchikeys=None +): + # Helper to create log-spaced bins, avoids zero by starting from min_bin_val + + # Step 1: Aggregate frequency by bin + freq_df = agg_combined.groupby(['y_bin', 'x_bin'])["frequency"].sum().unstack(fill_value=0) + freq_matrix = freq_df.values + + # Build a 2D array of ref_ids matching freq_matrix shape + refid_df = agg_combined.groupby(['y_bin', 'x_bin'])["ref_id"].first().unstack(fill_value=None) + refid_matrix = refid_df.values + + # Build a 2D array of base64 images for those ref_ids, but always use blank image for speed + img_matrix = [] + for row in refid_matrix: + img_row = [] + for _ in row: + img_b64 = TRANSPARENT_PNG_BASE64 + img_row.append(img_b64) + img_matrix.append(img_row) + + # Step 2: Filter bins based on aggregated frequency (after binning) + if freq_range is not None: + min_freq, max_freq = freq_range + mask = (freq_matrix >= min_freq) & (freq_matrix <= max_freq) + else: + mask = np.ones_like(freq_matrix, dtype=bool) + + # Step 3: Apply mask to the frequency matrix + masked_freq_matrix = np.where(mask, freq_matrix, np.nan) + + heatmap_matrix = masked_freq_matrix + + # Step 5: Bin the heatmap values into discrete color bins + val_min = np.nanmin(masked_freq_matrix) + val_max = np.nanmax(masked_freq_matrix) + + bins = create_custom_linear_bins(freq_matrix, n_bins) + + # Assign bin 0 to all values that are exactly zero, rest use normal binning + bin_indices = np.digitize(heatmap_matrix, bins) - 1 + bin_indices = np.where(heatmap_matrix == 0, 0, bin_indices) # bin 0 for zeros + + # Replace NaNs in heatmap_matrix with bin 0 (lowest bin) + bin_indices = np.where(np.isnan(heatmap_matrix), 0, bin_indices) + + # Build customdata as [[ [freq, img_b64], ...], ... ] + customdata = [] + for i in range(len(freq_matrix)): + row = [] + for j in range(len(freq_matrix[i])): + freq = freq_matrix[i][j] + img_b64 = img_matrix[i][j] + row.append([freq, img_b64]) + customdata.append(row) + + # Step 6: Create Plotly heatmap + fig = go.Figure(go.Heatmap( + z=bin_indices.tolist(), + x=x_centers.tolist(), + y=y_centers.tolist(), + colorscale=discrete_colorscale, + zmin=0, + zmax=n_bins - 1, + zauto=False, + showscale=False, + customdata=customdata, + hovertemplate=( + "dim_1: %{x}
" + "dim_2: %{y}
" + "frequency: %{customdata[0]}" + ), + name="" + )) + + # Add special markers for ref_inchikeys if provided + if special_ref_inchikeys is not None: + # Determine which are unique and which are common + unique_set = set(special_unique_ref_inchikeys) if special_unique_ref_inchikeys is not None else set() + all_set = set(special_ref_inchikeys) + common_set = all_set - unique_set + + # Plot unique (red star) + if unique_set: + unique_points = agg_combined[agg_combined['ref_inchikey'].isin(unique_set)] + if not unique_points.empty: + unique_x = [x_centers[int(row['x_bin'])] for _, row in unique_points.iterrows()] + unique_y = [y_centers[int(row['y_bin'])] for _, row in unique_points.iterrows()] + fig.add_trace(go.Scatter( + x=unique_x, + y=unique_y, + mode='markers', + marker=dict( + size=24, + color='red', + symbol='star', + opacity=1.0, + line=dict(width=3, color='black') + ), + name='Special Unique InChIKeys', + showlegend=True + )) + + # Plot common (green star) + if common_set: + common_points = agg_combined[agg_combined['ref_inchikey'].isin(common_set)] + if not common_points.empty: + common_x = [x_centers[int(row['x_bin'])] for _, row in common_points.iterrows()] + common_y = [y_centers[int(row['y_bin'])] for _, row in common_points.iterrows()] + fig.add_trace(go.Scatter( + x=common_x, + y=common_y, + mode='markers', + marker=dict( + size=24, + color='green', + symbol='star', + opacity=1.0, + line=dict(width=3, color='black') + ), + name='Special Common InChIKeys', + showlegend=True + )) + + fig.update_layout( + title={ + 'text': f"Heatmap of {plot_type}", + 'x': 0.5, + 'xanchor': 'center', + }, + xaxis_title="dim_1", + yaxis_title="dim_2", + yaxis_autorange='reversed', + xaxis=dict( + showline=True, + linecolor='black', + linewidth=2, + mirror=True, + ticks='outside' + ), + yaxis=dict( + showline=True, + linecolor='black', + linewidth=2, + mirror=True, + ticks='outside' + ) + ) + + # Step 7: Return bin edges and ranges + freq_bins = bins + min_freq = val_min + max_freq = val_max + + return fig, freq_bins, min_freq, max_freq + + +def rgb_tuple_to_str(rgb_tuple): + r, g, b = [int(255 * x) for x in rgb_tuple] + return f'rgb({r},{g},{b})' + +def add_label_inside_image(img, label, label_height=30): + width, height = img.size + draw = ImageDraw.Draw(img) + + try: + font = ImageFont.truetype("arial.ttf", 14) + except IOError: + font = ImageFont.load_default() + + bbox = draw.textbbox((0, 0), label, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + text_x = (width - text_width) // 2 + text_y = height - label_height + (label_height - text_height) // 2 # inside bottom + + draw.text((text_x, text_y), label, fill='black', font=font) + return img + +def get_mcs_match(parent_mol, child_mol, min_atoms_ratio=0.5): + mcs_result = rdFMCS.FindMCS([parent_mol, child_mol]) + if mcs_result.canceled or not mcs_result.smartsString: + return None + + mcs_mol = Chem.MolFromSmarts(mcs_result.smartsString) + parent_match = parent_mol.GetSubstructMatch(mcs_mol) + child_match = child_mol.GetSubstructMatch(mcs_mol) + + if len(child_match) < min_atoms_ratio * child_mol.GetNumAtoms(): + print(f"MCS too small: {len(child_match)} < {min_atoms_ratio * child_mol.GetNumAtoms()}", flush=True) + return None + + return parent_match, child_match + + +def generate_highlighted_overlap_image(parent_smiles, child_smiles, show_atom_mapping = False, min_atoms_ratio = .5): + parent_mol = Chem.MolFromSmiles(parent_smiles) + child_mol = Chem.MolFromSmiles(child_smiles) + + if not show_atom_mapping: + for atom in parent_mol.GetAtoms(): + atom.SetAtomMapNum(0) + for atom in child_mol.GetAtoms(): + atom.SetAtomMapNum(0) + + result = get_mcs_match(parent_mol, child_mol, min_atoms_ratio=min_atoms_ratio) + if result is None: + print(f"[MCS TOO SMALL or NOT FOUND] Parent: {parent_smiles} | Child: {child_smiles}", flush=True) + return generate_image_from_smiles(parent_smiles) + + rdDepictor.Compute2DCoords(parent_mol) + parent_match, _ = result + highlight_atoms = list(parent_match) + highlight_color = (0.0, 1.0, 0.0) + highlight_dict = {idx: highlight_color for idx in highlight_atoms} + + drawer = rdMolDraw2D.MolDraw2DCairo(200, 200) + drawer.DrawMolecule(parent_mol, highlightAtoms=highlight_atoms, highlightAtomColors=highlight_dict) + drawer.FinishDrawing() + img_bytes = drawer.GetDrawingText() + img = Image.open(BytesIO(img_bytes)).convert("RGBA") + return img + + +def generate_highlighted_grid_image( + parent_smiles_list, + child_smiles_list, + parent_reaction_ids, + page, + images_per_row, + rows_per_grid, + image_size, + show_atom_mapping=False +): + if len(parent_smiles_list) != len(child_smiles_list): + raise ValueError("parent_smiles_list and child_smiles_list must have the same length.") + + images_per_grid = images_per_row * rows_per_grid + start = page * images_per_grid + end = start + images_per_grid + + parent_subset = parent_smiles_list[start:end] + child_subset = child_smiles_list[start:end] + reaction_subset = parent_reaction_ids[start:end] if parent_reaction_ids else [None] * len(parent_subset) + + images = [] + + for i, (parent_smiles, child_smiles) in enumerate(zip(parent_subset, child_subset)): + parent_mol = Chem.MolFromSmiles(parent_smiles) + child_mol = Chem.MolFromSmiles(child_smiles) + + if not show_atom_mapping: + for atom in parent_mol.GetAtoms(): + atom.SetAtomMapNum(0) + for atom in child_mol.GetAtoms(): + atom.SetAtomMapNum(0) + + result = get_mcs_match(parent_mol, child_mol, min_atoms_ratio=0.5) + if result is None: + print(f"[NO MATCH or MCS TOO SMALL] Parent: {parent_smiles}\nChild: {child_smiles}", flush=True) + images.append(generate_image_from_smiles(child_smiles)) + continue + + parent_match, _ = result + highlight_atoms = list(parent_match) + highlight_color = (0.0, 1.0, 0.0) + highlight_dict = {idx: highlight_color for idx in highlight_atoms} + + drawer = rdMolDraw2D.MolDraw2DCairo(image_size, image_size) + drawer.DrawMolecule(parent_mol, highlightAtoms=highlight_atoms, highlightAtomColors=highlight_dict) + drawer.FinishDrawing() + + img_bytes = drawer.GetDrawingText() + img = Image.open(BytesIO(img_bytes)).convert("RGBA") + if reaction_subset[i]: + img = add_label_inside_image(img, str(reaction_subset[i])) + + images.append(img) + + rows = (len(images) + images_per_row - 1) // images_per_row + grid_img = Image.new('RGBA', (images_per_row * image_size, rows * image_size), (255, 255, 255, 0)) + for i, img in enumerate(images): + if img: + x_pos = (i % images_per_row) * image_size + y_pos = (i // images_per_row) * image_size + grid_img.paste(img, (x_pos, y_pos)) + + return grid_img + + +def make_normal_grid_image( + smiles_list, + page, + images_per_row, + rows_per_grid, + image_size, +): + images_per_grid = images_per_row * rows_per_grid + start = page * images_per_grid + end = start + images_per_grid + subset = smiles_list[start:end] + + if not subset: + return TRANSPARENT_PNG_BASE64 + + images = [generate_image_from_smiles(smi) for smi in subset if smi] + rows = (len(images) + images_per_row - 1) // images_per_row + + grid_img = Image.new('RGBA', (images_per_row * image_size, rows * image_size), (255, 255, 255, 0)) + for i, img in enumerate(images): + if img: + x_pos = (i % images_per_row) * image_size + y_pos = (i // images_per_row) * image_size + grid_img.paste(img, (x_pos, y_pos)) + + return grid_img + + + +def generate_image_from_smiles(smiles, show_atom_mapping=False): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"Invalid SMILES for image generation: {smiles}") + return None + if not show_atom_mapping: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + rdDepictor.Compute2DCoords(mol) + drawer = rdMolDraw2D.MolDraw2DCairo(200, 200) + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + img_bytes = drawer.GetDrawingText() + img = Image.open(BytesIO(img_bytes)).convert("RGBA") + return img + +def image_to_base64(img): + buffer = BytesIO() + img.save(buffer, format='PNG') + return f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode()}" + +def find_nearest_bin(center_array, val): + idx = np.abs(center_array - val).argmin() + return idx + +def plot_interactive_hs( + embedded, target_ids, hc_orders, ref_ids, parent_target_ids, parent_ref_ids, closest_indices_all, parent_reaction_ids, target_inchikeys, + ref_inchikeys, special_ref_inchikeys=None, special_unique_ref_inchikeys=None, start_port=8050, max_mols_per_point=500, images_per_row = 5, rows_per_grid = 4, image_size = 200 ): + idx = 0 + data_by_hc_order = {} + + for i, coords in enumerate(embedded): + hc_order_val = hc_orders[i] + + rows = [] + for j in range(coords.shape[0]): + dim_1, dim_2 = map(float, coords[j]) + ref_idx = closest_indices_all[idx] + rows.append({ + "target_id": target_ids[idx], + "parent_target_id": parent_target_ids[idx], + "ref_id": ref_ids[ref_idx], + "parent_ref_id": parent_ref_ids[ref_idx], + "dim_1": dim_1, + "dim_2": dim_2, + "hc_order": hc_order_val, + "target_inchikey": target_inchikeys[idx], + "ref_inchikey": ref_inchikeys[ref_idx], + "parent_reaction_id": parent_reaction_ids[idx], + }) + idx += 1 + data_by_hc_order.setdefault(hc_order_val, []).extend(rows) + + for i, (hc_order_val, rows) in enumerate(data_by_hc_order.items()): + df = pd.DataFrame(rows) + + # Compute frequency as the number of targets per (dim_1, dim_2, ref_id) + freq = ( + df.groupby(["dim_1", "dim_2", "ref_id"]) + .size() + .reset_index(name="frequency") + ) + + agg_targets = ( + df.groupby(["dim_1", "dim_2", "ref_id"]) + .apply(lambda group: pd.DataFrame({ + "target_id_list": [list(group["target_id"])], + "parent_target_id_list": [list(group["parent_target_id"])], + "parent_reaction_id_list": [list(group["parent_reaction_id"])] + })) + .reset_index() + .drop(columns=["level_3"]) # adjust if needed + ) + + df_ref_inchikeys = df[['ref_id', 'ref_inchikey']].drop_duplicates() + + agg_ref = ( + df.groupby(["dim_1", "dim_2", "ref_id"]) + .agg({ + "parent_ref_id": "first" + }) + .reset_index() + ) + agg_ref = agg_ref.merge(df_ref_inchikeys, on='ref_id', how='left') + + agg_combined = agg_ref.merge(agg_targets, on=["dim_1", "dim_2", "ref_id"], how="outer") + + agg_combined = agg_combined.merge( + df[["dim_1", "dim_2", "ref_id", "parent_reaction_id"]], + on=["dim_1", "dim_2", "ref_id"], + how="left" + ) + + agg_combined = agg_combined.merge(freq, on=["dim_1", "dim_2", "ref_id"], how="left") + + target_inchikey_group = df.groupby(["dim_1", "dim_2"])["target_inchikey"].unique().reset_index() + target_inchikey_group.rename(columns={"target_inchikey": "target_inchikey_list"}, inplace=True) + agg_combined = agg_combined.merge(target_inchikey_group, on=["dim_1", "dim_2"], how="left") + + # Remove duplicate rows caused by merges + agg_combined = agg_combined.drop_duplicates(subset=["dim_1", "dim_2", "ref_id"]) + + # Save selected fields to CSV with hc_order in filename + save_fields = ["dim_1", "dim_2", "ref_id", "parent_ref_id", "ref_inchikey", "parent_reaction_id", "frequency"] + csv_filename = f"agg_combined_hc_order_{hc_order_val}.csv" + agg_combined[save_fields].to_csv(csv_filename, index=False) + print(f"Saved {csv_filename} with selected fields.") + + # Clip target lists to max_mols_per_point + agg_combined["target_id_list"] = agg_combined["target_id_list"].apply(lambda lst: lst[:max_mols_per_point] if isinstance(lst, list) else []) + agg_combined["parent_target_id_list"] = agg_combined["parent_target_id_list"].apply(lambda lst: lst[:max_mols_per_point] if isinstance(lst, list) else []) + agg_combined["parent_reaction_id_list"] = agg_combined["parent_reaction_id_list"].apply(lambda lst: lst[:max_mols_per_point] if isinstance(lst, list) else []) + + bin_size = 1 + max_dim = 2 ** hc_order_val + + x_min = 0 + x_max = min(agg_combined.dim_1.max(), max_dim) + y_min = 0 + y_max = min(agg_combined.dim_2.max(), max_dim) + + x_bins = np.arange(x_min, x_max + bin_size, bin_size) + y_bins = np.arange(y_min, y_max + bin_size, bin_size) + + agg_combined['x_bin'] = np.digitize(agg_combined.dim_1, x_bins) - 1 + agg_combined['x_bin'] = agg_combined['x_bin'].clip(0, len(x_bins) - 2) + agg_combined['y_bin'] = np.digitize(agg_combined.dim_2, y_bins) - 1 + agg_combined['y_bin'] = agg_combined['y_bin'].clip(0, len(y_bins) - 2) + + x_centers = x_bins[:-1] + bin_size / 2 + y_centers = y_bins[:-1] + bin_size / 2 + + canonical_smiles_to_coord = {} + + # Build bin_data_map with keys = (y_bin, x_bin, ref_id) + bin_data_map = {} + for _, row in agg_combined.iterrows(): + smiles = row.get('ref_id') + if pd.notnull(smiles): + canonical_smiles_to_coord[smiles] = (row['dim_1'], row['dim_2']) + key = (row['y_bin'], row['x_bin'], row['ref_id']) + if key not in bin_data_map: + bin_data_map[key] = { + 'ref_id': row['ref_id'], + 'target_id_list': row['target_id_list'], + 'parent_ref_id': row['parent_ref_id'], + 'parent_target_id_list': row['parent_target_id_list'], + 'parent_reaction_id_list': row['parent_reaction_id_list'], + 'parent_reaction_id': row['parent_reaction_id'] + } + + n_bins = 10 + # We'll need freq_matrix to create bins, so generate it here + freq_df = agg_combined.groupby(['y_bin', 'x_bin'])["frequency"].sum().unstack(fill_value=0) + freq_matrix = freq_df.values + bins = create_custom_linear_bins(freq_matrix, n_bins) + n_bins = len(bins) - 1 + colors = sns.color_palette("cubehelix", n_colors=n_bins - 1) + colors = colors[::-1] # Invert color scheme: light = low, dark = high + # First color is white for bin 0 (frequency == 0) + discrete_colorscale = [(0.0, "rgb(255,255,255)")] + discrete_colorscale += [ + ((i + 1) / (n_bins - 1), rgb_tuple_to_str(color)) + for i, color in enumerate(colors) + ] + # Only generate the frequency plot + fig, freq_bins, min_raw_freq, max_raw_freq = generate_heatmap_figure( + agg_combined, "frequency", None, + x_centers, y_centers, discrete_colorscale, n_bins, + special_ref_inchikeys=special_ref_inchikeys, + special_unique_ref_inchikeys=special_unique_ref_inchikeys + ) + + figs = {"frequency": fig} + legend_data = { + "frequency": { + "colors": ["rgb(255,255,255)"] + [rgb_tuple_to_str(c) for c in colors], + "bins": freq_bins + } + } + + app = Dash(__name__) + + def to_human_notation(val): + if val == 0: + return "0" + elif val < 1: + return f"{val:.0f}" + elif val < 1_000: + return f"{int(val)}" + elif val < 1_000_000: + return f"{int(val/1_000)}K" + elif val < 1_000_000_000: + return f"{int(val/1_000_000)}M" + else: + return f"{int(val/1_000_000_000)}B" + + # Show marks for all bins using just the bin edge values + slider_marks = {i: to_human_notation(legend_data['frequency']['bins'][i]) for i in range(n_bins)} + + # Reverse marks for the grid slider so 0 is at the top + def make_reversed_marks(max_page): + return {i: str(max_page - i) for i in range(max_page + 1)} + + app.layout = html.Div([ + dcc.Store(id='hover-coords', data={'x': 0, 'y': 0}), + dcc.Store(id='mouse-coords', data={'x': 0, 'y': 0}), + html.Img( + id='hover-img', + src=TRANSPARENT_PNG_BASE64, + style={ + 'position': 'fixed', + 'top': '0px', + 'left': '0px', + 'display': 'none', + 'zIndex': 1000, + 'border': '2px solid #333', + 'backgroundColor': 'white', + 'height': '120px' + } + ), + # dcc.Store removed, now using clickData + html.Div([ + html.H2( + f"HS Coordinates (hc_order = {hc_order_val})", + style={"textAlign": "center", "marginBottom": "10px"} + ), + + html.Div([ + dcc.RadioItems( + id='mode-toggle', + options=[ + {'label': 'Use Ref/Target Scaffolds', 'value': 'normal'}, + {'label': 'Use Parent Ref/Target Product SMILES', 'value': 'parent'} + ], + value='parent', + labelStyle={'display': 'inline-block', 'marginRight': '15px'}, + ) + ], style={"textAlign": "center", "marginBottom": "15px"}), + + # Plot type dropdown removed + ]), + + html.Div([ + html.Label("Search Reference (SMILES or InChIKey):", style={"fontWeight": "bold"}), + dcc.Input( + id='search-input', + type='text', + debounce=True, + placeholder="Enter SMILES or InChIKey", + style={'width': '400px'} + ), + html.Div(id='search-status', style={"marginTop": "5px", "color": "red"}) + ], style={"textAlign": "center", "marginBottom": "20px"}), + html.Div([ + html.Label("Search Target (SMILES or InChIKey):", style={"fontWeight": "bold"}), + dcc.Input( + id='target-search-input', + type='text', + debounce=True, + placeholder="Enter SMILES or InChIKey", + style={'width': '400px'} + ), + html.Div(id='target-search-status', style={"marginTop": "5px", "color": "red"}) + ], style={"textAlign": "center", "marginBottom": "20px"}), + + + + html.Div([ + html.Label("Frequency Range:", style={"marginBottom": "5px"}), + dcc.RangeSlider( + id='freq-range-slider', + min=0, + max=n_bins-1, + step=1, + value=[0, n_bins-1], + marks=slider_marks, + allowCross=False + ) + ], style={ + "width": "80%", + "margin": "0 auto 10px auto", + "padding": "0", + "backgroundColor": "transparent", + "border": "none" + }), + + + html.Div([ + dcc.Graph( + id='heatmap', + figure=figs['frequency'], + style={ + 'height': '700px', + 'width': '80%' + } + ), + html.Div(id='color-legend', style={ + 'marginLeft': '5px', + 'alignSelf': 'center', + }) + ], style={ + 'display': 'flex', + 'flexDirection': 'row', + 'justifyContent': 'flex-center', + 'alignItems': 'center', + }), + + html.Div([ + html.Div("Reference Space Structure", style={"fontWeight": "bold", "fontSize": 18, "marginBottom": "5px"}), + html.Img(id='ref-img', src="", style={ + 'height': '150px', + 'display': 'none', + 'border': '1px solid #ccc', + 'backgroundColor': 'white', + 'margin': '0 auto 10px auto', + 'maxWidth': '200px', + 'objectFit': 'contain' + }), + ], style={'textAlign': 'center'}), + + html.Div([ + html.Div("Corresponding Target Space Structures", style={ + "fontWeight": "bold", + "fontSize": 18, + "marginBottom": "5px", + "textAlign": "center" + }), + + html.Div([ + html.Img(id='grid-img', src="", style={ + 'height': '600px', + 'border': '1px solid #ccc', + 'margin': '10px', + 'maxWidth': '100%', + 'objectFit': 'contain' + }), + + html.Div([ + html.Label("Scroll Pages", style={"fontWeight": "bold", "marginBottom": "10px"}), + dcc.Slider( + id='grid-page-slider', + min=0, + max=0, # will be set dynamically + step=1, + value=0, + marks={0: "0"}, + vertical=True, + updatemode='drag', + tooltip={"always_visible": True, "placement": "right"}, + ) + ], style={ + 'width': '60px', + 'height': '600px', + 'marginLeft': '10px', + 'display': 'flex', + 'flexDirection': 'column', + 'justifyContent': 'center' + }) + ], style={ + 'display': 'flex', + 'flexDirection': 'row', + 'justifyContent': 'center', + 'alignItems': 'center', + 'marginBottom': '30px' + }) + ]) + + ]) + + # Clientside callback to update mouse-coords with actual mouse position over the graph + app.clientside_callback( + """ + function(n_intervals) { + // Only attach the event listener once + if (!window._dash_mousemove_attached) { + var graphDiv = document.querySelector('#heatmap'); + if (graphDiv) { + graphDiv.addEventListener('mousemove', function(e) { + var coords = {'x': e.clientX, 'y': e.clientY}; + window.dash_clientside.callback_context.setStoreData('mouse-coords', coords); + }); + window._dash_mousemove_attached = true; + } + } + return window._dash_last_mouse_coords || {'x': 0, 'y': 0}; + } + """, + Output('mouse-coords', 'data'), + Input('heatmap', 'n_clicks'), # dummy input to trigger on load + prevent_initial_call=True + ) + + @app.callback( + Output('hover-img', 'src'), + Output('hover-img', 'style'), + Input('heatmap', 'hoverData') + ) + def update_hover_img(hoverData): + if not hoverData or "points" not in hoverData: + return TRANSPARENT_PNG_BASE64, { + 'position': 'fixed', + 'top': '0px', + 'left': '0px', + 'display': 'none', + 'zIndex': 1000, + 'border': '2px solid #333', + 'backgroundColor': 'white', + 'height': '120px' + } + # Get data coordinates + point = hoverData["points"][0] + x = point["x"] + y = point["y"] + + # Find the nearest bin and ref_id + x_bin = find_nearest_bin(x_centers, x) + y_bin = find_nearest_bin(y_centers, y) + possible_keys = [key for key in bin_data_map.keys() if key[0] == y_bin and key[1] == x_bin] + ref_id = None + if possible_keys: + ref_id = bin_data_map[possible_keys[0]]['ref_id'] + + # Generate the image for the ref_id, or fallback to not rendering + if ref_id: + img = generate_image_from_smiles(ref_id) + img_src = image_to_base64(img) if img else TRANSPARENT_PNG_BASE64 + + # Assume graph is 80% of window width and 700px height, and centered horizontally + import dash + window_width = 1200 # fallback default + graph_width = int(window_width * 0.8) + graph_height = 700 + graph_left = int((window_width - graph_width) / 2) + graph_top = 300 # estimate vertical offset from top of page + + # Get axis ranges (assume linear, not log) + x_min = min(x_centers) + x_max = max(x_centers) + y_min = min(y_centers) + y_max = max(y_centers) + + # Map data coordinates to pixel coordinates + px = int((x - x_min) / (x_max - x_min) * graph_width) + graph_left + py = int((y - y_min) / (y_max - y_min) * graph_height) + graph_top + + style = { + 'position': 'fixed', + 'top': f"{py + 10}px", + 'left': f"{px + 10}px", + 'display': 'block', + 'zIndex': 1000, + 'border': '2px solid #333', + 'backgroundColor': 'white', + 'height': '120px' + } + return img_src, style + else: + # No ref_id, do not render image + return TRANSPARENT_PNG_BASE64, { + 'display': 'none' + } + + @app.callback( + Output('heatmap', 'figure'), + Output('search-status', 'children'), + Output('target-search-status', 'children'), + Input('search-input', 'value'), + Input('target-search-input', 'value'), + Input('freq-range-slider', 'value'), + ) + def update_figure_with_search(search_input, target_search_input, freq_range): + # freq_range is now [bin_start_idx, bin_end_idx] + bins = legend_data["frequency"]["bins"] + # Use all indices as valid + bin_start = min(range(n_bins), key=lambda x: abs(x - freq_range[0])) + bin_end = min(range(n_bins), key=lambda x: abs(x - freq_range[1])) + # Convert to frequency values + min_freq = bins[bin_start] + max_freq = bins[bin_end+1] if bin_end+1 < len(bins) else bins[-1] + freq_range_actual = [min_freq, max_freq] + + fig, freq_bins, min_freq, max_freq = generate_heatmap_figure( + agg_combined, "frequency", freq_range_actual, + x_centers, y_centers, discrete_colorscale, n_bins, + special_ref_inchikeys=special_ref_inchikeys, + special_unique_ref_inchikeys=special_unique_ref_inchikeys + ) + + status = "" + if search_input: + search_input = search_input.strip() + # Try SMILES first + mol = Chem.MolFromSmiles(search_input) + if mol is not None: + try: + inchikey = Chem.inchi.MolToInchiKey(mol) + match = agg_combined[agg_combined["ref_inchikey"] == inchikey] + if not match.empty: + x, y = match.iloc[0]['dim_1'], match.iloc[0]['dim_2'] + fig.add_trace(go.Scatter( + x=[x], y=[y], mode='markers+text', + marker=dict(size=15, color='red', symbol='x') + )) + status = f"Match found for SMILES (InChIKey: {inchikey})" + else: + status = f"No match found for SMILES (InChIKey: {inchikey})" + except Exception as e: + status = f"Error parsing SMILES: {e}" + else: + # Check if input looks like an InChIKey (27 chars, 2 hyphens, uppercase) + ik = search_input.strip().upper() + if len(ik) == 27 and ik.count('-') == 2 and all(c.isalnum() or c == '-' for c in ik): + match = agg_combined[agg_combined["ref_inchikey"] == ik] + if not match.empty: + x, y = match.iloc[0]['dim_1'], match.iloc[0]['dim_2'] + fig.add_trace(go.Scatter( + x=[x], y=[y], mode='markers+text', + marker=dict(size=15, color='blue', symbol='circle'), + text=["Match (InChIKey)"], textposition="top center" + )) + status = f"Match found for InChIKey: {ik}" + else: + status = f"No match found for InChIKey: {ik}" + else: + status = "Input is not a valid SMILES or InChIKey." + + target_status = "" + if target_search_input: + target_search_input = target_search_input.strip() + # Try SMILES first + mol = Chem.MolFromSmiles(target_search_input) + if mol is not None: + try: + inchikey = Chem.inchi.MolToInchiKey(mol) + match = agg_combined[agg_combined["target_inchikey_list"].apply(lambda lst: inchikey in lst if isinstance(lst, (list, np.ndarray)) else False)] + if not match.empty: + xs = match['dim_1'].tolist() + ys = match['dim_2'].tolist() + fig.add_trace(go.Scatter( + x=xs, y=ys, mode='markers+text', + marker=dict(size=15, color='green', symbol='diamond'), + text=["Target"]*len(xs), textposition="top center" + )) + target_status = f"Match found for Target SMILES (InChIKey: {inchikey})" + else: + target_status = f"No match found for Target SMILES (InChIKey: {inchikey})" + except Exception as e: + target_status = f"Error parsing SMILES: {e}" + else: + # Check if input looks like an InChIKey (27 chars, 2 hyphens, uppercase) + ik = target_search_input.strip().upper() + if len(ik) == 27 and ik.count('-') == 2 and all(c.isalnum() or c == '-' for c in ik): + match = agg_combined[agg_combined["target_inchikey_list"].apply(lambda lst: ik in lst if isinstance(lst, (list, np.ndarray)) else False)] + if not match.empty: + xs = match['dim_1'].tolist() + ys = match['dim_2'].tolist() + fig.add_trace(go.Scatter( + x=xs, y=ys, mode='markers+text', + marker=dict(size=15, color='orange', symbol='star'), + text=["Target (InChIKey)"]*len(xs), textposition="top center" + )) + target_status = f"Match found for Target InChIKey: {ik}" + else: + target_status = f"No match found for Target InChIKey: {ik}" + else: + target_status = "Input is not a valid SMILES or InChIKey." + + return fig, status, target_status + + + + + @app.callback( + Output('grid-img', 'src'), + Input('heatmap', 'clickData'), + Input('mode-toggle', 'value'), + Input('grid-page-slider', 'value') + ) + def update_grid_image(clickData, mode, page): + if not clickData or "points" not in clickData: + return TRANSPARENT_PNG_BASE64 + + point = clickData["points"][0] + x, y = point["x"], point["y"] + + x_bin = find_nearest_bin(x_centers, x) + y_bin = find_nearest_bin(y_centers, y) + + possible_keys = [key for key in bin_data_map.keys() if key[0] == y_bin and key[1] == x_bin] + if not possible_keys: + return TRANSPARENT_PNG_BASE64 + selected_key = possible_keys[0] + + # Use page directly (no inversion) + if mode == 'normal': + smiles_list = bin_data_map[selected_key]['target_id_list'] + grid_img = make_normal_grid_image( + smiles_list, + page=page, + images_per_row=images_per_row, + rows_per_grid=rows_per_grid, + image_size=image_size + ) + else: + grid_img = generate_highlighted_grid_image( + parent_smiles_list=bin_data_map[selected_key]['parent_target_id_list'], + child_smiles_list=bin_data_map[selected_key]['target_id_list'], + parent_reaction_ids=bin_data_map[selected_key]['parent_reaction_id_list'], + page=page, + images_per_row=images_per_row, + rows_per_grid=rows_per_grid, + image_size=image_size + ) + + return image_to_base64(grid_img) + + @app.callback( + [Output('ref-img', 'src'), Output('ref-img', 'style')], + [Input('heatmap', 'clickData'), + Input('mode-toggle', 'value')] + ) + def update_ref_image_and_position(clickData, mode): + if not clickData or "points" not in clickData: + return TRANSPARENT_PNG_BASE64, { + 'height': '150px', + 'border': '1px solid #ccc', + 'backgroundColor': 'white', + 'margin': '20px auto 10px auto', + 'display': 'block', # keep display block so space is reserved + 'opacity': '0' # optionally make it fully transparent but still occupies space + } + point = clickData["points"][0] + x, y = point["x"], point["y"] + x_bin = find_nearest_bin(x_centers, x) + y_bin = find_nearest_bin(y_centers, y) + + possible_keys = [key for key in bin_data_map.keys() if key[0] == y_bin and key[1] == x_bin] + if not possible_keys: + return TRANSPARENT_PNG_BASE64, { + 'height': '150px', + 'border': '1px solid #ccc', + 'backgroundColor': 'white', + 'margin': '20px auto 10px auto', + 'display': 'block', # keep display block so space is reserved + 'opacity': '0' # optionally make it fully transparent but still occupies space + } + + selected_key = possible_keys[0] + ref_id = bin_data_map[selected_key]['ref_id'] + if not ref_id: + return TRANSPARENT_PNG_BASE64, { + 'height': '150px', + 'border': '1px solid #ccc', + 'backgroundColor': 'white', + 'margin': '20px auto 10px auto', + 'display': 'block', # keep display block so space is reserved + 'opacity': '0' # optionally make it fully transparent but still occupies space + } + img = generate_image_from_smiles(ref_id) + + src = image_to_base64(img) if img else "" + style = { + 'height': '150px', + 'border': '1px solid #ccc', + 'backgroundColor': 'white', + 'margin': '20px auto 10px auto', + 'display': 'block', + } + return src, style + + + @app.callback( + Output('color-legend', 'children'), + Input('freq-range-slider', 'value'), + ) + def update_color_legend(freq_range): + colors = legend_data["frequency"]["colors"] + bins = legend_data["frequency"]["bins"] + + def bin_label(i): + if i == 0: + return "0" + elif i == len(colors) - 1: + return f"[{to_human_notation(bins[i])}, ∞)" + else: + return f"[{to_human_notation(bins[i])}, {to_human_notation(bins[i+1])})" + + return html.Div([ + html.H4("Color Legend"), + html.Ul([ + html.Li(style={'listStyle': 'none', 'marginBottom': '5px'}, children=[ + html.Span(style={ + 'display': 'inline-block', + 'width': '20px', + 'height': '20px', + 'backgroundColor': colors[i], + 'marginRight': '10px', + 'border': '1px solid #000' + }), + bin_label(i) + ]) for i in reversed(range(len(colors))) + ]) + ], style={'textAlign': 'left'}) + + @app.callback( + Output('grid-page-slider', 'max'), + Output('grid-page-slider', 'value'), + Output('grid-page-slider', 'marks'), + Input('heatmap', 'clickData') + ) + def update_slider_max(clickData): + if not clickData or "points" not in clickData: + return 0, 0, {0: "0"} + + point = clickData["points"][0] + x, y = point["x"], point["y"] + x_bin = find_nearest_bin(x_centers, x) + y_bin = find_nearest_bin(y_centers, y) + + possible_keys = [key for key in bin_data_map if key[0] == y_bin and key[1] == x_bin] + if not possible_keys: + return 0, 0, {0: "0"} + + selected_key = possible_keys[0] + total = len(bin_data_map[selected_key]['target_id_list']) + per_page = images_per_row * rows_per_grid + max_page = (total - 1) // per_page + + # Use increasing marks so 0 is at the top and matches the tooltip + marks = {i: str(i) for i in range(max_page + 1)} + + return max_page, 0, marks + + port = start_port + i * 10 + print(f"Dash app running on port {port} for hc_order={hc_order_val}") + app.run_server(port=port, host='0.0.0.0') diff --git a/hcase/substructures.py b/hcase/substructures.py index ec345e6..bfeb93c 100644 --- a/hcase/substructures.py +++ b/hcase/substructures.py @@ -15,7 +15,7 @@ def smiles2bmscaffold(smiles): try: mol = Chem.MolFromSmiles(smiles) bms = MurckoScaffold.GetScaffoldForMol(mol) - bms = Chem.MolToSmiles(bms) + bms = Chem.MolToSmiles(bms, canonical=True) except: bms = 'NA'