Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: Installing and Testing
on: [push]

permissions:
contents: write
checks: write
pull-requests: write

env:
UV_SYSTEM_PYTHON: true

jobs:
install-and-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Install dependencies
run: |
pip install --upgrade pip
uv pip install -e '.[dev]'
# - name: Test with pytest
# run: |
# python -m pytest -sv tests
build-and-install:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Build the package
run: |
uv pip install build
python -m build --sdist
- name: Install the package
run: |
uv pip install dist/*.tar.gz
ruff-linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
version: "0.15.11"
args: "check"
src: "."
ruff-formatting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
version: "0.15.11"
args: "format --check --verbose"
src: "."
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.11
hooks:
- id: ruff
args: [--fix]
files: ^((fiora|tests|resources)/.*\.py|scripts/.*)$
- id: ruff-format
files: ^((fiora|tests|resources)/.*\.py|scripts/.*)$
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,64 @@ Run the fiora-predict from within this directory

By default, an open-source model is selected automatically, and predictions typically complete within a few seconds. For faster performance, specify a GPU device using the `--dev` option (e.g., `--dev cuda:0`). The output file (e.g., examples/example_spec.mgf) can be compared with the [expected results](examples/expected_output.mgf) to verify model accuracy. This verification is automatically performed by running pytest (as described above).

### Models and Resources

Default models are packaged under `fiora/resources/models`. The CLI uses these automatically when `--model default` is selected.

Scripts for downloading and preprocessing MSnLib are provided in `resources/data/msnlib` (`download_msnlib.py` and `preprocess_msnlib.py`).

The downloader defaults to MSnLib v7 on Zenodo (`https://zenodo.org/records/16984129`) and accepts both direct file URLs and Zenodo record URLs. For Zenodo record URLs it downloads all files matching `*_ms2.mgf` by default.

```bash
python resources/data/msnlib/download_msnlib.py
python resources/data/msnlib/preprocess_msnlib.py
```

Use `--record-pattern` to select a different subset, e.g. `--record-pattern "*_pos_*.mgf"`.

### MSnLib (Re)training

To (re)train a new FIORA model, use `fiora-train`. For example, to train on MSnLib with the same parameters used for the v1.0 release:

```bash
fiora-train \
-i resources/data/msnlib/library.csv \
-o checkpoints/fiora.pt \
--device cuda:0 \
--instruments HCD \
--precursor-modes "[M+H]+,[M-H]-,[M]+,[M]-"
```

To persist per-epoch training history, add `--history-out` (supports `.json` or `.csv`):

```bash
fiora-train ... --history-out checkpoints/fiora_history.json
```

### Model Evaluation CLI

You can evaluate a trained model on validation/test splits with:

```bash
fiora-eval \
-i resources/data/msnlib/library.csv \
-m checkpoints/fiora.pt \
--device cuda:0 \
--output-dir checkpoints/eval
```

This prints split-level summary scores (default: `spectral_sqrt_cosine`) and, when available, also reports precursor-excluded metrics (`spectral_sqrt_cosine_wo_prec`, `spectral_sqrt_cosine_avg`). Per-split result files like `validation_eval.csv` and `test_eval.csv` are written when `--output-dir` is set.

### Model Info CLI

To inspect key parameters of a trained model checkpoint:

```bash
fiora-model-info -m checkpoints/fiora.pt
```

Use `--as-json` to print the full `model_params` dictionary.

## The Algorithm

FIORA has been developed as a computational tool to predict bond cleavages that occur in the MS/MS fragmentation process and estimate the probabilities of resulting fragment ions. To that end, FIORA utilizes graph neural networks to learn local molecular neighborhoods around bonds, combined with edge prediction to simulate bond dissociation. The prediction determines which fragment (left or right of the bond cleavage, with up to four possible hydrogen losses) retains the charge and which becomes the neutral loss. The figure below illustrates an example fragmentation prediction for a single bond.
Expand Down
2 changes: 2 additions & 0 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pytz==2025.2
PyYAML==6.0.2
pyzmq==26.4.0
rdkit==2024.9.6
regex==2024.11.6
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
Expand All @@ -141,6 +142,7 @@ threadpoolctl==3.6.0
tinycss2==1.4.0
tomli==2.2.1
torch==2.6.0
lightning-fabric==2.6.1
torch-geometric==2.6.1
torchmetrics==1.8.1
tornado==6.4.2
Expand Down
121 changes: 77 additions & 44 deletions fiora/GNN/AtomFeatureEncoder.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,154 @@
import torch
import numpy as np
from rdkit import Chem
from typing import Literal
from fiora.MOL.constants import ORDERED_ELEMENT_LIST

import torch
from rdkit import Chem

from fiora.MOL.constants import ORDERED_ELEMENT_LIST


class AtomFeatureEncoder:
def __init__(self, feature_list=["symbol", "num_hydrogen", "ring_type"]):
def __init__(self, feature_list=['symbol', 'num_hydrogen', 'ring_type']):
self.encoding_dim = 0
self.sets = {
"symbol": ORDERED_ELEMENT_LIST, #OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
"num_hydrogen": [0, 1, 2, 3], #OTHERS: 5, 6, 7, 8},
"ring_type": ["no-ring", "small-ring", "5-cycle", "6-cycle", "large-ring"],
"hybridization": ["SP", "SP2", "SP3", "SP3D2"],
"valence_electrons": [1,2,3,4,5,6,7,8],
"oxidation_number": [1,2,3,4,5,6,7,8,9],
'symbol': ORDERED_ELEMENT_LIST, # OTHERS: Au, Se, Si #standard list {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
'num_hydrogen': [0, 1, 2, 3], # OTHERS: 5, 6, 7, 8},
'ring_type': ['no-ring', 'small-ring', '5-cycle', '6-cycle', 'large-ring'],
'hybridization': ['SP', 'SP2', 'SP3', 'SP3D2'],
'valence_electrons': [1, 2, 3, 4, 5, 6, 7, 8],
'oxidation_number': [1, 2, 3, 4, 5, 6, 7, 8, 9],
}
self.feature_list = feature_list
self.reduced_features = ["symbol", "num_hydrogen", "hybridization"] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS)

self.reduced_features = [
'symbol',
'num_hydrogen',
'hybridization',
] # Reduced features may have additional values that will be encoded with another bit (representing OTHERS)

self.one_hot_mapper = {}
self.number_mapper = {}
self.feature_numbers = {}
for feature in self.feature_list:
variables = self.sets[feature]
num_variables = len(variables)
self.one_hot_mapper[feature] = dict(zip(variables, range(self.encoding_dim, num_variables + self.encoding_dim)))
self.one_hot_mapper[feature] = dict(
zip(
variables,
range(self.encoding_dim, num_variables + self.encoding_dim),
)
)
self.number_mapper[feature] = dict(zip(variables, range(0, num_variables)))
self.encoding_dim += num_variables
if feature in self.reduced_features:
self.encoding_dim += 1
num_variables += 1
self.feature_numbers[feature] = num_variables

def encode(self, G, encoder_type: Literal['one_hot', 'number']):

if encoder_type == 'one_hot':
feature_matrix = torch.zeros(G.number_of_nodes(), self.encoding_dim, dtype=torch.float32)
feature_matrix = torch.zeros(
G.number_of_nodes(), self.encoding_dim, dtype=torch.float32
)

for i in range(G.number_of_nodes()):
atom = G.nodes()[i]['atom']

if 'symbol' in self.feature_list:
if not atom.GetSymbol() in self.sets['symbol']:
feature_matrix[i][self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]] + 1] = 1.0
if atom.GetSymbol() not in self.sets['symbol']:
feature_matrix[i][
self.one_hot_mapper['symbol'][list(self.sets['symbol'])[-1]]
+ 1
] = 1.0
else:
feature_matrix[i][self.one_hot_mapper['symbol'][atom.GetSymbol()]] = 1.0
feature_matrix[i][
self.one_hot_mapper['symbol'][atom.GetSymbol()]
] = 1.0

if 'num_hydrogen' in self.feature_list:
value = atom.GetTotalNumHs()
if value in self.sets["num_hydrogen"]:
feature_matrix[i][self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()]] = 1.0
if value in self.sets['num_hydrogen']:
feature_matrix[i][
self.one_hot_mapper['num_hydrogen'][atom.GetTotalNumHs()]
] = 1.0
else:
feature_matrix[i][self.one_hot_mapper['num_hydrogen'][list(self.sets['num_hydrogen'])[-1]] + 1] = 1.0
feature_matrix[i][
self.one_hot_mapper['num_hydrogen'][
list(self.sets['num_hydrogen'])[-1]
]
+ 1
] = 1.0
if 'ring_type' in self.feature_list:
if not atom.IsInRing():
ring_type = "no-ring"
ring_type = 'no-ring'
elif atom.IsInRingSize(7):
ring_type = "large-ring"
ring_type = 'large-ring'
elif atom.IsInRingSize(6):
ring_type = "6-cycle"
ring_type = '6-cycle'
elif atom.IsInRingSize(5):
ring_type = "5-cycle"
ring_type = '5-cycle'
else:
ring_type = "small-ring"
ring_type = 'small-ring'
feature_matrix[i][self.one_hot_mapper['ring_type'][ring_type]] = 1.0
if 'hybridization' in self.feature_list:
orbi = atom.GetHybridization().name
if orbi in self.sets['hybridization']:
feature_matrix[i][self.one_hot_mapper['hybridization'][orbi]] = 1.0
feature_matrix[i][
self.one_hot_mapper['hybridization'][orbi]
] = 1.0
else:
feature_matrix[i][self.one_hot_mapper['hybridization'][list(self.sets['hybridization'])[-1]] + 1] = 1.0

else: # Case: Number mapping
feature_matrix = torch.zeros(G.number_of_nodes(), len(self.feature_list), dtype=torch.int)
feature_matrix[i][
self.one_hot_mapper['hybridization'][
list(self.sets['hybridization'])[-1]
]
+ 1
] = 1.0

else: # Case: Number mapping
feature_matrix = torch.zeros(
G.number_of_nodes(), len(self.feature_list), dtype=torch.int
)
for i in range(G.number_of_nodes()):
atom = G.nodes()[i]['atom']

for j, feature in enumerate(self.feature_list):
if feature == "symbol":
if feature == 'symbol':
if atom.GetSymbol() in self.sets['symbol']:
feature_matrix[i][j] = self.number_mapper[feature][atom.GetSymbol()]
feature_matrix[i][j] = self.number_mapper[feature][
atom.GetSymbol()
]
else:
feature_matrix[i][j] = self.feature_numbers[feature] - 1
elif feature == 'num_hydrogen':
value = atom.GetTotalNumHs()
if value in self.sets["num_hydrogen"]:
if value in self.sets['num_hydrogen']:
feature_matrix[i][j] = self.number_mapper[feature][value]
else:
feature_matrix[i][j] = self.feature_numbers[feature] - 1
elif feature == 'valence_electrons':
value = atom.GetExplicitValence()
if value in self.sets["valence_electrons"]:
if value in self.sets['valence_electrons']:
feature_matrix[i][j] = self.number_mapper[feature][value]
else:
feature_matrix[i][j] = self.feature_numbers[feature] - 1
elif feature == 'oxidation_number':
raise NotImplementedError()
value = Chem.rdMolDescriptors.CalcOxidationNumbers(atom)
if value in self.sets["oxidation_number"]:
if value in self.sets['oxidation_number']:
feature_matrix[i][j] = self.number_mapper[feature][value]
else:
feature_matrix[i][j] = self.feature_numbers[feature] - 1

elif feature == 'ring_type':
if not atom.IsInRing():
ring_type = "no-ring"
ring_type = 'no-ring'
elif atom.IsInRingSize(7):
ring_type = "large-ring"
ring_type = 'large-ring'
elif atom.IsInRingSize(6):
ring_type = "6-cycle"
ring_type = '6-cycle'
elif atom.IsInRingSize(5):
ring_type = "5-cycle"
ring_type = '5-cycle'
else:
ring_type = "small-ring"
ring_type = 'small-ring'
feature_matrix[i][j] = self.number_mapper[feature][ring_type]
if feature == 'hybridization':
orbi = atom.GetHybridization().name
Expand Down
Loading
Loading