diff --git a/download_from_zenodo.sh b/download_from_zenodo.sh index f6bad4b..96d1458 100755 --- a/download_from_zenodo.sh +++ b/download_from_zenodo.sh @@ -38,6 +38,16 @@ declare -A FILE_MAP=( ["SpecBridge_Spectraverse_candidates.pkl"]="data/SpecBridge_Spectraverse_candidates.pkl" ) +# Determine which Python to use +if command -v python3 >/dev/null 2>&1; then + PYTHON_CMD=python3 +elif command -v python >/dev/null 2>&1; then + PYTHON_CMD=python +else + echo "❌ Failed to find a Python interpreter (python3 or python)." + exit 1 +fi + # Get file list from Zenodo echo "📦 Fetching file list from Zenodo..." FILES_JSON=$(curl -s "${ZENODO_URL}") @@ -56,7 +66,7 @@ for zenodo_name in "${!FILE_MAP[@]}"; do local_path="${FILE_MAP[$zenodo_name]}" # Extract download URL for this file - download_url=$(echo "$FILES_JSON" | python3 -c " + download_url=$(echo "$FILES_JSON" | "$PYTHON_CMD" -c " import sys, json data = json.load(sys.stdin) for f in data.get('files', []): @@ -89,7 +99,7 @@ for f in data.get('files', []): echo " → ${local_path}" # Get file size for progress - file_size=$(echo "$FILES_JSON" | python3 -c " + file_size=$(echo "$FILES_JSON" | "$PYTHON_CMD" -c " import sys, json data = json.load(sys.stdin) for f in data.get('files', []): diff --git a/inference.ipynb b/inference.ipynb new file mode 100644 index 0000000..ebd408b --- /dev/null +++ b/inference.ipynb @@ -0,0 +1,620 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c3a6d7af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wrote 3 spectra to temp_data/output.mgf\n" + ] + } + ], + "source": [ + "from scripts.dataframe_to_mgf import convert_pd_to_mgf\n", + "import pandas as pd\n", + "import pickle\n", + "import json\n", + "\n", + "\n", + "temp_data_folder = \"temp_data\"\n", + "columns = [\"smiles\", \"formula\", \"precursor_mz\", \"adduct\", \"instrument_type\", \"collision_energy\", \"fold\", \"mzs\", \"intensities\"]\n", + "data = [[\"CCCCCC\", \"C6H14\", 100, \"[M+H]+\", \"Orbitrap\", 30.0, \"test\", '91.0542,125.0233,154.0499,155.0577,185.0961,200.107,229.0859,246.1125', '0.24524524524524524,1.0,0.08008008008008008,0.35535535535535534,0.34934934934934936,0.04504504504504504,0.14214214214214213,0.7347347347347347'],\n", + " [\"CCCC(CCC)CC\", \"C6H14\", 100, \"[M+H]+\", \"Orbitrap\", 30.0, \"test\", '70.0652,72.0807,80.0497,81.0572,84.0806,94.0653,95.0856,96.0808,105.0697,108.0808,110.0967,122.0965,124.052,126.1278,131.0492,136.052,137.0597,138.0677,139.0755,146.0728,150.0677,151.0754,153.1388,162.0676,163.0753,164.0831,165.0911,175.0756,176.0713,176.0838,177.0911,179.1072,186.0915,187.0994,189.0901,190.0869,191.1065,193.1227,197.0953,202.0863,203.0945,204.1025,212.12,214.1231,217.1101,218.1178,228.1152,229.1465,233.1288,233.1534,243.1379,244.132,245.1414,247.1443,248.1521,249.1586,260.1647,264.1955,291.2068', '0.15204982,0.02029124,0.013635669999999999,0.011343970000000002,0.01605363,0.03316496,0.06067071,0.12251221,0.03452229,0.01859035,0.0497156,0.24131891,0.04573695,0.02566249,0.01281502,0.12434985,0.09406283,0.13051761,0.10595964000000001,0.01560572,0.04302669,0.93310411,0.08373860000000001,0.01206865,0.035472250000000004,0.03266045,0.49504897999999997,0.01554959,0.021134919999999998,0.00877385,0.7773944199999999,0.04626964,0.02004546,0.04536204,0.010848070000000001,0.02189874,0.07254942,0.033431829999999996,0.01577041,0.05776354,0.051225849999999996,0.02355171,0.026938080000000003,0.0284604,0.14925443,0.31206497,0.03363947,0.024091550000000003,0.15443661,0.15695751,0.13653456,0.01398749,0.12058118000000001,0.05583403,0.68011155,0.01409389,1.0,0.08633982,0.97755591'],\n", + " [\"CCCCCC\", \"C6H14\", 100, \"[M+H]+\", \"Orbitrap\", 30.0, \"test\",'79.021126,206.104706,207.088135,246.098892,260.114868,261.110107,287.126129,315.046417,315.132233,315.218964', '0.00362,0.00409,0.00932,0.0032400000000000003,0.04627,0.04598,0.02078,0.008239999999999999,1.0,0.00536']]\n", + "df = pd.DataFrame(data, columns=columns)\n", + "convert_pd_to_mgf(df, f'{temp_data_folder}/output.mgf') \n", + "\n", + "candidates = {smiles: [df.smiles.tolist()] for smiles in df.smiles.tolist()}\n", + "with open(f'{temp_data_folder}/candidates.json', 'w') as f:\n", + " json.dump(candidates, f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea5f6953", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[config] device=cpu\n", + "[config] mol_space=chemberta\n", + "[config] use_mapped=False\n", + "[config] deterministic_map=True\n", + "[config] n_blocks=8\n", + "[candidates] Loading from temp_data/candidates.json\n", + "[candidates] Loaded 2 feature_id entries\n", + "[model] Building SpecBridge model...\n", + "Error using dreams encoder from checkpoint DreaMS/dreams/models/pretrained/ssl_model.ckpt: No module named 'pkg_resources'\n", + "Error using dreams encoder from checkpoint DreaMS/dreams/models/pretrained/ssl_model.ckpt using DreamsEncoder: cannot import name 'DreamsEncoder' from 'dreams' (G:\\Other computers\\My laptop\\Cambridge PhD Project\\Projects\\Nitrosation Project\\Code\\SpecAssign\\SpecBridge\\DreaMS\\dreams\\__init__.py)\n", + "Using dummy dreams encoder\n", + "Warning: Could not load DreaMS checkpoint DreaMS/dreams/models/pretrained/ssl_model.ckpt: cannot instantiate 'PosixPath' on your system\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of RobertaModel were not initialized from the model checkpoint at Derify/ChemBERTa_augmented_pubchem_13m and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[model] Loading adapter checkpoint: runs/msgym/SpecBridge_MSGYM_checkpoint.pt\n", + "[model] Missing keys: 4\n", + "[model] Unexpected keys: 65\n", + "[model] Model ready\n", + "[dataset] Loading MGF from temp_data/output.mgf\n", + "[dataset] Intensity normalization: enabled\n", + "[dataset] Loaded 3 spectra\n", + "[embedding] Collecting unique candidate SMILES...\n", + "[embedding] Found 2 unique candidate SMILES\n", + "[embedding] Computing candidate embeddings...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Embedding candidates: 100%|██████████| 1/1 [00:00<00:00, 29.06it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[embedding] Computed embeddings for 2 candidates\n", + "[prediction] Processing spectra...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Predicting: 0%| | 0/1 [00:00 \u001b[39m\u001b[32m17\u001b[39m results = \u001b[43mpredictor\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\utils\\_contextlib.py:115\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 112\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 113\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m115\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mg:\\Other computers\\My laptop\\Cambridge PhD Project\\Projects\\Nitrosation Project\\Code\\SpecAssign\\SpecBridge\\specbridge\\eval\\predict_smiles.py:706\u001b[39m, in \u001b[36mSpecBridgePredictor.predict\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 702\u001b[39m meta_with_smiles[\u001b[33m\"\u001b[39m\u001b[33msmi_key\u001b[39m\u001b[33m\"\u001b[39m] = dummy_smiles\n\u001b[32m 704\u001b[39m \u001b[38;5;28mprint\u001b[39m(meta_with_smiles) \u001b[38;5;66;03m# TESTING, DELETE LATER\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m706\u001b[39m z_s, z_m, z_hat, mu, lv = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmeta_with_smiles\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minference\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 708\u001b[39m \u001b[38;5;66;03m# Build query embeddings\u001b[39;00m\n\u001b[32m 709\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.args.use_mapped:\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1511\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1515\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1516\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1517\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1518\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1519\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1520\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1523\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mg:\\Other computers\\My laptop\\Cambridge PhD Project\\Projects\\Nitrosation Project\\Code\\SpecAssign\\SpecBridge\\specbridge\\models\\mapper.py:128\u001b[39m, in \u001b[36mDreamsToMolCondition.forward\u001b[39m\u001b[34m(self, spectra_binned, meta, mol_feats, inference)\u001b[39m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, spectra_binned, meta, mol_feats, inference: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m 124\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 125\u001b[39m \u001b[33;03m meta must include a list[str] SMILES under key 'smi_key' (or 'smiles').\u001b[39;00m\n\u001b[32m 126\u001b[39m \u001b[33;03m mol_feats is only used when mol_space == 'ecfp' (0/1 fingerprint tensor).\u001b[39;00m\n\u001b[32m 127\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m128\u001b[39m z_s = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mspec\u001b[49m\u001b[43m(\u001b[49m\u001b[43mspectra_binned\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmeta\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 130\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.mol_space == \u001b[33m\"\u001b[39m\u001b[33mchemberta\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 131\u001b[39m smiles = meta.get(\u001b[33m\"\u001b[39m\u001b[33msmi_key\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mor\u001b[39;00m meta.get(\u001b[33m\"\u001b[39m\u001b[33msmiles\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1511\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1515\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1516\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1517\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1518\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1519\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1520\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1523\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mg:\\Other computers\\My laptop\\Cambridge PhD Project\\Projects\\Nitrosation Project\\Code\\SpecAssign\\SpecBridge\\specbridge\\adapters\\dreams_adapter.py:222\u001b[39m, in \u001b[36mDreamsAdapter.forward\u001b[39m\u001b[34m(self, spectra_binned, meta)\u001b[39m\n\u001b[32m 220\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m 221\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(meta, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[33m'\u001b[39m\u001b[33mpeaks\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m meta \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(meta[\u001b[33m'\u001b[39m\u001b[33mpeaks\u001b[39m\u001b[33m'\u001b[39m], torch.Tensor):\n\u001b[32m--> \u001b[39m\u001b[32m222\u001b[39m out = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdreams\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmeta\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mpeaks\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmeta\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# could be [B,N,D] or [B,D]\u001b[39;00m\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m out.dim() == \u001b[32m3\u001b[39m:\n\u001b[32m 224\u001b[39m z0 = \u001b[38;5;28mself\u001b[39m._pool(out, meta[\u001b[33m'\u001b[39m\u001b[33mpeaks\u001b[39m\u001b[33m'\u001b[39m])\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1511\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1515\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1516\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1517\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1518\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1519\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1520\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1523\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mg:\\Other computers\\My laptop\\Cambridge PhD Project\\Projects\\Nitrosation Project\\Code\\SpecAssign\\SpecBridge\\specbridge\\adapters\\dreams_adapter.py:19\u001b[39m, in \u001b[36mDummyDreams.forward\u001b[39m\u001b[34m(self, spectra_binned, meta)\u001b[39m\n\u001b[32m 18\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, spectra_binned: torch.Tensor, meta: \u001b[38;5;28mdict\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mspectra_binned\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1511\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1515\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1516\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1517\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1518\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1519\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1520\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1523\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[39m, in \u001b[36mSequential.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 215\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[32m 216\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m217\u001b[39m \u001b[38;5;28minput\u001b[39m = \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 218\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1509\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1510\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1511\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1515\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1516\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1517\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1518\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1519\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1520\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1522\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1523\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\pietr\\anaconda3\\envs\\specbridge_test\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:116\u001b[39m, in \u001b[36mLinear.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[31mRuntimeError\u001b[39m: mat1 and mat2 shapes cannot be multiplied (177x2 and 2048x1024)" + ] + } + ], + "source": [ + "from specbridge.eval.predict_smiles import SpecBridgePredictor\n", + "\n", + "\n", + "predictor = SpecBridgePredictor(\n", + " mgf=\"temp_data/output.mgf\",\n", + " dreams_ckpt=\"DreaMS/dreams/models/pretrained/ssl_model.ckpt\",\n", + " adapter_ckpt=\"runs/msgym/SpecBridge_MSGYM_checkpoint.pt\",\n", + " candidates_json=\"temp_data/candidates.json\",\n", + " use_mapped=True,\n", + " deterministic_map=True,\n", + " no_gaussian=True,\n", + " batch_size=32,\n", + " top_k=5, # Get top 5 predictions\n", + " limit=100 # Optional: only process first 100 spectra for testing\n", + ")\n", + "\n", + "results = predictor.predict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10fb1eba", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e26222ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 59, 2])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "t = {'formula': torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 'adduct': torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 'charge': torch.tensor([[1., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0.],\n", + " [1., 0., 0., 0., 0., 0., 0., 0.]]), 'peaks': torch.tensor([[[9.1054e+01, 2.4525e-01],\n", + " [1.2502e+02, 1.0000e+00],\n", + " [1.5405e+02, 8.0080e-02],\n", + " [1.5506e+02, 3.5536e-01],\n", + " [1.8510e+02, 3.4935e-01],\n", + " [2.0011e+02, 4.5045e-02],\n", + " [2.2909e+02, 1.4214e-01],\n", + " [2.4611e+02, 7.3473e-01],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[7.0065e+01, 1.5205e-01],\n", + " [7.2081e+01, 2.0291e-02],\n", + " [8.0050e+01, 1.3636e-02],\n", + " [8.1057e+01, 1.1344e-02],\n", + " [8.4081e+01, 1.6054e-02],\n", + " [9.4065e+01, 3.3165e-02],\n", + " [9.5086e+01, 6.0671e-02],\n", + " [9.6081e+01, 1.2251e-01],\n", + " [1.0507e+02, 3.4522e-02],\n", + " [1.0808e+02, 1.8590e-02],\n", + " [1.1010e+02, 4.9716e-02],\n", + " [1.2210e+02, 2.4132e-01],\n", + " [1.2405e+02, 4.5737e-02],\n", + " [1.2613e+02, 2.5662e-02],\n", + " [1.3105e+02, 1.2815e-02],\n", + " [1.3605e+02, 1.2435e-01],\n", + " [1.3706e+02, 9.4063e-02],\n", + " [1.3807e+02, 1.3052e-01],\n", + " [1.3908e+02, 1.0596e-01],\n", + " [1.4607e+02, 1.5606e-02],\n", + " [1.5007e+02, 4.3027e-02],\n", + " [1.5108e+02, 9.3310e-01],\n", + " [1.5314e+02, 8.3739e-02],\n", + " [1.6207e+02, 1.2069e-02],\n", + " [1.6308e+02, 3.5472e-02],\n", + " [1.6408e+02, 3.2660e-02],\n", + " [1.6509e+02, 4.9505e-01],\n", + " [1.7508e+02, 1.5550e-02],\n", + " [1.7607e+02, 2.1135e-02],\n", + " [1.7608e+02, 8.7739e-03],\n", + " [1.7709e+02, 7.7739e-01],\n", + " [1.7911e+02, 4.6270e-02],\n", + " [1.8609e+02, 2.0045e-02],\n", + " [1.8710e+02, 4.5362e-02],\n", + " [1.8909e+02, 1.0848e-02],\n", + " [1.9009e+02, 2.1899e-02],\n", + " [1.9111e+02, 7.2549e-02],\n", + " [1.9312e+02, 3.3432e-02],\n", + " [1.9710e+02, 1.5770e-02],\n", + " [2.0209e+02, 5.7764e-02],\n", + " [2.0309e+02, 5.1226e-02],\n", + " [2.0410e+02, 2.3552e-02],\n", + " [2.1212e+02, 2.6938e-02],\n", + " [2.1412e+02, 2.8460e-02],\n", + " [2.1711e+02, 1.4925e-01],\n", + " [2.1812e+02, 3.1206e-01],\n", + " [2.2812e+02, 3.3639e-02],\n", + " [2.2915e+02, 2.4092e-02],\n", + " [2.3313e+02, 1.5444e-01],\n", + " [2.3315e+02, 1.5696e-01],\n", + " [2.4314e+02, 1.3653e-01],\n", + " [2.4413e+02, 1.3987e-02],\n", + " [2.4514e+02, 1.2058e-01],\n", + " [2.4714e+02, 5.5834e-02],\n", + " [2.4815e+02, 6.8011e-01],\n", + " [2.4916e+02, 1.4094e-02],\n", + " [2.6016e+02, 1.0000e+00],\n", + " [2.6420e+02, 8.6340e-02],\n", + " [2.9121e+02, 9.7756e-01]],\n", + "\n", + " [[7.9021e+01, 3.6200e-03],\n", + " [2.0610e+02, 4.0900e-03],\n", + " [2.0709e+02, 9.3200e-03],\n", + " [2.4610e+02, 3.2400e-03],\n", + " [2.6011e+02, 4.6270e-02],\n", + " [2.6111e+02, 4.5980e-02],\n", + " [2.8713e+02, 2.0780e-02],\n", + " [3.1505e+02, 8.2400e-03],\n", + " [3.1513e+02, 1.0000e+00],\n", + " [3.1522e+02, 5.3600e-03],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00]]]), 'smi_key': ['C', 'C', 'C']}\n", + "\n", + "t[\"peaks\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "61131d31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "177" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "3*59" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "specbridge_test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/dataframe_to_mgf.py b/scripts/dataframe_to_mgf.py new file mode 100644 index 0000000..01effeb --- /dev/null +++ b/scripts/dataframe_to_mgf.py @@ -0,0 +1,69 @@ +import pandas as pd + +def convert_pd_to_mgf(df, output_path): + """ + Convert a pandas DataFrame to MGF format. + + Parameters: + ----------- + df : pd.DataFrame + DataFrame with columns: smiles, formula, precursor_mz, adduct, + instrument_type, collision_energy, fold, mzs, intensities + output_path : str + Path to save the MGF file + """ + with open(output_path, 'w', encoding='utf-8') as f: + for idx, row in df.iterrows(): + # Start of spectrum block + f.write("BEGIN IONS\n") + + # Write metadata + f.write(f"TITLE=spectrum_{idx}\n") + + if pd.notna(row['smiles']): + f.write(f"SMILES={row['smiles']}\n") + + if pd.notna(row['formula']): + f.write(f"FORMULA={row['formula']}\n") + + if pd.notna(row['precursor_mz']): + f.write(f"PRECURSOR_MZ={row['precursor_mz']}\n") + + if pd.notna(row['adduct']): + f.write(f"ADDUCT={row['adduct']}\n") + + if pd.notna(row['instrument_type']): + f.write(f"INSTRUMENT_TYPE={row['instrument_type']}\n") + + if pd.notna(row['collision_energy']): + f.write(f"COLLISION_ENERGY={row['collision_energy']}\n") + + if pd.notna(row['fold']): + f.write(f"FOLD={row['fold']}\n") + + # Parse and write peak data (m/z intensity pairs) + # Handle different formats of masses and intensities + if isinstance(row['mzs'], str): + masses = [float(m.strip()) for m in row['mzs'].split(',')] + else: + masses = row['mzs'] # Already a list + + if isinstance(row['intensities'], str): + intensities = [float(i.strip()) for i in row['intensities'].split(',')] + else: + intensities = row['intensities'] # Already a list + + # Write m/z intensity pairs + for mz, intensity in zip(masses, intensities): + f.write(f"{mz} {intensity}\n") + + # End of spectrum block + f.write("END IONS\n") + f.write("\n") # Blank line between spectra (optional but common) + + print(f"Wrote {len(df)} spectra to {output_path}") + + +# Example usage: +# df = pd.read_csv('your_data.csv') +#dataframe_to_mgf(df, 'output.mgf') \ No newline at end of file diff --git a/specbridge/eval/predict_smiles.py b/specbridge/eval/predict_smiles.py index 2262e0d..3872b1d 100755 --- a/specbridge/eval/predict_smiles.py +++ b/specbridge/eval/predict_smiles.py @@ -460,6 +460,402 @@ def main(): print(f"[done] Results written to {args.output}") -if __name__ == "__main__": - main() +class SpecBridgePredictor: + """ + Jupyter-friendly wrapper for SpecBridge SMILES prediction. + + Usage in Jupyter notebook: + ```python + from specbridge.eval.predict_smiles import SpecBridgePredictor + + predictor = SpecBridgePredictor( + mgf="data/my_spectra.mgf", + dreams_ckpt="DreaMS/dreams/models/pretrained/ssl_model.ckpt", + adapter_ckpt="runs/msgym/SpecBridge_MSGYM_checkpoint.pt", + candidates_json="data/candidates.json", + use_mapped=True, + deterministic_map=True, + no_gaussian=True, + batch_size=32, + cond_dim=2048, + mapper_hidden=2048, + mol_space="chemberta", + chemberta_model="Derify/ChemBERTa_augmented_pubchem_13m", + top_k=5, + seed=1234, + cpu=False + ) + + # Run prediction + results = predictor.predict() + + # Save results + predictor.save_results(results, "predictions.json") + ``` + """ + + def __init__( + self, + mgf: str, + dreams_ckpt: str, + adapter_ckpt: str, + candidates_json: str, + spec_bins: int = 2048, + cond_dim: int = 2048, + mapper_hidden: int = 2048, + n_blocks=8, # Number of blocks in the mapper + no_gaussian: bool = True, + use_mapped: bool = True, + deterministic_map: bool = True, + batch_size: int = 32, + mol_space: str = "chemberta", + chemberta_model: str = "Derify/ChemBERTa_augmented_pubchem_13m", + limit: Optional[int] = None, + seed: int = 1234, + cpu: bool = False, + handle_duplicates: str = "keep_first", + top_k: int = 1, + no_normalize_intensities: bool = False, + ): + """ + Initialize the SpecBridge predictor. + + Args: + mgf: Path to input MGF file with mass spectra + dreams_ckpt: Path to DreaMS backbone checkpoint (ssl_model.ckpt) + adapter_ckpt: Path to trained adapter checkpoint (.pt file) + candidates_json: Path to JSON file with {feature_id: [candidate_smiles]} + spec_bins: Number of bins for spectrum binning (default: 2048) + cond_dim: Conditioning dimension (default: 2048) + mapper_hidden: Mapper hidden dimension (default: 2048) + no_gaussian: Disable Gaussian uncertainty (default: True) + use_mapped: Use mapped embedding for query (default: True) + deterministic_map: Use deterministic mapping/mean (default: True) + batch_size: Batch size for processing (default: 32) + mol_space: Molecule embedding space - "chemberta" or "ecfp" (default: "chemberta") + chemberta_model: HuggingFace ChemBERTa model name (default: "Derify/ChemBERTa_augmented_pubchem_13m") + limit: Limit number of spectra to process (default: None = all) + seed: Random seed (default: 1234) + cpu: Force CPU usage even if GPU available (default: False) + handle_duplicates: How to handle duplicate feature_ids - "keep_first" or "combine" (default: "keep_first") + top_k: Number of top predictions to return per spectrum (default: 1) + no_normalize_intensities: Disable intensity normalization (default: False) + """ + # Create args object matching main() expectations + self.args = argparse.Namespace( + mgf=mgf, + dreams_ckpt=dreams_ckpt, + adapter_ckpt=adapter_ckpt, + candidates_json=candidates_json, + spec_bins=spec_bins, + cond_dim=cond_dim, + mapper_hidden=mapper_hidden, + n_blocks=n_blocks, + no_gaussian=no_gaussian, + use_mapped=use_mapped, + deterministic_map=deterministic_map, + batch_size=batch_size, + mol_space=mol_space, + chemberta_model=chemberta_model, + limit=limit, + seed=seed, + cpu=cpu, + handle_duplicates=handle_duplicates, + top_k=top_k, + no_normalize_intensities=no_normalize_intensities, + ) + + # Setup device and seed + self.device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu") + set_seed(seed) + + print(f"[config] device={self.device}") + print(f"[config] mol_space={mol_space}") + print(f"[config] use_mapped={use_mapped}") + print(f"[config] deterministic_map={deterministic_map}") + print(f"[config] n_blocks={n_blocks}") + + # Will be initialized in predict() + self.model = None + self.embed_fn = None + self.candidates_dict = None + + def _load_candidates(self): + """Load candidates JSON file.""" + print(f"[candidates] Loading from {self.args.candidates_json}") + with open(self.args.candidates_json, 'r') as f: + self.candidates_dict = json.load(f) + print(f"[candidates] Loaded {len(self.candidates_dict)} feature_id entries") + + def _build_model(self): + """Build and load the SpecBridge model.""" + print("[model] Building SpecBridge model...") + self.model = build_model(self.args, self.device) + self.embed_fn = build_mol_embed_fn(self.args, self.model, self.device) + print("[model] Model ready") + + @torch.no_grad() + def predict(self) -> List[Dict]: + """ + Run prediction on all spectra in the MGF file. + + Returns: + List of prediction results, where each result is a dict with: + - feature_id: Feature ID from MGF + - title: Spectrum title + - predicted_smiles: Top predicted SMILES + - top_k_smiles: List of top-K predicted SMILES + - top_k_scores: List of top-K similarity scores + - all_candidates: All candidates sorted by score (descending) + - all_scores: All scores sorted (descending) + - num_candidates: Number of candidates available + - status: "success", "no_candidates", "no_embeddings", etc. + """ + # Load candidates if not already loaded + if self.candidates_dict is None: + self._load_candidates() + + # Build model if not already built + if self.model is None: + self._build_model() + + # Create dataset + print(f"[dataset] Loading MGF from {self.args.mgf}") + normalize_intensities = not self.args.no_normalize_intensities + print(f"[dataset] Intensity normalization: {'enabled' if normalize_intensities else 'disabled'}") + dataset = FeatureIDDataset(self.args.mgf, normalize_intensities=normalize_intensities) + print(f"[dataset] Loaded {len(dataset)} spectra") + + # Handle duplicates if needed + if self.args.handle_duplicates == "combine": + feature_id_to_indices = {} + for i, rec in enumerate(dataset._records): + fid = rec["feature_id"] + if fid: + if fid not in feature_id_to_indices: + feature_id_to_indices[fid] = [] + feature_id_to_indices[fid].append(i) + + for fid, indices in feature_id_to_indices.items(): + if len(indices) > 1 and fid in self.candidates_dict: + pass # Already have candidates, keep as is + print(f"[dataset] Found {len([fid for fid, idxs in feature_id_to_indices.items() if len(idxs) > 1])} duplicate feature_ids") + + # Create dataloader + loader = DataLoader( + dataset, + batch_size=self.args.batch_size, + shuffle=False, + collate_fn=lambda b: collate_with_feature_id( + b, self.args.spec_bins, 32, 16, 8, 2048, seed=self.args.seed + ), + num_workers=0, + pin_memory=False, + ) + + # Precompute all candidate embeddings + print("[embedding] Collecting unique candidate SMILES...") + all_candidate_smiles = set() + for fid, cand_list in self.candidates_dict.items(): + if isinstance(cand_list, list): + for item in cand_list: + if isinstance(item, list): + all_candidate_smiles.update(item) + else: + all_candidate_smiles.add(item) + else: + all_candidate_smiles.add(cand_list) + all_candidate_smiles = sorted(list(all_candidate_smiles)) + print(f"[embedding] Found {len(all_candidate_smiles)} unique candidate SMILES") + + print("[embedding] Computing candidate embeddings...") + cand_embeddings = {} + batch_size_emb = 512 + for i in tqdm(range(0, len(all_candidate_smiles), batch_size_emb), desc="Embedding candidates"): + chunk = all_candidate_smiles[i:i+batch_size_emb] + Z = self.embed_fn(chunk, self.device, bs=batch_size_emb) + for smi, z in zip(chunk, Z): + canon_smi = _canon_smi(smi) or smi + cand_embeddings[canon_smi] = z + print(f"[embedding] Computed embeddings for {len(cand_embeddings)} candidates") + + # Process spectra and predict + print("[prediction] Processing spectra...") + results = [] + total_processed = 0 + missing_candidates = 0 + missing_embeddings = 0 + + for batch in tqdm(loader, desc="Predicting"): + if self.args.limit and total_processed >= self.args.limit: + break + + s = batch["spectra"].to(self.device) + B = s.size(0) + meta = {k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in batch["meta"].items()} + feature_ids = batch["feature_ids"] + titles = batch["titles"] + + # Get query embeddings + dummy_smiles = ["C"] * B + meta_with_smiles = meta.copy() + meta_with_smiles["smi_key"] = dummy_smiles + print(meta_with_smiles) # TESTING, DELETE LATER + + z_s, z_m, z_hat, mu, lv = self.model(s, meta_with_smiles, None, inference=True) + + # Build query embeddings + if self.args.use_mapped: + if self.args.deterministic_map or (lv is None): + z_query = mu + else: + z_query = self.model.mapB.sample(mu, lv, deterministic=True) + else: + if self.args.deterministic_map or (lv is None): + z_query = mu + else: + z_query = self.model.mapB.sample(mu, lv, deterministic=True) + + for i in range(B): + if self.args.limit and total_processed >= self.args.limit: + break + + feature_id = feature_ids[i] + title = titles[i] + + # Handle None feature_id + if feature_id is None: + results.append({ + "feature_id": None, + "title": title, + "predicted_smiles": None, + "top_k_smiles": [], + "top_k_scores": [], + "all_candidates": [], + "all_scores": [], + "num_candidates": 0, + "status": "no_feature_id" + }) + total_processed += 1 + continue + + zq = z_query[i] + + # Get candidates for this feature_id + cand_list = self.candidates_dict.get(feature_id, []) + if not isinstance(cand_list, list): + cand_list = [] + if not cand_list: + missing_candidates += 1 + results.append({ + "feature_id": feature_id, + "title": title, + "predicted_smiles": None, + "top_k_smiles": [], + "top_k_scores": [], + "all_candidates": [], + "all_scores": [], + "num_candidates": 0, + "status": "no_candidates" + }) + total_processed += 1 + continue + + # Get embeddings for candidates + Z_candidates = [] + candidate_smiles = [] + for smi in cand_list: + canon_smi = _canon_smi(smi) or smi + z_cand = cand_embeddings.get(canon_smi) + if z_cand is not None: + Z_candidates.append(z_cand) + candidate_smiles.append(canon_smi) + else: + missing_embeddings += 1 + + if not Z_candidates: + results.append({ + "feature_id": feature_id, + "title": title, + "predicted_smiles": None, + "top_k_smiles": [], + "top_k_scores": [], + "all_candidates": [], + "all_scores": [], + "num_candidates": len(cand_list), + "status": "no_embeddings" + }) + total_processed += 1 + continue + + # Compute similarities + Z_stack = torch.stack(Z_candidates, dim=0).to(self.device) + zq_norm = F.normalize(zq.unsqueeze(0), dim=-1) + Z_norm = F.normalize(Z_stack, dim=-1) + similarities = (zq_norm @ Z_norm.T).squeeze(0) + + # Sort all candidates by similarity (descending) + all_indices = torch.argsort(similarities, descending=True) + all_smiles_sorted = [candidate_smiles[idx] for idx in all_indices] + all_scores_sorted = [float(similarities[idx]) for idx in all_indices] + + # Get top-k + top_k = min(self.args.top_k, len(similarities)) + top_smiles = all_smiles_sorted[:top_k] + top_scores = all_scores_sorted[:top_k] + + results.append({ + "feature_id": feature_id, + "title": title, + "predicted_smiles": all_smiles_sorted[0] if all_smiles_sorted else None, + "top_k_smiles": top_smiles, + "top_k_scores": top_scores, + "all_candidates": all_smiles_sorted, + "all_scores": all_scores_sorted, + "num_candidates": len(cand_list), + "status": "success" + }) + total_processed += 1 + + print(f"\n[summary] Processed {total_processed} spectra") + print(f"[summary] Missing candidates: {missing_candidates}") + print(f"[summary] Missing embeddings: {missing_embeddings}") + + return results + + def save_results(self, results: List[Dict], output_path: str): + """ + Save prediction results to file. + + Args: + results: List of prediction results from predict() + output_path: Output file path (.json or .tsv) + """ + print(f"[output] Writing results to {output_path}") + if output_path.endswith('.json'): + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + elif output_path.endswith('.tsv'): + import csv + with open(output_path, 'w', newline='') as f: + writer = csv.writer(f, delimiter='\t') + writer.writerow(['feature_id', 'title', 'predicted_smiles', 'score', 'num_candidates', 'status']) + for r in results: + writer.writerow([ + r['feature_id'], + r['title'], + r['predicted_smiles'] or '', + r['top_k_scores'][0] if r['top_k_scores'] else '', + r['num_candidates'], + r['status'] + ]) + else: + # Default to JSON + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + + print(f"[done] Results written to {output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/specbridge/notebook.py b/specbridge/notebook.py new file mode 100644 index 0000000..126933c --- /dev/null +++ b/specbridge/notebook.py @@ -0,0 +1,247 @@ +""" +Notebook-friendly SpecBridge predictor. +-------------------------------------- +Minimal helper to load trained checkpoints and score candidate SMILES directly +from Python (e.g., inside a Jupyter notebook) without invoking the CLI script. +""" +from __future__ import annotations + +from types import SimpleNamespace +from typing import Dict, List, Sequence, Set + +import torch +import torch.nn.functional as F + +from specbridge.eval.predict_smiles import ( + _canon_smi, + build_model, + build_mol_embed_fn, + collate_with_feature_id, +) +from specbridge.utils.common import set_seed + +DUMMY_SMILES = "C" + + +class SpecBridgeNotebookPredictor: + """ + Lightweight wrapper for notebook inference. + + Usage: + predictor = SpecBridgeNotebookPredictor( + dreams_ckpt="runs/msgym/ssl_model.ckpt", + adapter_ckpt="runs/msgym/checkpoint.ckpt", + device="cuda" # or "cpu" + ) + predictor.embed_candidates(["CCO", "CCN"]) + # Provide your own spectrum arrays (1D m/z and intensity values) + result = predictor.predict(your_mz_values, your_intensity_values, top_k=3) + + Args: + dreams_ckpt: Path to the DreaMS checkpoint (ssl_model.ckpt). + adapter_ckpt: Path to the SpecBridge adapter checkpoint. + device: Target device; defaults to CUDA if available. + spec_bins: Number of spectral bins (matches training setup). + cond_dim: SpecBridge conditioning dimension. + mapper_hidden: Hidden size for the mapper network. + n_blocks: Residual mapper blocks (matches training default of 8). + mol_space: Molecule embedding space (ChemBERTa supported). + chemberta_model: Hugging Face ChemBERTa model name. + no_gaussian: Disable Gaussian sampling in the mapper. + formula_vocab: One-hot vocab size for formulas; keep defaults for pretrained checkpoints. + adduct_vocab: One-hot vocab size for adducts; keep defaults for pretrained checkpoints. + charge_vocab: One-hot vocab size for charges; keep defaults for pretrained checkpoints. + fp_bits: Fingerprint length used for placeholder molecule features. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + dreams_ckpt: str, + adapter_ckpt: str, + *, + device: str | None = None, + spec_bins: int = 2048, + cond_dim: int = 2048, + mapper_hidden: int = 2048, + n_blocks: int = 8, + mol_space: str = "chemberta", + chemberta_model: str = "Derify/ChemBERTa_augmented_pubchem_13m", + no_gaussian: bool = False, + seed: int = 1234, + formula_vocab: int = 32, + adduct_vocab: int = 16, + charge_vocab: int = 8, + fp_bits: int = 2048, + ): + set_seed(seed) + self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.formula_vocab = formula_vocab + self.adduct_vocab = adduct_vocab + self.charge_vocab = charge_vocab + self.fp_bits = fp_bits + self.seed = seed + self.spec_bins = spec_bins + + args = SimpleNamespace( + dreams_ckpt=dreams_ckpt, + adapter_ckpt=adapter_ckpt, + spec_bins=spec_bins, + cond_dim=cond_dim, + mapper_hidden=mapper_hidden, + mol_space=mol_space, + chemberta_model=chemberta_model, + no_gaussian=no_gaussian, + n_blocks=n_blocks, + ) + self.model = build_model(args, self.device) + self.embed_fn = build_mol_embed_fn(args, self.model, self.device) + + self._cand_smiles: List[str] | None = None + self._cand_embeddings: torch.Tensor | None = None + + def embed_candidates(self, candidates: Sequence[str], *, batch_size: int = 512) -> List[str]: + """Pre-compute and cache candidate embeddings (canonicalized).""" + uniq: List[str] = [] + seen: Set[str] = set() + for smi in candidates: + canon = _canon_smi(smi) or smi + if canon not in seen: + uniq.append(canon) + seen.add(canon) + # Reuse cached embeddings if candidate list is unchanged. + if self._cand_smiles == uniq: + return uniq + if not uniq: + self._cand_smiles = [] + self._cand_embeddings = None + return [] + + Z = F.normalize(self.embed_fn(uniq, self.device, bs=batch_size), dim=-1) + self._cand_smiles = uniq + self._cand_embeddings = Z + return uniq + + def _compute_mapped(self, mu: torch.Tensor, lv: torch.Tensor | None, deterministic_map: bool) -> torch.Tensor: + if lv is None or deterministic_map: + return mu + return self.model.mapB.sample(mu, lv, deterministic=False) + + def _make_batch( + self, + mz: Sequence[float] | torch.Tensor, + intensity: Sequence[float] | torch.Tensor, + *, + normalize_intensities: bool, + title: str, + ) -> Dict[str, torch.Tensor]: + mz_t = torch.as_tensor(mz, dtype=torch.float32) + inten_t = torch.as_tensor(intensity, dtype=torch.float32) + if normalize_intensities and inten_t.numel() > 0: + max_int = inten_t.max() + if max_int > 0: + inten_t = inten_t / max_int + + rec = { + "mz": mz_t, + "intensity": inten_t, + "title": title, + "feature_id": title, + "meta": { + "formula_idx": None, + "adduct_idx": None, + "charge_idx": None, + "nce": None, + "instrument": None, + "smi_key": DUMMY_SMILES, + }, + "smiles": DUMMY_SMILES, + } + return collate_with_feature_id( + [rec], + self.spec_bins, + self.formula_vocab, + self.adduct_vocab, + self.charge_vocab, + self.fp_bits, + seed=self.seed, + ) + + @torch.no_grad() + def predict( + self, + mz: Sequence[float] | torch.Tensor, + intensity: Sequence[float] | torch.Tensor, + *, + candidates: Sequence[str] | None = None, + top_k: int = 5, + normalize_intensities: bool = True, + title: str = "query_0", + use_mapped: bool = True, + deterministic_map: bool = True, + ) -> Dict[str, object]: + """ + Score candidate SMILES for a single spectrum. + + Args: + mz: iterable of m/z values. + intensity: iterable of intensities matching ``mz``. + candidates: list of candidate SMILES (optional if already cached via ``embed_candidates``). + top_k: number of top predictions to return. + normalize_intensities: divide intensities by max value before binning. + title: identifier for the spectrum (used in outputs only). + use_mapped: when False, try spectrum-space embedding if its dimension matches + candidate embeddings; otherwise fall back to the mapped embedding. + deterministic_map: controls sampling in the mapper; also affects the fallback + path when ``use_mapped`` is False but a mapped embedding is required. + """ + if candidates is not None: + self.embed_candidates(candidates) + if self._cand_smiles is None: + raise ValueError("No candidates provided. Call embed_candidates() or pass candidates to predict().") + if len(self._cand_smiles) == 0: + return { + "title": title, + "predicted_smiles": None, + "top_k_smiles": [], + "top_k_scores": [], + "all_candidates": [], + "all_scores": [], + "status": "no_candidates", + } + if self._cand_embeddings is None: + raise ValueError("Candidate embeddings are missing. Call embed_candidates() first.") + + batch = self._make_batch(mz, intensity, normalize_intensities=normalize_intensities, title=title) + spectra = batch["spectra"].to(self.device) + meta = {k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in batch["meta"].items()} + meta["smi_key"] = meta.get("smi_key", [DUMMY_SMILES]) + + cand_emb = self._cand_embeddings.to(self.device) + z_s, z_m, z_hat, mu, lv = self.model(spectra, meta, None, inference=True) + if use_mapped: + z_query = self._compute_mapped(mu, lv, deterministic_map) + else: + # Only fall back to z_s (spectrum space) if it already matches + # the candidate embedding dimensionality (i.e., is compatible with the candidate space). + if z_s.shape[-1] == cand_emb.shape[-1]: + z_query = z_s + else: + z_query = self._compute_mapped(mu, lv, deterministic_map) + + sims = (F.normalize(z_query, dim=-1) @ cand_emb.T).squeeze(0) + order = torch.argsort(sims, descending=True) + k = min(top_k, len(order)) + top_smiles = [self._cand_smiles[i] for i in order[:k]] + top_scores = [float(sims[i]) for i in order[:k]] + all_scores = [float(sims[i]) for i in order] + + return { + "title": title, + "predicted_smiles": top_smiles[0] if top_smiles else None, + "top_k_smiles": top_smiles, + "top_k_scores": top_scores, + "all_candidates": [self._cand_smiles[i] for i in order], + "all_scores": all_scores, + "status": "success", + } \ No newline at end of file